refactor: tool parameter cache (#3703)

This commit is contained in:
Yeuoly 2024-04-23 15:22:42 +08:00 committed by GitHub
parent 65ac4f69af
commit 3480f1c59e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 95 additions and 57 deletions

View File

@ -1,5 +1,3 @@
import json
from flask_login import current_user
from flask_restful import Resource, inputs, marshal_with, reqparse
from werkzeug.exceptions import BadRequest, Forbidden
@ -8,17 +6,12 @@ from controllers.console import api
from controllers.console.app.wraps import get_app_model
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from core.agent.entities import AgentToolEntity
from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ToolParameterConfigurationManager
from extensions.ext_database import db
from fields.app_fields import (
app_detail_fields,
app_detail_fields_with_site,
app_pagination_fields,
)
from libs.login import login_required
from models.model import App, AppMode, AppModelConfig
from services.app_service import AppService
ALLOW_CREATE_APP_MODES = ['chat', 'agent-chat', 'advanced-chat', 'workflow', 'completion']
@ -108,43 +101,9 @@ class AppApi(Resource):
@marshal_with(app_detail_fields_with_site)
def get(self, app_model):
"""Get app detail"""
# get original app model config
if app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent:
model_config: AppModelConfig = app_model.app_model_config
agent_mode = model_config.agent_mode_dict
# decrypt agent tool parameters if it's secret-input
for tool in agent_mode.get('tools') or []:
if not isinstance(tool, dict) or len(tool.keys()) <= 3:
continue
agent_tool_entity = AgentToolEntity(**tool)
# get tool
try:
tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id=current_user.current_tenant_id,
agent_tool=agent_tool_entity,
)
manager = ToolParameterConfigurationManager(
tenant_id=current_user.current_tenant_id,
tool_runtime=tool_runtime,
provider_name=agent_tool_entity.provider_id,
provider_type=agent_tool_entity.provider_type,
)
app_service = AppService()
# get decrypted parameters
if agent_tool_entity.tool_parameters:
parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
masked_parameter = manager.mask_tool_parameters(parameters or {})
else:
masked_parameter = {}
# override tool parameters
tool['tool_parameters'] = masked_parameter
except Exception as e:
pass
# override agent mode
model_config.agent_mode = json.dumps(agent_mode)
db.session.commit()
app_model = app_service.get_app(app_model)
return app_model

View File

@ -57,6 +57,7 @@ class ModelConfigResource(Resource):
try:
tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id=current_user.current_tenant_id,
app_id=app_model.id,
agent_tool=agent_tool_entity,
)
manager = ToolParameterConfigurationManager(
@ -64,6 +65,7 @@ class ModelConfigResource(Resource):
tool_runtime=tool_runtime,
provider_name=agent_tool_entity.provider_id,
provider_type=agent_tool_entity.provider_type,
identity_id=f'AGENT.{app_model.id}'
)
except Exception as e:
continue
@ -94,6 +96,7 @@ class ModelConfigResource(Resource):
try:
tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id=current_user.current_tenant_id,
app_id=app_model.id,
agent_tool=agent_tool_entity,
)
except Exception as e:
@ -104,6 +107,7 @@ class ModelConfigResource(Resource):
tool_runtime=tool_runtime,
provider_name=agent_tool_entity.provider_id,
provider_type=agent_tool_entity.provider_type,
identity_id=f'AGENT.{app_model.id}'
)
manager.delete_tool_parameters_cache()
@ -111,9 +115,11 @@ class ModelConfigResource(Resource):
if agent_tool_entity.tool_parameters:
if key not in masked_parameter_map:
continue
if agent_tool_entity.tool_parameters == masked_parameter_map[key]:
agent_tool_entity.tool_parameters = parameter_map[key]
for masked_key, masked_value in masked_parameter_map[key].items():
if masked_key in agent_tool_entity.tool_parameters and \
agent_tool_entity.tool_parameters[masked_key] == masked_value:
agent_tool_entity.tool_parameters[masked_key] = parameter_map[key].get(masked_key)
# encrypt parameters
if agent_tool_entity.tool_parameters:

View File

@ -163,6 +163,7 @@ class BaseAgentRunner(AppRunner):
"""
tool_entity = ToolManager.get_agent_tool_runtime(
tenant_id=self.tenant_id,
app_id=self.app_config.app_id,
agent_tool=tool,
)
tool_entity.load_variables(self.variables_pool)

View File

@ -11,12 +11,13 @@ class ToolParameterCacheType(Enum):
class ToolParameterCache:
def __init__(self,
tenant_id: str,
provider: str,
tool_name: str,
cache_type: ToolParameterCacheType
tenant_id: str,
provider: str,
tool_name: str,
cache_type: ToolParameterCacheType,
identity_id: str
):
self.cache_key = f"{cache_type.value}_secret:tenant_id:{tenant_id}:provider:{provider}:tool_name:{tool_name}"
self.cache_key = f"{cache_type.value}_secret:tenant_id:{tenant_id}:provider:{provider}:tool_name:{tool_name}:identity_id:{identity_id}"
def get(self) -> Optional[dict]:
"""

View File

@ -222,7 +222,7 @@ class ToolManager:
return parameter_value
@classmethod
def get_agent_tool_runtime(cls, tenant_id: str, agent_tool: AgentToolEntity) -> Tool:
def get_agent_tool_runtime(cls, tenant_id: str, app_id: str, agent_tool: AgentToolEntity) -> Tool:
"""
get the agent tool runtime
"""
@ -245,6 +245,7 @@ class ToolManager:
tool_runtime=tool_entity,
provider_name=agent_tool.provider_id,
provider_type=agent_tool.provider_type,
identity_id=f'AGENT.{app_id}'
)
runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
@ -252,7 +253,7 @@ class ToolManager:
return tool_entity
@classmethod
def get_workflow_tool_runtime(cls, tenant_id: str, workflow_tool: ToolEntity):
def get_workflow_tool_runtime(cls, tenant_id: str, app_id: str, node_id: str, workflow_tool: ToolEntity):
"""
get the workflow tool runtime
"""
@ -277,6 +278,7 @@ class ToolManager:
tool_runtime=tool_entity,
provider_name=workflow_tool.provider_id,
provider_type=workflow_tool.provider_type,
identity_id=f'WORKFLOW.{app_id}.{node_id}'
)
if runtime_parameters:

View File

@ -113,12 +113,13 @@ class ToolParameterConfigurationManager(BaseModel):
tool_runtime: Tool
provider_name: str
provider_type: str
identity_id: str
def _deep_copy(self, parameters: dict[str, Any]) -> dict[str, Any]:
"""
deep copy parameters
"""
return {key: value for key, value in parameters.items()}
return deepcopy(parameters)
def _merge_parameters(self) -> list[ToolParameter]:
"""
@ -176,6 +177,8 @@ class ToolParameterConfigurationManager(BaseModel):
# override parameters
current_parameters = self._merge_parameters()
parameters = self._deep_copy(parameters)
for parameter in current_parameters:
if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT:
if parameter.name in parameters:
@ -194,7 +197,8 @@ class ToolParameterConfigurationManager(BaseModel):
tenant_id=self.tenant_id,
provider=f'{self.provider_type}.{self.provider_name}',
tool_name=self.tool_runtime.identity.name,
cache_type=ToolParameterCacheType.PARAMETER
cache_type=ToolParameterCacheType.PARAMETER,
identity_id=self.identity_id
)
cached_parameters = cache.get()
if cached_parameters:
@ -223,7 +227,8 @@ class ToolParameterConfigurationManager(BaseModel):
tenant_id=self.tenant_id,
provider=f'{self.provider_type}.{self.provider_name}',
tool_name=self.tool_runtime.identity.name,
cache_type=ToolParameterCacheType.PARAMETER
cache_type=ToolParameterCacheType.PARAMETER,
identity_id=self.identity_id
)
cache.delete()

View File

@ -39,7 +39,8 @@ class ToolNode(BaseNode):
parameters = self._generate_parameters(variable_pool, node_data)
# get tool runtime
try:
tool_runtime = ToolManager.get_workflow_tool_runtime(self.tenant_id, node_data)
self.app_id
tool_runtime = ToolManager.get_workflow_tool_runtime(self.tenant_id, self.app_id, self.node_id, node_data)
except Exception as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,

View File

@ -22,5 +22,6 @@ def handle(sender, **kwargs):
tool_runtime=tool_runtime,
provider_name=tool_entity.provider_name,
provider_type=tool_entity.provider_type,
identity_id=f'WORKFLOW.{app.id}.{node_data.get("id")}'
)
manager.delete_tool_parameters_cache()

View File

@ -5,13 +5,17 @@ from typing import cast
import yaml
from flask import current_app
from flask_login import current_user
from flask_sqlalchemy.pagination import Pagination
from constants.model_template import default_app_templates
from core.agent.entities import AgentToolEntity
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ToolParameterConfigurationManager
from events.app_event import app_model_config_was_updated, app_was_created, app_was_deleted
from extensions.ext_database import db
from models.account import Account
@ -240,6 +244,64 @@ class AppService:
return yaml.dump(export_data)
def get_app(self, app: App) -> App:
"""
Get App
"""
# get original app model config
if app.mode == AppMode.AGENT_CHAT.value or app.is_agent:
model_config: AppModelConfig = app.app_model_config
agent_mode = model_config.agent_mode_dict
# decrypt agent tool parameters if it's secret-input
for tool in agent_mode.get('tools') or []:
if not isinstance(tool, dict) or len(tool.keys()) <= 3:
continue
agent_tool_entity = AgentToolEntity(**tool)
# get tool
try:
tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id=current_user.current_tenant_id,
app_id=app.id,
agent_tool=agent_tool_entity,
)
manager = ToolParameterConfigurationManager(
tenant_id=current_user.current_tenant_id,
tool_runtime=tool_runtime,
provider_name=agent_tool_entity.provider_id,
provider_type=agent_tool_entity.provider_type,
identity_id=f'AGENT.{app.id}'
)
# get decrypted parameters
if agent_tool_entity.tool_parameters:
parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
masked_parameter = manager.mask_tool_parameters(parameters or {})
else:
masked_parameter = {}
# override tool parameters
tool['tool_parameters'] = masked_parameter
except Exception as e:
pass
# override agent mode
model_config.agent_mode = json.dumps(agent_mode)
class ModifiedApp(App):
"""
Modified App class
"""
def __init__(self, app):
self.__dict__.update(app.__dict__)
@property
def app_model_config(self):
return model_config
app = ModifiedApp(app)
return app
def update_app(self, app: App, args: dict) -> App:
"""
Update app