fix: prompt for baichuan text generation models (#1299)

This commit is contained in:
takatost 2023-10-10 13:01:18 +08:00 committed by GitHub
parent df07fb5951
commit 8480b0197b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -37,6 +37,12 @@ class BaichuanModel(BaseLLM):
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
def prompt_file_name(self, mode: str) -> str:
if mode == 'completion':
return 'baichuan_completion'
else:
return 'baichuan_chat'
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
get num tokens of prompt messages.