2023-01-01 22:52:27 +08:00
|
|
|
# 提供与模型交互的抽象接口
|
2023-03-02 15:31:12 +08:00
|
|
|
import openai, logging
|
|
|
|
|
|
|
|
CHAT_COMPLETION_MODELS = {
|
|
|
|
'gpt-3.5-turbo'
|
|
|
|
}
|
2023-01-01 22:52:27 +08:00
|
|
|
|
|
|
|
COMPLETION_MODELS = {
|
|
|
|
'text-davinci-003'
|
|
|
|
}
|
|
|
|
|
|
|
|
EDIT_MODELS = {
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
IMAGE_MODELS = {
|
|
|
|
|
|
|
|
}
|
|
|
|
|
2023-03-02 15:31:12 +08:00
|
|
|
class Model():
|
|
|
|
|
|
|
|
can_chat = False
|
|
|
|
|
|
|
|
def __init__(self, model_name, user_name, request_fun):
|
|
|
|
self.model_name = model_name
|
|
|
|
self.user_name = user_name
|
|
|
|
self.request_fun = request_fun
|
|
|
|
|
|
|
|
def request(self, **kwargs):
|
|
|
|
ret = self.request_fun(**kwargs)
|
|
|
|
self.ret = self.ret_handle(ret)
|
|
|
|
self.message = self.ret["choices"][0]["message"]
|
|
|
|
|
|
|
|
def msg_handle(self, msg):
|
|
|
|
return msg
|
|
|
|
|
|
|
|
def ret_handle(self, ret):
|
|
|
|
return ret
|
|
|
|
|
|
|
|
def get_total_tokens(self):
|
|
|
|
return self.ret['usage']['total_tokens']
|
|
|
|
|
|
|
|
def get_message(self):
|
|
|
|
return self.message
|
|
|
|
|
|
|
|
def get_response(self):
|
|
|
|
return self.ret
|
|
|
|
|
|
|
|
class ChatCompletionModel(Model):
|
|
|
|
def __init__(self, model_name, user_name):
|
|
|
|
request_fun = openai.ChatCompletion.create
|
|
|
|
self.can_chat = True
|
|
|
|
super().__init__(model_name, user_name, request_fun)
|
|
|
|
|
|
|
|
def request(self, messages, **kwargs):
|
|
|
|
ret = self.request_fun(messages = self.msg_handle(messages), **kwargs, user=self.user_name)
|
|
|
|
self.ret = self.ret_handle(ret)
|
|
|
|
self.message = self.ret["choices"][0]["message"]['content']
|
2023-01-01 22:52:27 +08:00
|
|
|
|
2023-03-02 15:31:12 +08:00
|
|
|
def get_content(self):
|
|
|
|
return self.message
|
2023-01-01 22:52:27 +08:00
|
|
|
|
2023-03-02 15:31:12 +08:00
|
|
|
class CompletionModel(Model):
|
|
|
|
def __init__(self, model_name, user_name):
|
|
|
|
request_fun = openai.Completion.create
|
|
|
|
super().__init__(model_name, user_name, request_fun)
|
2023-01-01 22:52:27 +08:00
|
|
|
|
2023-03-02 15:31:12 +08:00
|
|
|
def request(self, prompt, **kwargs):
|
|
|
|
ret = self.request_fun(prompt = self.msg_handle(prompt), **kwargs)
|
|
|
|
self.ret = self.ret_handle(ret)
|
|
|
|
self.message = self.ret["choices"][0]["text"]
|
2023-01-01 22:52:27 +08:00
|
|
|
|
2023-03-02 15:31:12 +08:00
|
|
|
def msg_handle(self, msgs):
|
|
|
|
prompt = ''
|
|
|
|
for msg in msgs:
|
|
|
|
if msg['role'] == '':
|
|
|
|
prompt = prompt + "{}\n".format(msg['content'])
|
|
|
|
else:
|
|
|
|
prompt = prompt + "{}:{}\n".format(msg['role'] if msg['role']!='system' else '你的回答要遵守此规则', msg['content'])
|
|
|
|
print(prompt)
|
|
|
|
return prompt
|
|
|
|
|
|
|
|
def get_text(self):
|
|
|
|
return self.message
|
2023-01-01 22:52:27 +08:00
|
|
|
|
2023-03-02 15:31:12 +08:00
|
|
|
def OpenaiModel(model_name:str, user_name='user'):
|
|
|
|
if model_name in CHAT_COMPLETION_MODELS:
|
|
|
|
model = ChatCompletionModel(model_name, user_name)
|
|
|
|
elif model_name in COMPLETION_MODELS:
|
|
|
|
model = CompletionModel(model_name, user_name)
|
|
|
|
else :
|
|
|
|
log = "找不到模型[{}],请检查配置文件".format(model_name)
|
|
|
|
logging.error(log)
|
|
|
|
raise IndexError(log)
|
2023-01-01 22:52:27 +08:00
|
|
|
|
2023-03-02 15:31:12 +08:00
|
|
|
return model
|