2023-03-05 15:39:13 +08:00
|
|
|
|
"""OpenAI 接口底层封装
|
|
|
|
|
|
|
|
|
|
目前使用的对话接口有:
|
|
|
|
|
ChatCompletion - gpt-3.5-turbo 等模型
|
|
|
|
|
Completion - text-davinci-003 等模型
|
|
|
|
|
此模块封装此两个接口的请求实现,为上层提供统一的调用方式
|
|
|
|
|
"""
|
2023-03-03 14:12:53 +08:00
|
|
|
|
import openai, logging, threading, asyncio
|
2023-03-05 13:52:43 +08:00
|
|
|
|
import openai.error as aiE
|
2023-03-02 15:31:12 +08:00
|
|
|
|
|
2023-03-02 17:57:39 +08:00
|
|
|
|
COMPLETION_MODELS = {
|
|
|
|
|
'text-davinci-003',
|
|
|
|
|
'text-davinci-002',
|
|
|
|
|
'code-davinci-002',
|
|
|
|
|
'code-cushman-001',
|
|
|
|
|
'text-curie-001',
|
|
|
|
|
'text-babbage-001',
|
|
|
|
|
'text-ada-001',
|
2023-03-02 15:31:12 +08:00
|
|
|
|
}
|
2023-01-01 22:52:27 +08:00
|
|
|
|
|
2023-03-02 17:57:39 +08:00
|
|
|
|
CHAT_COMPLETION_MODELS = {
|
|
|
|
|
'gpt-3.5-turbo',
|
|
|
|
|
'gpt-3.5-turbo-0301',
|
2023-03-18 12:38:48 +08:00
|
|
|
|
'gpt-4',
|
|
|
|
|
'gpt-4-0314',
|
|
|
|
|
'gpt-4-32k',
|
|
|
|
|
'gpt-4-32k-0314'
|
2023-01-01 22:52:27 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
EDIT_MODELS = {
|
|
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
IMAGE_MODELS = {
|
|
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
2023-03-02 16:41:03 +08:00
|
|
|
|
|
2023-03-05 15:39:13 +08:00
|
|
|
|
class ModelRequest:
|
|
|
|
|
"""模型接口请求父类"""
|
|
|
|
|
|
2023-03-02 15:31:12 +08:00
|
|
|
|
can_chat = False
|
2023-03-05 15:39:13 +08:00
|
|
|
|
runtime: threading.Thread = None
|
2023-03-05 13:52:43 +08:00
|
|
|
|
ret = {}
|
2023-03-05 15:39:13 +08:00
|
|
|
|
proxy: str = None
|
2023-03-05 13:52:43 +08:00
|
|
|
|
request_ready = True
|
2023-03-05 15:39:13 +08:00
|
|
|
|
error_info: str = "若在没有任何错误的情况下看到这句话,请带着配置文件上报Issues"
|
2023-03-02 15:31:12 +08:00
|
|
|
|
|
2023-03-05 13:52:43 +08:00
|
|
|
|
def __init__(self, model_name, user_name, request_fun, http_proxy:str = None, time_out = None):
|
2023-03-02 15:31:12 +08:00
|
|
|
|
self.model_name = model_name
|
|
|
|
|
self.user_name = user_name
|
|
|
|
|
self.request_fun = request_fun
|
2023-03-05 13:52:43 +08:00
|
|
|
|
self.time_out = time_out
|
2023-03-03 14:12:53 +08:00
|
|
|
|
if http_proxy != None:
|
|
|
|
|
self.proxy = http_proxy
|
|
|
|
|
openai.proxy = self.proxy
|
2023-03-05 13:52:43 +08:00
|
|
|
|
self.request_ready = False
|
2023-03-03 14:12:53 +08:00
|
|
|
|
|
|
|
|
|
async def __a_request__(self, **kwargs):
|
2023-03-05 15:39:13 +08:00
|
|
|
|
"""异步请求"""
|
|
|
|
|
|
2023-03-05 13:52:43 +08:00
|
|
|
|
try:
|
|
|
|
|
self.ret:dict = await self.request_fun(**kwargs)
|
|
|
|
|
self.request_ready = True
|
|
|
|
|
except aiE.APIConnectionError as e:
|
|
|
|
|
self.error_info = "{}\n请检查网络连接或代理是否正常".format(e)
|
|
|
|
|
raise ConnectionError(self.error_info)
|
2023-03-05 14:06:07 +08:00
|
|
|
|
except ValueError as e:
|
|
|
|
|
self.error_info = "{}\n该错误可能是由于http_proxy格式设置错误引起的"
|
2023-03-05 13:52:43 +08:00
|
|
|
|
except Exception as e:
|
|
|
|
|
self.error_info = "{}\n由于请求异常产生的未知错误,请查看日志".format(e)
|
2023-04-08 18:26:08 +08:00
|
|
|
|
raise type(e)(self.error_info)
|
2023-03-02 15:31:12 +08:00
|
|
|
|
|
|
|
|
|
def request(self, **kwargs):
|
2023-03-05 15:39:13 +08:00
|
|
|
|
"""向接口发起请求"""
|
|
|
|
|
|
2023-03-03 14:12:53 +08:00
|
|
|
|
if self.proxy != None: #异步请求
|
2023-03-05 13:52:43 +08:00
|
|
|
|
self.request_ready = False
|
2023-03-03 15:20:42 +08:00
|
|
|
|
loop = asyncio.new_event_loop()
|
2023-03-03 14:12:53 +08:00
|
|
|
|
self.runtime = threading.Thread(
|
2023-03-03 15:20:42 +08:00
|
|
|
|
target=loop.run_until_complete,
|
2023-03-03 14:12:53 +08:00
|
|
|
|
args=(self.__a_request__(**kwargs),)
|
|
|
|
|
)
|
|
|
|
|
self.runtime.start()
|
|
|
|
|
else: #同步请求
|
|
|
|
|
self.ret = self.request_fun(**kwargs)
|
2023-03-02 15:31:12 +08:00
|
|
|
|
|
2023-03-03 00:07:53 +08:00
|
|
|
|
def __msg_handle__(self, msg):
|
|
|
|
|
"""将prompt dict转换成接口需要的格式"""
|
2023-03-02 15:31:12 +08:00
|
|
|
|
return msg
|
|
|
|
|
|
2023-03-02 23:50:51 +08:00
|
|
|
|
def ret_handle(self):
|
2023-03-03 14:12:53 +08:00
|
|
|
|
'''
|
|
|
|
|
API消息返回处理函数
|
|
|
|
|
若重写该方法,应检查异步线程状态,或在需要检查处super该方法
|
|
|
|
|
'''
|
|
|
|
|
if self.runtime != None and isinstance(self.runtime, threading.Thread):
|
2023-03-05 13:52:43 +08:00
|
|
|
|
self.runtime.join(self.time_out)
|
|
|
|
|
if self.request_ready:
|
|
|
|
|
return
|
|
|
|
|
raise Exception(self.error_info)
|
2023-03-03 14:12:53 +08:00
|
|
|
|
|
2023-03-02 15:31:12 +08:00
|
|
|
|
def get_total_tokens(self):
|
2023-03-03 14:12:53 +08:00
|
|
|
|
try:
|
|
|
|
|
return self.ret['usage']['total_tokens']
|
2023-03-05 13:52:43 +08:00
|
|
|
|
except:
|
2023-03-03 14:12:53 +08:00
|
|
|
|
return 0
|
|
|
|
|
|
2023-03-02 15:31:12 +08:00
|
|
|
|
def get_message(self):
|
|
|
|
|
return self.message
|
2023-03-03 14:12:53 +08:00
|
|
|
|
|
2023-03-02 15:31:12 +08:00
|
|
|
|
def get_response(self):
|
|
|
|
|
return self.ret
|
|
|
|
|
|
2023-03-05 15:39:13 +08:00
|
|
|
|
|
2023-03-02 19:50:31 +08:00
|
|
|
|
class ChatCompletionModel(ModelRequest):
|
2023-03-05 15:39:13 +08:00
|
|
|
|
"""ChatCompletion接口的请求实现"""
|
|
|
|
|
|
2023-03-02 16:47:50 +08:00
|
|
|
|
Chat_role = ['system', 'user', 'assistant']
|
2023-03-03 14:12:53 +08:00
|
|
|
|
def __init__(self, model_name, user_name, http_proxy:str = None, **kwargs):
|
|
|
|
|
if http_proxy == None:
|
|
|
|
|
request_fun = openai.ChatCompletion.create
|
|
|
|
|
else:
|
|
|
|
|
request_fun = openai.ChatCompletion.acreate
|
2023-03-02 15:31:12 +08:00
|
|
|
|
self.can_chat = True
|
2023-03-03 14:12:53 +08:00
|
|
|
|
super().__init__(model_name, user_name, request_fun, http_proxy, **kwargs)
|
2023-03-02 15:31:12 +08:00
|
|
|
|
|
2023-03-03 00:07:53 +08:00
|
|
|
|
def request(self, prompts, **kwargs):
|
2023-03-03 14:12:53 +08:00
|
|
|
|
prompts = self.__msg_handle__(prompts)
|
|
|
|
|
kwargs['messages'] = prompts
|
|
|
|
|
super().request(**kwargs)
|
2023-03-02 23:50:51 +08:00
|
|
|
|
self.ret_handle()
|
2023-01-01 22:52:27 +08:00
|
|
|
|
|
2023-03-03 00:07:53 +08:00
|
|
|
|
def __msg_handle__(self, msgs):
|
2023-03-02 16:47:50 +08:00
|
|
|
|
temp_msgs = []
|
2023-03-03 00:07:53 +08:00
|
|
|
|
# 把msgs拷贝进temp_msgs
|
2023-03-02 16:47:50 +08:00
|
|
|
|
for msg in msgs:
|
2023-03-03 00:07:53 +08:00
|
|
|
|
temp_msgs.append(msg.copy())
|
2023-03-02 16:47:50 +08:00
|
|
|
|
return temp_msgs
|
|
|
|
|
|
2023-03-03 14:12:53 +08:00
|
|
|
|
def get_message(self):
|
|
|
|
|
return self.ret["choices"][0]["message"]['content'] #需要时直接加载加快请求速度,降低内存消耗
|
|
|
|
|
|
2023-01-01 22:52:27 +08:00
|
|
|
|
|
2023-03-02 19:50:31 +08:00
|
|
|
|
class CompletionModel(ModelRequest):
|
2023-03-05 15:39:13 +08:00
|
|
|
|
"""Completion接口的请求实现"""
|
|
|
|
|
|
2023-03-03 14:12:53 +08:00
|
|
|
|
def __init__(self, model_name, user_name, http_proxy:str = None, **kwargs):
|
|
|
|
|
if http_proxy == None:
|
|
|
|
|
request_fun = openai.Completion.create
|
|
|
|
|
else:
|
|
|
|
|
request_fun = openai.Completion.acreate
|
|
|
|
|
super().__init__(model_name, user_name, request_fun, http_proxy, **kwargs)
|
2023-01-01 22:52:27 +08:00
|
|
|
|
|
2023-03-03 00:07:53 +08:00
|
|
|
|
def request(self, prompts, **kwargs):
|
2023-03-03 14:12:53 +08:00
|
|
|
|
prompts = self.__msg_handle__(prompts)
|
|
|
|
|
kwargs['prompt'] = prompts
|
|
|
|
|
super().request(**kwargs)
|
2023-03-02 23:50:51 +08:00
|
|
|
|
self.ret_handle()
|
2023-01-01 22:52:27 +08:00
|
|
|
|
|
2023-03-03 00:07:53 +08:00
|
|
|
|
def __msg_handle__(self, msgs):
|
2023-03-02 15:31:12 +08:00
|
|
|
|
prompt = ''
|
|
|
|
|
for msg in msgs:
|
2023-03-03 00:07:53 +08:00
|
|
|
|
prompt = prompt + "{}: {}\n".format(msg['role'], msg['content'])
|
|
|
|
|
# for msg in msgs:
|
|
|
|
|
# if msg['role'] == 'assistant':
|
|
|
|
|
# prompt = prompt + "{}\n".format(msg['content'])
|
|
|
|
|
# else:
|
|
|
|
|
# prompt = prompt + "{}:{}\n".format(msg['role'] , msg['content'])
|
|
|
|
|
prompt = prompt + "assistant: "
|
2023-03-02 15:31:12 +08:00
|
|
|
|
return prompt
|
2023-01-01 22:52:27 +08:00
|
|
|
|
|
2023-03-03 14:12:53 +08:00
|
|
|
|
def get_message(self):
|
|
|
|
|
return self.ret["choices"][0]["text"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_openai_model_request(model_name: str, user_name: str = 'user', http_proxy:str = None) -> ModelRequest:
|
2023-03-02 19:50:31 +08:00
|
|
|
|
"""使用给定的模型名称创建模型请求对象"""
|
2023-03-02 15:31:12 +08:00
|
|
|
|
if model_name in CHAT_COMPLETION_MODELS:
|
2023-03-03 14:12:53 +08:00
|
|
|
|
model = ChatCompletionModel(model_name, user_name, http_proxy)
|
2023-03-02 15:31:12 +08:00
|
|
|
|
elif model_name in COMPLETION_MODELS:
|
2023-03-03 14:12:53 +08:00
|
|
|
|
model = CompletionModel(model_name, user_name, http_proxy)
|
2023-03-02 15:31:12 +08:00
|
|
|
|
else :
|
|
|
|
|
log = "找不到模型[{}],请检查配置文件".format(model_name)
|
|
|
|
|
logging.error(log)
|
|
|
|
|
raise IndexError(log)
|
2023-03-02 23:20:28 +08:00
|
|
|
|
logging.debug("使用接口[{}]创建模型请求[{}]".format(model.__class__.__name__, model_name))
|
2023-03-02 15:31:12 +08:00
|
|
|
|
return model
|