QChatGPT/pkg/provider/entities.py

120 lines
3.2 KiB
Python
Raw Normal View History

from __future__ import annotations
import typing
import enum
import pydantic
import mirai
2024-01-27 21:50:40 +08:00
class FunctionCall(pydantic.BaseModel):
name: str
2024-01-27 21:50:40 +08:00
arguments: str
2024-01-27 21:50:40 +08:00
class ToolCall(pydantic.BaseModel):
id: str
2024-01-27 21:50:40 +08:00
type: str
2024-01-27 21:50:40 +08:00
function: FunctionCall
2024-05-15 21:40:18 +08:00
class ImageURLContentObject(pydantic.BaseModel):
url: str
2024-05-16 20:52:17 +08:00
def __str__(self):
return self.url[:128] + ('...' if len(self.url) > 128 else '')
2024-05-15 21:40:18 +08:00
class ContentElement(pydantic.BaseModel):
type: str
"""内容类型"""
text: typing.Optional[str] = None
2024-05-15 21:40:18 +08:00
image_url: typing.Optional[ImageURLContentObject] = None
def __str__(self):
if self.type == 'text':
return self.text
elif self.type == 'image_url':
return f'[图片]({self.image_url})'
else:
return '未知内容'
@classmethod
def from_text(cls, text: str):
return cls(type='text', text=text)
@classmethod
def from_image_url(cls, image_url: str):
return cls(type='image_url', image_url=ImageURLContentObject(url=image_url))
class Message(pydantic.BaseModel):
2024-03-03 16:34:59 +08:00
"""消息"""
role: str # user, system, assistant, tool, command, plugin
2024-03-22 16:41:46 +08:00
"""消息的角色"""
2024-01-27 21:50:40 +08:00
name: typing.Optional[str] = None
2024-03-22 16:41:46 +08:00
"""名称,仅函数调用返回时设置"""
2024-05-15 21:40:18 +08:00
content: typing.Optional[list[ContentElement]] | typing.Optional[str] = None
2024-03-22 16:41:46 +08:00
"""内容"""
2024-01-27 21:50:40 +08:00
tool_calls: typing.Optional[list[ToolCall]] = None
2024-03-22 16:41:46 +08:00
"""工具调用"""
2024-01-27 21:50:40 +08:00
tool_call_id: typing.Optional[str] = None
def readable_str(self) -> str:
if self.content is not None:
2024-05-15 21:40:18 +08:00
return str(self.role) + ": " + str(self.get_content_mirai_message_chain())
elif self.tool_calls is not None:
return f'调用工具: {self.tool_calls[0].id}'
else:
return '未知消息'
2024-05-15 21:40:18 +08:00
def get_content_mirai_message_chain(self, prefix_text: str="") -> mirai.MessageChain | None:
"""将内容转换为 Mirai MessageChain 对象
Args:
prefix_text (str): 首个文字组件的前缀文本
"""
if self.content is None:
return None
elif isinstance(self.content, str):
return mirai.MessageChain([mirai.Plain(prefix_text+self.content)])
elif isinstance(self.content, list):
mc = []
for ce in self.content:
if ce.type == 'text':
mc.append(mirai.Plain(ce.text))
elif ce.type == 'image_url':
if ce.image_url.url.startswith("http"):
mc.append(mirai.Image(url=ce.image_url.url))
else: # base64
b64_str = ce.image_url.url
if b64_str.startswith("data:"):
b64_str = b64_str.split(",")[1]
mc.append(mirai.Image(base64=b64_str))
2024-05-15 21:40:18 +08:00
# 找第一个文字组件
if prefix_text:
for i, c in enumerate(mc):
if isinstance(c, mirai.Plain):
mc[i] = mirai.Plain(prefix_text+c.text)
break
else:
mc.insert(0, mirai.Plain(prefix_text))
return mirai.MessageChain(mc)