feat: backwards invoke tools

This commit is contained in:
Yeuoly 2024-10-10 18:09:06 +08:00
parent 699d41deec
commit 118fa66567
No known key found for this signature in database
GPG Key ID: A66E7E320FB19F61
7 changed files with 99 additions and 16 deletions

View File

@ -1,5 +1,3 @@
import time
from flask_restful import Resource
from controllers.console.setup import setup_required
@ -10,6 +8,7 @@ from core.plugin.backwards_invocation.app import PluginAppBackwardsInvocation
from core.plugin.backwards_invocation.base import BaseBackwardsInvocationResponse
from core.plugin.backwards_invocation.model import PluginModelBackwardsInvocation
from core.plugin.backwards_invocation.node import PluginNodeBackwardsInvocation
from core.plugin.backwards_invocation.tool import PluginToolBackwardsInvocation
from core.plugin.encrypt import PluginEncrypter
from core.plugin.entities.request import (
RequestInvokeApp,
@ -24,7 +23,7 @@ from core.plugin.entities.request import (
RequestInvokeTool,
RequestInvokeTTS,
)
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.entities.tool_entities import ToolProviderType
from libs.helper import compact_generate_response
from models.account import Tenant
@ -138,17 +137,16 @@ class PluginInvokeToolApi(Resource):
@plugin_data(payload_type=RequestInvokeTool)
def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeTool):
def generator():
for i in range(10):
time.sleep(0.1)
yield (
ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.TEXT,
message=ToolInvokeMessage.TextMessage(text="helloworld"),
)
.model_dump_json()
.encode()
+ b"\n\n"
)
return PluginToolBackwardsInvocation.convert_to_event_stream(
PluginToolBackwardsInvocation.invoke_tool(
tenant_id=tenant_model.id,
user_id=user_id,
tool_type=ToolProviderType.value_of(payload.tool_type),
provider=payload.provider,
tool_name=payload.tool,
tool_parameters=payload.tool_parameters,
),
)
return compact_generate_response(generator())

View File

@ -0,0 +1,45 @@
from collections.abc import Generator
from typing import Any
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
from core.tools.tool_engine import ToolEngine
from core.tools.tool_manager import ToolManager
from core.tools.utils.message_transformer import ToolFileMessageTransformer
class PluginToolBackwardsInvocation(BaseBackwardsInvocation):
"""
Backwards invocation for plugin tools.
"""
@classmethod
def invoke_tool(
cls,
tenant_id: str,
user_id: str,
tool_type: ToolProviderType,
provider: str,
tool_name: str,
tool_parameters: dict[str, Any],
) -> Generator[ToolInvokeMessage, None, None]:
"""
invoke tool
"""
# get tool runtime
try:
tool_runtime = ToolManager.get_tool_runtime_from_plugin(
tool_type, tenant_id, provider, tool_name, tool_parameters
)
response = ToolEngine.generic_invoke(
tool_runtime, tool_parameters, user_id, DifyWorkflowCallbackHandler(), workflow_call_depth=1
)
response = ToolFileMessageTransformer.transform_tool_invoke_messages(
response, user_id=user_id, tenant_id=tenant_id
)
return response
except Exception as e:
raise e

View File

@ -32,6 +32,11 @@ class RequestInvokeTool(BaseModel):
Request to invoke a tool
"""
tool_type: Literal["builtin", "workflow", "api"]
provider: str
tool: str
tool_parameters: dict
class BaseRequestInvokeModel(BaseModel):
provider: str

View File

@ -378,6 +378,7 @@ class ToolInvokeFrom(Enum):
WORKFLOW = "workflow"
AGENT = "agent"
PLUGIN = "plugin"
class ToolProviderID:

View File

@ -131,7 +131,7 @@ class ToolEngine:
return error_response, [], ToolInvokeMeta.error_instance(error_response)
@staticmethod
def workflow_invoke(
def generic_invoke(
tool: Tool,
tool_parameters: dict[str, Any],
user_id: str,

View File

@ -365,6 +365,40 @@ class ToolManager:
tool_runtime.runtime.runtime_parameters.update(runtime_parameters)
return tool_runtime
@classmethod
def get_tool_runtime_from_plugin(
cls,
tool_type: ToolProviderType,
tenant_id: str,
provider: str,
tool_name: str,
tool_parameters: dict[str, Any],
) -> Tool:
"""
get tool runtime from plugin
"""
tool_entity = cls.get_tool_runtime(
provider_type=tool_type,
provider_id=provider,
tool_name=tool_name,
tenant_id=tenant_id,
invoke_from=InvokeFrom.SERVICE_API,
tool_invoke_from=ToolInvokeFrom.PLUGIN,
)
runtime_parameters = {}
parameters = tool_entity.get_merged_runtime_parameters()
for parameter in parameters:
if parameter.form == ToolParameter.ToolParameterForm.FORM:
# save tool parameter to tool entity memory
value = cls._init_runtime_parameter(parameter, tool_parameters)
runtime_parameters[parameter.name] = value
if not tool_entity.runtime:
raise Exception("tool missing runtime")
tool_entity.runtime.runtime_parameters.update(runtime_parameters)
return tool_entity
@classmethod
def get_builtin_provider_icon(cls, provider: str, tenant_id: str) -> tuple[str, str]:
"""

View File

@ -66,7 +66,7 @@ class ToolNode(BaseNode):
)
try:
message_stream = ToolEngine.workflow_invoke(
message_stream = ToolEngine.generic_invoke(
tool=tool_runtime,
tool_parameters=parameters,
user_id=self.user_id,