mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 19:59:50 +08:00
feat: backwards invoke tools
This commit is contained in:
parent
699d41deec
commit
118fa66567
|
@ -1,5 +1,3 @@
|
|||
import time
|
||||
|
||||
from flask_restful import Resource
|
||||
|
||||
from controllers.console.setup import setup_required
|
||||
|
@ -10,6 +8,7 @@ from core.plugin.backwards_invocation.app import PluginAppBackwardsInvocation
|
|||
from core.plugin.backwards_invocation.base import BaseBackwardsInvocationResponse
|
||||
from core.plugin.backwards_invocation.model import PluginModelBackwardsInvocation
|
||||
from core.plugin.backwards_invocation.node import PluginNodeBackwardsInvocation
|
||||
from core.plugin.backwards_invocation.tool import PluginToolBackwardsInvocation
|
||||
from core.plugin.encrypt import PluginEncrypter
|
||||
from core.plugin.entities.request import (
|
||||
RequestInvokeApp,
|
||||
|
@ -24,7 +23,7 @@ from core.plugin.entities.request import (
|
|||
RequestInvokeTool,
|
||||
RequestInvokeTTS,
|
||||
)
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from libs.helper import compact_generate_response
|
||||
from models.account import Tenant
|
||||
|
||||
|
@ -138,17 +137,16 @@ class PluginInvokeToolApi(Resource):
|
|||
@plugin_data(payload_type=RequestInvokeTool)
|
||||
def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeTool):
|
||||
def generator():
|
||||
for i in range(10):
|
||||
time.sleep(0.1)
|
||||
yield (
|
||||
ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.TEXT,
|
||||
message=ToolInvokeMessage.TextMessage(text="helloworld"),
|
||||
)
|
||||
.model_dump_json()
|
||||
.encode()
|
||||
+ b"\n\n"
|
||||
)
|
||||
return PluginToolBackwardsInvocation.convert_to_event_stream(
|
||||
PluginToolBackwardsInvocation.invoke_tool(
|
||||
tenant_id=tenant_model.id,
|
||||
user_id=user_id,
|
||||
tool_type=ToolProviderType.value_of(payload.tool_type),
|
||||
provider=payload.provider,
|
||||
tool_name=payload.tool,
|
||||
tool_parameters=payload.tool_parameters,
|
||||
),
|
||||
)
|
||||
|
||||
return compact_generate_response(generator())
|
||||
|
||||
|
|
45
api/core/plugin/backwards_invocation/tool.py
Normal file
45
api/core/plugin/backwards_invocation/tool.py
Normal file
|
@ -0,0 +1,45 @@
|
|||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
|
||||
|
||||
class PluginToolBackwardsInvocation(BaseBackwardsInvocation):
|
||||
"""
|
||||
Backwards invocation for plugin tools.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def invoke_tool(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
tool_type: ToolProviderType,
|
||||
provider: str,
|
||||
tool_name: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""
|
||||
invoke tool
|
||||
"""
|
||||
# get tool runtime
|
||||
try:
|
||||
tool_runtime = ToolManager.get_tool_runtime_from_plugin(
|
||||
tool_type, tenant_id, provider, tool_name, tool_parameters
|
||||
)
|
||||
response = ToolEngine.generic_invoke(
|
||||
tool_runtime, tool_parameters, user_id, DifyWorkflowCallbackHandler(), workflow_call_depth=1
|
||||
)
|
||||
|
||||
response = ToolFileMessageTransformer.transform_tool_invoke_messages(
|
||||
response, user_id=user_id, tenant_id=tenant_id
|
||||
)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
|
@ -32,6 +32,11 @@ class RequestInvokeTool(BaseModel):
|
|||
Request to invoke a tool
|
||||
"""
|
||||
|
||||
tool_type: Literal["builtin", "workflow", "api"]
|
||||
provider: str
|
||||
tool: str
|
||||
tool_parameters: dict
|
||||
|
||||
|
||||
class BaseRequestInvokeModel(BaseModel):
|
||||
provider: str
|
||||
|
|
|
@ -378,6 +378,7 @@ class ToolInvokeFrom(Enum):
|
|||
|
||||
WORKFLOW = "workflow"
|
||||
AGENT = "agent"
|
||||
PLUGIN = "plugin"
|
||||
|
||||
|
||||
class ToolProviderID:
|
||||
|
|
|
@ -131,7 +131,7 @@ class ToolEngine:
|
|||
return error_response, [], ToolInvokeMeta.error_instance(error_response)
|
||||
|
||||
@staticmethod
|
||||
def workflow_invoke(
|
||||
def generic_invoke(
|
||||
tool: Tool,
|
||||
tool_parameters: dict[str, Any],
|
||||
user_id: str,
|
||||
|
|
|
@ -365,6 +365,40 @@ class ToolManager:
|
|||
tool_runtime.runtime.runtime_parameters.update(runtime_parameters)
|
||||
return tool_runtime
|
||||
|
||||
@classmethod
|
||||
def get_tool_runtime_from_plugin(
|
||||
cls,
|
||||
tool_type: ToolProviderType,
|
||||
tenant_id: str,
|
||||
provider: str,
|
||||
tool_name: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Tool:
|
||||
"""
|
||||
get tool runtime from plugin
|
||||
"""
|
||||
tool_entity = cls.get_tool_runtime(
|
||||
provider_type=tool_type,
|
||||
provider_id=provider,
|
||||
tool_name=tool_name,
|
||||
tenant_id=tenant_id,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
tool_invoke_from=ToolInvokeFrom.PLUGIN,
|
||||
)
|
||||
runtime_parameters = {}
|
||||
parameters = tool_entity.get_merged_runtime_parameters()
|
||||
for parameter in parameters:
|
||||
if parameter.form == ToolParameter.ToolParameterForm.FORM:
|
||||
# save tool parameter to tool entity memory
|
||||
value = cls._init_runtime_parameter(parameter, tool_parameters)
|
||||
runtime_parameters[parameter.name] = value
|
||||
|
||||
if not tool_entity.runtime:
|
||||
raise Exception("tool missing runtime")
|
||||
|
||||
tool_entity.runtime.runtime_parameters.update(runtime_parameters)
|
||||
return tool_entity
|
||||
|
||||
@classmethod
|
||||
def get_builtin_provider_icon(cls, provider: str, tenant_id: str) -> tuple[str, str]:
|
||||
"""
|
||||
|
|
|
@ -66,7 +66,7 @@ class ToolNode(BaseNode):
|
|||
)
|
||||
|
||||
try:
|
||||
message_stream = ToolEngine.workflow_invoke(
|
||||
message_stream = ToolEngine.generic_invoke(
|
||||
tool=tool_runtime,
|
||||
tool_parameters=parameters,
|
||||
user_id=self.user_id,
|
||||
|
|
Loading…
Reference in New Issue
Block a user