2023-03-05 15:39:13 +08:00
|
|
|
|
"""OpenAI 接口底层封装
|
|
|
|
|
|
|
|
|
|
目前使用的对话接口有:
|
|
|
|
|
ChatCompletion - gpt-3.5-turbo 等模型
|
|
|
|
|
Completion - text-davinci-003 等模型
|
|
|
|
|
此模块封装此两个接口的请求实现,为上层提供统一的调用方式
|
|
|
|
|
"""
|
2023-03-03 14:12:53 +08:00
|
|
|
|
import openai, logging, threading, asyncio
|
2023-03-05 13:52:43 +08:00
|
|
|
|
import openai.error as aiE
|
2023-03-02 15:31:12 +08:00
|
|
|
|
|
2023-07-28 19:03:02 +08:00
|
|
|
|
from pkg.openai.api.model import RequestBase
|
|
|
|
|
from pkg.openai.api.completion import CompletionRequest
|
|
|
|
|
from pkg.openai.api.chat_completion import ChatCompletionRequest
|
|
|
|
|
|
2023-03-02 17:57:39 +08:00
|
|
|
|
COMPLETION_MODELS = {
|
|
|
|
|
'text-davinci-003',
|
|
|
|
|
'text-davinci-002',
|
|
|
|
|
'code-davinci-002',
|
|
|
|
|
'code-cushman-001',
|
|
|
|
|
'text-curie-001',
|
|
|
|
|
'text-babbage-001',
|
|
|
|
|
'text-ada-001',
|
2023-03-02 15:31:12 +08:00
|
|
|
|
}
|
2023-01-01 22:52:27 +08:00
|
|
|
|
|
2023-03-02 17:57:39 +08:00
|
|
|
|
CHAT_COMPLETION_MODELS = {
|
|
|
|
|
'gpt-3.5-turbo',
|
2023-06-16 19:35:26 +08:00
|
|
|
|
'gpt-3.5-turbo-16k',
|
|
|
|
|
'gpt-3.5-turbo-0613',
|
|
|
|
|
'gpt-3.5-turbo-16k-0613',
|
|
|
|
|
# 'gpt-3.5-turbo-0301',
|
2023-03-18 12:38:48 +08:00
|
|
|
|
'gpt-4',
|
2023-06-16 19:35:26 +08:00
|
|
|
|
'gpt-4-0613',
|
2023-03-18 12:38:48 +08:00
|
|
|
|
'gpt-4-32k',
|
2023-06-16 19:35:26 +08:00
|
|
|
|
'gpt-4-32k-0613'
|
2023-01-01 22:52:27 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
EDIT_MODELS = {
|
|
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
IMAGE_MODELS = {
|
|
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
2023-03-02 16:41:03 +08:00
|
|
|
|
|
2023-07-28 19:03:02 +08:00
|
|
|
|
def select_request_cls(model_name: str, messages: list, args: dict) -> RequestBase:
|
2023-03-02 15:31:12 +08:00
|
|
|
|
if model_name in CHAT_COMPLETION_MODELS:
|
2023-07-28 19:03:02 +08:00
|
|
|
|
return ChatCompletionRequest(model_name, messages, **args)
|
2023-03-02 15:31:12 +08:00
|
|
|
|
elif model_name in COMPLETION_MODELS:
|
2023-07-28 19:03:02 +08:00
|
|
|
|
return CompletionRequest(model_name, messages, **args)
|
|
|
|
|
raise ValueError("不支持模型[{}],请检查配置文件".format(model_name))
|