2023-05-25 21:31:11 +08:00
|
|
|
import logging
|
2023-07-27 13:08:57 +08:00
|
|
|
import re
|
2023-05-23 19:54:04 +08:00
|
|
|
from typing import Optional, List, Union, Tuple
|
2023-05-15 08:51:32 +08:00
|
|
|
|
2023-06-25 16:49:14 +08:00
|
|
|
from langchain.base_language import BaseLanguageModel
|
|
|
|
from langchain.callbacks.base import BaseCallbackHandler
|
2023-05-15 08:51:32 +08:00
|
|
|
from langchain.chat_models.base import BaseChatModel
|
|
|
|
from langchain.llms import BaseLLM
|
2023-06-25 16:49:14 +08:00
|
|
|
from langchain.schema import BaseMessage, HumanMessage
|
2023-05-25 21:31:11 +08:00
|
|
|
from requests.exceptions import ChunkedEncodingError
|
|
|
|
|
2023-07-27 13:08:57 +08:00
|
|
|
from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy
|
|
|
|
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
|
2023-05-15 08:51:32 +08:00
|
|
|
from core.constant import llm_constant
|
|
|
|
from core.callback_handler.llm_callback_handler import LLMCallbackHandler
|
|
|
|
from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, \
|
|
|
|
DifyStdOutCallbackHandler
|
2023-06-25 16:49:14 +08:00
|
|
|
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
|
2023-05-15 08:51:32 +08:00
|
|
|
from core.llm.error import LLMBadRequestError
|
2023-07-27 13:08:57 +08:00
|
|
|
from core.llm.fake import FakeLLM
|
2023-05-15 08:51:32 +08:00
|
|
|
from core.llm.llm_builder import LLMBuilder
|
|
|
|
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
|
|
|
|
from core.llm.streamable_open_ai import StreamableOpenAI
|
|
|
|
from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
|
|
|
|
ReadOnlyConversationTokenDBBufferSharedMemory
|
2023-07-27 13:08:57 +08:00
|
|
|
from core.orchestrator_rule_parser import OrchestratorRuleParser
|
2023-05-15 08:51:32 +08:00
|
|
|
from core.prompt.prompt_builder import PromptBuilder
|
2023-06-27 15:30:38 +08:00
|
|
|
from core.prompt.prompt_template import JinjaPromptTemplate
|
2023-05-15 08:51:32 +08:00
|
|
|
from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT
|
2023-07-27 22:15:07 +08:00
|
|
|
from models.model import App, AppModelConfig, Account, Conversation, Message, EndUser
|
2023-05-15 08:51:32 +08:00
|
|
|
|
|
|
|
|
|
|
|
class Completion:
|
|
|
|
@classmethod
|
|
|
|
def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, query: str, inputs: dict,
|
2023-07-27 22:15:07 +08:00
|
|
|
user: Union[Account, EndUser], conversation: Optional[Conversation], streaming: bool, is_override: bool = False):
|
2023-05-15 08:51:32 +08:00
|
|
|
"""
|
|
|
|
errors: ProviderTokenNotInitError
|
|
|
|
"""
|
2023-06-27 15:30:38 +08:00
|
|
|
query = PromptBuilder.process_template(query)
|
|
|
|
|
2023-05-15 08:51:32 +08:00
|
|
|
memory = None
|
|
|
|
if conversation:
|
|
|
|
# get memory of conversation (read-only)
|
|
|
|
memory = cls.get_memory_from_conversation(
|
|
|
|
tenant_id=app.tenant_id,
|
|
|
|
app_model_config=app_model_config,
|
2023-05-23 12:26:28 +08:00
|
|
|
conversation=conversation,
|
|
|
|
return_messages=False
|
2023-05-15 08:51:32 +08:00
|
|
|
)
|
|
|
|
|
|
|
|
inputs = conversation.inputs
|
|
|
|
|
2023-06-25 16:49:14 +08:00
|
|
|
rest_tokens_for_context_and_memory = cls.get_validate_rest_tokens(
|
|
|
|
mode=app.mode,
|
|
|
|
tenant_id=app.tenant_id,
|
|
|
|
app_model_config=app_model_config,
|
|
|
|
query=query,
|
|
|
|
inputs=inputs
|
|
|
|
)
|
|
|
|
|
2023-05-15 08:51:32 +08:00
|
|
|
conversation_message_task = ConversationMessageTask(
|
|
|
|
task_id=task_id,
|
|
|
|
app=app,
|
|
|
|
app_model_config=app_model_config,
|
|
|
|
user=user,
|
|
|
|
conversation=conversation,
|
|
|
|
is_override=is_override,
|
|
|
|
inputs=inputs,
|
|
|
|
query=query,
|
|
|
|
streaming=streaming
|
|
|
|
)
|
|
|
|
|
2023-07-27 13:08:57 +08:00
|
|
|
chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
|
|
|
|
|
|
|
|
# init orchestrator rule parser
|
|
|
|
orchestrator_rule_parser = OrchestratorRuleParser(
|
2023-05-15 08:51:32 +08:00
|
|
|
tenant_id=app.tenant_id,
|
2023-07-27 13:08:57 +08:00
|
|
|
app_model_config=app_model_config
|
|
|
|
)
|
|
|
|
|
|
|
|
# parse sensitive_word_avoidance_chain
|
|
|
|
sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain([chain_callback])
|
|
|
|
if sensitive_word_avoidance_chain:
|
|
|
|
query = sensitive_word_avoidance_chain.run(query)
|
|
|
|
|
|
|
|
# get agent executor
|
|
|
|
agent_executor = orchestrator_rule_parser.to_agent_executor(
|
|
|
|
conversation_message_task=conversation_message_task,
|
|
|
|
memory=memory,
|
2023-06-25 16:49:14 +08:00
|
|
|
rest_tokens=rest_tokens_for_context_and_memory,
|
2023-07-27 13:08:57 +08:00
|
|
|
chain_callback=chain_callback
|
2023-05-15 08:51:32 +08:00
|
|
|
)
|
|
|
|
|
2023-07-27 13:08:57 +08:00
|
|
|
# run agent executor
|
|
|
|
agent_execute_result = None
|
|
|
|
if agent_executor:
|
|
|
|
should_use_agent = agent_executor.should_use_agent(query)
|
|
|
|
if should_use_agent:
|
|
|
|
agent_execute_result = agent_executor.run(query)
|
2023-05-15 08:51:32 +08:00
|
|
|
|
|
|
|
# run the final llm
|
|
|
|
try:
|
|
|
|
cls.run_final_llm(
|
|
|
|
tenant_id=app.tenant_id,
|
|
|
|
mode=app.mode,
|
|
|
|
app_model_config=app_model_config,
|
|
|
|
query=query,
|
|
|
|
inputs=inputs,
|
2023-07-27 13:08:57 +08:00
|
|
|
agent_execute_result=agent_execute_result,
|
2023-05-15 08:51:32 +08:00
|
|
|
conversation_message_task=conversation_message_task,
|
|
|
|
memory=memory,
|
|
|
|
streaming=streaming
|
|
|
|
)
|
|
|
|
except ConversationTaskStoppedException:
|
|
|
|
return
|
2023-05-25 21:31:11 +08:00
|
|
|
except ChunkedEncodingError as e:
|
|
|
|
# Interrupt by LLM (like OpenAI), handle it.
|
|
|
|
logging.warning(f'ChunkedEncodingError: {e}')
|
|
|
|
conversation_message_task.end()
|
|
|
|
return
|
2023-05-15 08:51:32 +08:00
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def run_final_llm(cls, tenant_id: str, mode: str, app_model_config: AppModelConfig, query: str, inputs: dict,
|
2023-07-27 13:08:57 +08:00
|
|
|
agent_execute_result: Optional[AgentExecuteResult],
|
2023-05-15 08:51:32 +08:00
|
|
|
conversation_message_task: ConversationMessageTask,
|
|
|
|
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory], streaming: bool):
|
2023-07-27 13:08:57 +08:00
|
|
|
# When no extra pre prompt is specified,
|
|
|
|
# the output of the agent can be used directly as the main output content without calling LLM again
|
|
|
|
if not app_model_config.pre_prompt and agent_execute_result and agent_execute_result.output \
|
|
|
|
and agent_execute_result.strategy != PlanningStrategy.ROUTER:
|
|
|
|
final_llm = FakeLLM(response=agent_execute_result.output,
|
|
|
|
origin_llm=agent_execute_result.configuration.llm,
|
|
|
|
streaming=streaming)
|
|
|
|
final_llm.callbacks = cls.get_llm_callbacks(final_llm, streaming, conversation_message_task)
|
|
|
|
response = final_llm.generate([[HumanMessage(content=query)]])
|
|
|
|
return response
|
|
|
|
|
2023-05-15 08:51:32 +08:00
|
|
|
final_llm = LLMBuilder.to_llm_from_model(
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
model=app_model_config.model_dict,
|
|
|
|
streaming=streaming
|
|
|
|
)
|
|
|
|
|
|
|
|
# get llm prompt
|
2023-05-23 19:54:04 +08:00
|
|
|
prompt, stop_words = cls.get_main_llm_prompt(
|
2023-05-15 08:51:32 +08:00
|
|
|
mode=mode,
|
|
|
|
llm=final_llm,
|
2023-07-17 00:14:19 +08:00
|
|
|
model=app_model_config.model_dict,
|
2023-05-15 08:51:32 +08:00
|
|
|
pre_prompt=app_model_config.pre_prompt,
|
|
|
|
query=query,
|
|
|
|
inputs=inputs,
|
2023-07-27 13:08:57 +08:00
|
|
|
agent_execute_result=agent_execute_result,
|
2023-05-15 08:51:32 +08:00
|
|
|
memory=memory
|
|
|
|
)
|
|
|
|
|
2023-06-25 16:49:14 +08:00
|
|
|
final_llm.callbacks = cls.get_llm_callbacks(final_llm, streaming, conversation_message_task)
|
2023-05-15 08:51:32 +08:00
|
|
|
|
|
|
|
cls.recale_llm_max_tokens(
|
|
|
|
final_llm=final_llm,
|
2023-07-17 00:14:19 +08:00
|
|
|
model=app_model_config.model_dict,
|
2023-05-15 08:51:32 +08:00
|
|
|
prompt=prompt,
|
|
|
|
mode=mode
|
|
|
|
)
|
|
|
|
|
2023-05-23 19:54:04 +08:00
|
|
|
response = final_llm.generate([prompt], stop_words)
|
2023-05-15 08:51:32 +08:00
|
|
|
|
|
|
|
return response
|
|
|
|
|
|
|
|
@classmethod
|
2023-07-17 00:14:19 +08:00
|
|
|
def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, model: dict,
|
|
|
|
pre_prompt: str, query: str, inputs: dict,
|
2023-07-27 13:08:57 +08:00
|
|
|
agent_execute_result: Optional[AgentExecuteResult],
|
2023-05-15 08:51:32 +08:00
|
|
|
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \
|
2023-05-23 19:54:04 +08:00
|
|
|
Tuple[Union[str | List[BaseMessage]], Optional[List[str]]]:
|
2023-05-15 08:51:32 +08:00
|
|
|
if mode == 'completion':
|
2023-06-27 15:30:38 +08:00
|
|
|
prompt_template = JinjaPromptTemplate.from_template(
|
2023-07-17 00:14:19 +08:00
|
|
|
template=("""Use the following context as your learned knowledge, inside <context></context> XML tags.
|
|
|
|
|
|
|
|
<context>
|
2023-06-27 15:30:38 +08:00
|
|
|
{{context}}
|
2023-07-17 00:14:19 +08:00
|
|
|
</context>
|
2023-05-21 11:29:10 +08:00
|
|
|
|
|
|
|
When answer to user:
|
|
|
|
- If you don't know, just say that you don't know.
|
|
|
|
- If you don't know when you are not sure, ask for clarification.
|
|
|
|
Avoid mentioning that you obtained the information from the context.
|
|
|
|
And answer according to the language of the user's question.
|
2023-07-27 13:08:57 +08:00
|
|
|
""" if agent_execute_result else "")
|
2023-05-15 08:51:32 +08:00
|
|
|
+ (pre_prompt + "\n" if pre_prompt else "")
|
2023-06-27 15:30:38 +08:00
|
|
|
+ "{{query}}\n"
|
2023-05-15 08:51:32 +08:00
|
|
|
)
|
|
|
|
|
2023-07-27 13:08:57 +08:00
|
|
|
if agent_execute_result:
|
|
|
|
inputs['context'] = agent_execute_result.output
|
2023-05-15 08:51:32 +08:00
|
|
|
|
|
|
|
prompt_inputs = {k: inputs[k] for k in prompt_template.input_variables if k in inputs}
|
|
|
|
prompt_content = prompt_template.format(
|
|
|
|
query=query,
|
|
|
|
**prompt_inputs
|
|
|
|
)
|
|
|
|
|
|
|
|
if isinstance(llm, BaseChatModel):
|
|
|
|
# use chat llm as completion model
|
2023-05-23 19:54:04 +08:00
|
|
|
return [HumanMessage(content=prompt_content)], None
|
2023-05-15 08:51:32 +08:00
|
|
|
else:
|
2023-05-23 19:54:04 +08:00
|
|
|
return prompt_content, None
|
2023-05-15 08:51:32 +08:00
|
|
|
else:
|
|
|
|
messages: List[BaseMessage] = []
|
|
|
|
|
2023-05-21 11:29:10 +08:00
|
|
|
human_inputs = {
|
|
|
|
"query": query
|
|
|
|
}
|
|
|
|
|
2023-05-23 12:26:28 +08:00
|
|
|
human_message_prompt = ""
|
|
|
|
|
|
|
|
if pre_prompt:
|
|
|
|
pre_prompt_inputs = {k: inputs[k] for k in
|
2023-06-27 15:30:38 +08:00
|
|
|
JinjaPromptTemplate.from_template(template=pre_prompt).input_variables
|
2023-05-23 12:26:28 +08:00
|
|
|
if k in inputs}
|
|
|
|
|
|
|
|
if pre_prompt_inputs:
|
|
|
|
human_inputs.update(pre_prompt_inputs)
|
2023-05-15 08:51:32 +08:00
|
|
|
|
2023-07-27 13:08:57 +08:00
|
|
|
if agent_execute_result:
|
|
|
|
human_inputs['context'] = agent_execute_result.output
|
2023-07-17 00:14:19 +08:00
|
|
|
human_message_prompt += """Use the following context as your learned knowledge, inside <context></context> XML tags.
|
|
|
|
|
|
|
|
<context>
|
2023-06-27 15:30:38 +08:00
|
|
|
{{context}}
|
2023-07-17 00:14:19 +08:00
|
|
|
</context>
|
2023-05-15 08:51:32 +08:00
|
|
|
|
2023-05-21 11:29:10 +08:00
|
|
|
When answer to user:
|
|
|
|
- If you don't know, just say that you don't know.
|
|
|
|
- If you don't know when you are not sure, ask for clarification.
|
|
|
|
Avoid mentioning that you obtained the information from the context.
|
|
|
|
And answer according to the language of the user's question.
|
2023-05-21 17:06:04 +08:00
|
|
|
"""
|
2023-05-15 08:51:32 +08:00
|
|
|
|
2023-05-23 12:26:28 +08:00
|
|
|
if pre_prompt:
|
|
|
|
human_message_prompt += pre_prompt
|
|
|
|
|
2023-07-17 00:14:19 +08:00
|
|
|
query_prompt = "\n\nHuman: {{query}}\n\nAssistant: "
|
2023-05-15 08:51:32 +08:00
|
|
|
|
|
|
|
if memory:
|
|
|
|
# append chat histories
|
2023-05-23 12:26:28 +08:00
|
|
|
tmp_human_message = PromptBuilder.to_human_message(
|
|
|
|
prompt_content=human_message_prompt + query_prompt,
|
|
|
|
inputs=human_inputs
|
|
|
|
)
|
|
|
|
|
2023-07-17 00:14:19 +08:00
|
|
|
curr_message_tokens = memory.llm.get_num_tokens_from_messages([tmp_human_message])
|
|
|
|
model_name = model['name']
|
|
|
|
max_tokens = model.get("completion_params").get('max_tokens')
|
|
|
|
rest_tokens = llm_constant.max_context_token_length[model_name] \
|
|
|
|
- max_tokens - curr_message_tokens
|
2023-05-15 08:51:32 +08:00
|
|
|
rest_tokens = max(rest_tokens, 0)
|
2023-05-23 13:16:33 +08:00
|
|
|
histories = cls.get_history_messages_from_memory(memory, rest_tokens)
|
2023-07-17 00:14:19 +08:00
|
|
|
human_message_prompt += "\n\n" if human_message_prompt else ""
|
|
|
|
human_message_prompt += "Here is the chat histories between human and assistant, " \
|
2023-07-27 13:08:57 +08:00
|
|
|
"inside <histories></histories> XML tags.\n\n<histories>\n"
|
|
|
|
human_message_prompt += histories + "\n</histories>"
|
2023-05-23 12:26:28 +08:00
|
|
|
|
|
|
|
human_message_prompt += query_prompt
|
|
|
|
|
|
|
|
# construct main prompt
|
|
|
|
human_message = PromptBuilder.to_human_message(
|
|
|
|
prompt_content=human_message_prompt,
|
|
|
|
inputs=human_inputs
|
|
|
|
)
|
2023-05-15 08:51:32 +08:00
|
|
|
|
|
|
|
messages.append(human_message)
|
|
|
|
|
2023-07-27 13:08:57 +08:00
|
|
|
for message in messages:
|
|
|
|
message.content = re.sub(r'<\|.*?\|>', '', message.content)
|
|
|
|
|
|
|
|
return messages, ['\nHuman:', '</histories>']
|
2023-05-15 08:51:32 +08:00
|
|
|
|
|
|
|
@classmethod
|
2023-07-27 13:08:57 +08:00
|
|
|
def get_llm_callbacks(cls, llm: BaseLanguageModel,
|
2023-06-25 16:49:14 +08:00
|
|
|
streaming: bool,
|
|
|
|
conversation_message_task: ConversationMessageTask) -> List[BaseCallbackHandler]:
|
2023-05-15 08:51:32 +08:00
|
|
|
llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task)
|
|
|
|
if streaming:
|
2023-06-25 16:49:14 +08:00
|
|
|
return [llm_callback_handler, DifyStreamingStdOutCallbackHandler()]
|
2023-05-15 08:51:32 +08:00
|
|
|
else:
|
2023-06-25 16:49:14 +08:00
|
|
|
return [llm_callback_handler, DifyStdOutCallbackHandler()]
|
2023-05-15 08:51:32 +08:00
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
|
2023-07-27 13:08:57 +08:00
|
|
|
max_token_limit: int) -> str:
|
2023-05-15 08:51:32 +08:00
|
|
|
"""Get memory messages."""
|
|
|
|
memory.max_token_limit = max_token_limit
|
|
|
|
memory_key = memory.memory_variables[0]
|
|
|
|
external_context = memory.load_memory_variables({})
|
|
|
|
return external_context[memory_key]
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def get_memory_from_conversation(cls, tenant_id: str, app_model_config: AppModelConfig,
|
|
|
|
conversation: Conversation,
|
|
|
|
**kwargs) -> ReadOnlyConversationTokenDBBufferSharedMemory:
|
|
|
|
# only for calc token in memory
|
|
|
|
memory_llm = LLMBuilder.to_llm_from_model(
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
model=app_model_config.model_dict
|
|
|
|
)
|
|
|
|
|
|
|
|
# use llm config from conversation
|
|
|
|
memory = ReadOnlyConversationTokenDBBufferSharedMemory(
|
|
|
|
conversation=conversation,
|
|
|
|
llm=memory_llm,
|
|
|
|
max_token_limit=kwargs.get("max_token_limit", 2048),
|
|
|
|
memory_key=kwargs.get("memory_key", "chat_history"),
|
|
|
|
return_messages=kwargs.get("return_messages", True),
|
|
|
|
input_key=kwargs.get("input_key", "input"),
|
|
|
|
output_key=kwargs.get("output_key", "output"),
|
|
|
|
message_limit=kwargs.get("message_limit", 10),
|
|
|
|
)
|
|
|
|
|
|
|
|
return memory
|
|
|
|
|
|
|
|
@classmethod
|
2023-06-25 16:49:14 +08:00
|
|
|
def get_validate_rest_tokens(cls, mode: str, tenant_id: str, app_model_config: AppModelConfig,
|
|
|
|
query: str, inputs: dict) -> int:
|
2023-05-15 08:51:32 +08:00
|
|
|
llm = LLMBuilder.to_llm_from_model(
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
model=app_model_config.model_dict
|
|
|
|
)
|
|
|
|
|
2023-07-17 00:14:19 +08:00
|
|
|
model_name = app_model_config.model_dict.get("name")
|
|
|
|
model_limited_tokens = llm_constant.max_context_token_length[model_name]
|
|
|
|
max_tokens = app_model_config.model_dict.get("completion_params").get('max_tokens')
|
2023-05-15 08:51:32 +08:00
|
|
|
|
2023-06-25 16:49:14 +08:00
|
|
|
# get prompt without memory and context
|
|
|
|
prompt, _ = cls.get_main_llm_prompt(
|
|
|
|
mode=mode,
|
|
|
|
llm=llm,
|
2023-07-17 00:14:19 +08:00
|
|
|
model=app_model_config.model_dict,
|
2023-06-25 16:49:14 +08:00
|
|
|
pre_prompt=app_model_config.pre_prompt,
|
|
|
|
query=query,
|
|
|
|
inputs=inputs,
|
2023-07-27 13:08:57 +08:00
|
|
|
agent_execute_result=None,
|
2023-06-25 16:49:14 +08:00
|
|
|
memory=None
|
|
|
|
)
|
|
|
|
|
|
|
|
prompt_tokens = llm.get_num_tokens(prompt) if isinstance(prompt, str) \
|
|
|
|
else llm.get_num_tokens_from_messages(prompt)
|
|
|
|
|
|
|
|
rest_tokens = model_limited_tokens - max_tokens - prompt_tokens
|
|
|
|
if rest_tokens < 0:
|
|
|
|
raise LLMBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, "
|
|
|
|
"or shrink the max token, or switch to a llm with a larger token limit size.")
|
|
|
|
|
|
|
|
return rest_tokens
|
2023-05-15 08:51:32 +08:00
|
|
|
|
|
|
|
@classmethod
|
2023-07-17 00:14:19 +08:00
|
|
|
def recale_llm_max_tokens(cls, final_llm: BaseLanguageModel, model: dict,
|
2023-05-15 08:51:32 +08:00
|
|
|
prompt: Union[str, List[BaseMessage]], mode: str):
|
|
|
|
# recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
|
2023-07-17 00:14:19 +08:00
|
|
|
model_name = model.get("name")
|
|
|
|
model_limited_tokens = llm_constant.max_context_token_length[model_name]
|
|
|
|
max_tokens = model.get("completion_params").get('max_tokens')
|
2023-05-15 08:51:32 +08:00
|
|
|
|
|
|
|
if mode == 'completion' and isinstance(final_llm, BaseLLM):
|
|
|
|
prompt_tokens = final_llm.get_num_tokens(prompt)
|
|
|
|
else:
|
2023-07-17 00:14:19 +08:00
|
|
|
prompt_tokens = final_llm.get_num_tokens_from_messages(prompt)
|
2023-05-15 08:51:32 +08:00
|
|
|
|
|
|
|
if prompt_tokens + max_tokens > model_limited_tokens:
|
|
|
|
max_tokens = max(model_limited_tokens - prompt_tokens, 16)
|
|
|
|
final_llm.max_tokens = max_tokens
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def generate_more_like_this(cls, task_id: str, app: App, message: Message, pre_prompt: str,
|
|
|
|
app_model_config: AppModelConfig, user: Account, streaming: bool):
|
2023-07-17 00:14:19 +08:00
|
|
|
|
|
|
|
llm = LLMBuilder.to_llm_from_model(
|
2023-05-15 08:51:32 +08:00
|
|
|
tenant_id=app.tenant_id,
|
2023-07-17 00:14:19 +08:00
|
|
|
model=app_model_config.model_dict,
|
2023-05-15 08:51:32 +08:00
|
|
|
streaming=streaming
|
|
|
|
)
|
|
|
|
|
|
|
|
# get llm prompt
|
2023-05-23 19:54:04 +08:00
|
|
|
original_prompt, _ = cls.get_main_llm_prompt(
|
2023-05-15 08:51:32 +08:00
|
|
|
mode="completion",
|
|
|
|
llm=llm,
|
2023-07-17 00:14:19 +08:00
|
|
|
model=app_model_config.model_dict,
|
2023-05-15 08:51:32 +08:00
|
|
|
pre_prompt=pre_prompt,
|
|
|
|
query=message.query,
|
|
|
|
inputs=message.inputs,
|
2023-07-27 13:08:57 +08:00
|
|
|
agent_execute_result=None,
|
2023-05-15 08:51:32 +08:00
|
|
|
memory=None
|
|
|
|
)
|
|
|
|
|
|
|
|
original_completion = message.answer.strip()
|
|
|
|
|
|
|
|
prompt = MORE_LIKE_THIS_GENERATE_PROMPT
|
|
|
|
prompt = prompt.format(prompt=original_prompt, original_completion=original_completion)
|
|
|
|
|
|
|
|
if isinstance(llm, BaseChatModel):
|
|
|
|
prompt = [HumanMessage(content=prompt)]
|
|
|
|
|
|
|
|
conversation_message_task = ConversationMessageTask(
|
|
|
|
task_id=task_id,
|
|
|
|
app=app,
|
|
|
|
app_model_config=app_model_config,
|
|
|
|
user=user,
|
|
|
|
inputs=message.inputs,
|
|
|
|
query=message.query,
|
|
|
|
is_override=True if message.override_model_configs else False,
|
|
|
|
streaming=streaming
|
|
|
|
)
|
|
|
|
|
2023-06-25 16:49:14 +08:00
|
|
|
llm.callbacks = cls.get_llm_callbacks(llm, streaming, conversation_message_task)
|
2023-05-15 08:51:32 +08:00
|
|
|
|
|
|
|
cls.recale_llm_max_tokens(
|
|
|
|
final_llm=llm,
|
2023-07-17 00:14:19 +08:00
|
|
|
model=app_model_config.model_dict,
|
2023-05-15 08:51:32 +08:00
|
|
|
prompt=prompt,
|
|
|
|
mode='completion'
|
|
|
|
)
|
|
|
|
|
|
|
|
llm.generate([prompt])
|