2023-10-12 23:13:10 +08:00
|
|
|
from langchain.schema import BaseMessage, SystemMessage, AIMessage, HumanMessage
|
2023-05-15 08:51:32 +08:00
|
|
|
|
2023-10-12 23:13:10 +08:00
|
|
|
from core.prompt.prompt_template import PromptTemplateParser
|
2023-05-15 08:51:32 +08:00
|
|
|
|
|
|
|
|
|
|
|
class PromptBuilder:
|
2023-10-12 23:13:10 +08:00
|
|
|
@classmethod
|
|
|
|
def parse_prompt(cls, prompt: str, inputs: dict) -> str:
|
|
|
|
prompt_template = PromptTemplateParser(prompt)
|
|
|
|
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
|
|
|
|
prompt = prompt_template.format(prompt_inputs)
|
|
|
|
return prompt
|
|
|
|
|
2023-05-15 08:51:32 +08:00
|
|
|
@classmethod
|
|
|
|
def to_system_message(cls, prompt_content: str, inputs: dict) -> BaseMessage:
|
2023-10-12 23:13:10 +08:00
|
|
|
return SystemMessage(content=cls.parse_prompt(prompt_content, inputs))
|
2023-05-15 08:51:32 +08:00
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def to_ai_message(cls, prompt_content: str, inputs: dict) -> BaseMessage:
|
2023-10-12 23:13:10 +08:00
|
|
|
return AIMessage(content=cls.parse_prompt(prompt_content, inputs))
|
2023-05-15 08:51:32 +08:00
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def to_human_message(cls, prompt_content: str, inputs: dict) -> BaseMessage:
|
2023-10-12 23:13:10 +08:00
|
|
|
return HumanMessage(content=cls.parse_prompt(prompt_content, inputs))
|