Merge pull request #523 from RockChinQ/feat-prompt-preprocess-event

[Feat] 新增PromptPreprocessing事件
This commit is contained in:
Junyan Qin 2023-07-31 17:55:06 +08:00 committed by GitHub
commit a2bda85a9c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 126 additions and 7 deletions

View File

@ -214,7 +214,29 @@ class Session:
config = pkg.utils.context.get_config()
max_length = config.prompt_submit_length
prompts, _ = self.cut_out(text, max_length)
local_default_prompt = self.default_prompt.copy()
local_prompt = self.prompt.copy()
# 触发PromptPreProcessing事件
args = {
'session_name': self.name,
'default_prompt': self.default_prompt,
'prompt': self.prompt,
'text_message': text,
}
event = pkg.plugin.host.emit(plugin_models.PromptPreProcessing, **args)
if event.get_return_value('default_prompt') is not None:
local_default_prompt = event.get_return_value('default_prompt')
if event.get_return_value('prompt') is not None:
local_prompt = event.get_return_value('prompt')
if event.get_return_value('text_message') is not None:
text = event.get_return_value('text_message')
prompts, _ = self.cut_out(text, max_length, local_default_prompt, local_prompt)
res_text = ""
@ -301,7 +323,7 @@ class Session:
return question
# 构建对话体
def cut_out(self, msg: str, max_tokens: int) -> tuple[list, list]:
def cut_out(self, msg: str, max_tokens: int, default_prompt: list, prompt: list) -> tuple[list, list]:
"""将现有prompt进行切割处理使得新的prompt长度不超过max_tokens
:return: (新的prompt, 新的token_counts)
@ -317,19 +339,19 @@ class Session:
use_model = pkg.utils.context.get_config().completion_api_params['model']
ptr = len(self.prompt) - 1
ptr = len(prompt) - 1
# 直接从后向前扫描拼接,不管是否是整回合
while ptr >= 0:
if count_tokens(self.prompt[ptr:ptr+1]+changable_prompts, use_model) > max_tokens:
if count_tokens(prompt[ptr:ptr+1]+changable_prompts, use_model) > max_tokens:
break
changable_prompts.insert(0, self.prompt[ptr])
changable_prompts.insert(0, prompt[ptr])
ptr -= 1
# 将default_prompt和changable_prompts合并
result_prompt = self.default_prompt + changable_prompts
result_prompt = default_prompt + changable_prompts
# 添加当前问题
if msg:

View File

@ -266,7 +266,7 @@ class EventContext:
self.__return_value__[key] = []
self.__return_value__[key].append(ret)
def get_return(self, key: str):
def get_return(self, key: str) -> list:
"""获取key的所有返回值"""
if key in self.__return_value__:
return self.__return_value__[key]

View File

@ -132,6 +132,20 @@ KeySwitched = "key_switched"
key_list: list[str] api-key列表
"""
PromptPreProcessing = "prompt_pre_processing"
"""每回合调用接口前对prompt进行预处理时触发此事件不支持阻止默认行为
kwargs:
session_name: str 会话名称(<launcher_type>_<launcher_id>)
default_prompt: list 此session使用的情景预设内容
prompt: list 此session现有的prompt内容
text_message: str 用户发送的消息文本
returns (optional):
default_prompt: list 修改后的情景预设内容
prompt: list 修改后的prompt内容
text_message: str 修改后的消息文本
"""
def on(*args, **kwargs):
"""注册事件监听器
@ -150,6 +164,32 @@ def func(*args, **kwargs):
__current_registering_plugin__ = ""
def require_ver(ge: str, le: str="v999.9.9") -> bool:
"""插件版本要求装饰器
Args:
ge (str): 最低版本要求
le (str, optional): 最高版本要求
Returns:
bool: 是否满足要求, False时为无法获取版本号True时为满足要求报错为不满足要求
"""
qchatgpt_version = ""
from pkg.utils.updater import get_current_tag, compare_version_str
try:
qchatgpt_version = get_current_tag() # 从updater模块获取版本号
except:
return False
if compare_version_str(qchatgpt_version, ge) < 0 or \
(compare_version_str(qchatgpt_version, le) > 0):
raise Exception("QChatGPT 版本不满足要求,某些功能(可能是由插件提供的)无法正常使用。(要求版本:{}-{},但当前版本:{}".format(ge, le, qchatgpt_version))
return True
class Plugin:
"""插件基类"""

View File

@ -78,6 +78,34 @@ def get_current_tag() -> str:
return current_tag
def compare_version_str(v0: str, v1: str) -> int:
"""比较两个版本号"""
# 删除版本号前的v
if v0.startswith("v"):
v0 = v0[1:]
if v1.startswith("v"):
v1 = v1[1:]
v0:list = v0.split(".")
v1:list = v1.split(".")
# 如果两个版本号节数不同把短的后面用0补齐
if len(v0) < len(v1):
v0.extend(["0"]*(len(v1)-len(v0)))
elif len(v0) > len(v1):
v1.extend(["0"]*(len(v0)-len(v1)))
# 从高位向低位比较
for i in range(len(v0)):
if int(v0[i]) > int(v1[i]):
return 1
elif int(v0[i]) < int(v1[i]):
return -1
return 0
def update_all(cli: bool = False) -> bool:
"""检查更新并下载源码"""
current_tag = get_current_tag()

View File

@ -291,6 +291,21 @@ class HelloPlugin(Plugin):
- 这仅仅是一个示例,需要更高效的网络访问能力支持插件,请查看[WebwlkrPlugin](https://github.com/RockChinQ/WebwlkrPlugin)
## 🔒版本要求
若您的插件对主程序的版本有要求,可以使用以下函数进行断言,若不符合版本,此函数将报错并打断此函数所在的流程:
```python
require_ver("v2.5.1") # 要求最低版本为 v2.5.1
```
```python
require_ver("v2.5.1", "v2.6.0") # 要求最低版本为 v2.5.1, 同时要求最高版本为 v2.6.0
```
- 此函数在主程序`v2.5.1`中加入
- 此函数声明在`pkg.plugin.models`模块中,在插件示例代码最前方已引入此模块所有内容,故可直接使用
## 📄API参考
### 说明
@ -435,6 +450,20 @@ KeySwitched = "key_switched"
key_name: str 切换成功的api-key名称
key_list: list[str] api-key列表
"""
PromptPreProcessing = "prompt_pre_processing" # 于v2.5.1加入
"""每回合调用接口前对prompt进行预处理时触发此事件不支持阻止默认行为
kwargs:
session_name: str 会话名称(<launcher_type>_<launcher_id>)
default_prompt: list 此session使用的情景预设内容
prompt: list 此session现有的prompt内容
text_message: str 用户发送的消息文本
returns (optional):
default_prompt: list 修改后的情景预设内容
prompt: list 修改后的prompt内容
text_message: str 修改后的消息文本
"""
```
### host: PluginHost 详解