From c8bc3892b339e0c36e40683e05d870399a011dca Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Sun, 29 Sep 2024 14:44:22 +0800 Subject: [PATCH] refactor: invoke tool from dify --- api/core/plugin/manager/tool.py | 52 +++++++++++++++++++++++++++++---- api/core/tools/tool_manager.py | 8 ++--- 2 files changed, 50 insertions(+), 10 deletions(-) diff --git a/api/core/plugin/manager/tool.py b/api/core/plugin/manager/tool.py index 4f5fa1fa5c..5981bcb55e 100644 --- a/api/core/plugin/manager/tool.py +++ b/api/core/plugin/manager/tool.py @@ -7,9 +7,23 @@ from core.tools.entities.tool_entities import ToolInvokeMessage class PluginToolManager(BasePluginManager): + def _split_provider(self, provider: str) -> tuple[str, str]: + """ + split the provider to plugin_id and provider_name + + provider follows format: plugin_id/provider_name + """ + if "/" in provider: + parts = provider.split("/", 1) + if len(parts) == 2: + return parts[0], parts[1] + raise ValueError(f"invalid provider format: {provider}") + + raise ValueError(f"invalid provider format: {provider}") + def fetch_tool_providers(self, tenant_id: str) -> list[PluginToolProviderEntity]: """ - Fetch tool providers for the given asset. + Fetch tool providers for the given tenant. """ def transformer(json_response: dict[str, Any]) -> dict: @@ -28,18 +42,44 @@ class PluginToolManager(BasePluginManager): params={"page": 1, "page_size": 256}, transformer=transformer, ) + + for provider in response: + provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}" + + return response + + def fetch_tool_provider(self, tenant_id: str, provider: str) -> PluginToolProviderEntity: + """ + Fetch tool provider for the given tenant and plugin. + """ + plugin_id, provider_name = self._split_provider(provider) + + response = self._request_with_plugin_daemon_response( + "GET", + f"plugin/{tenant_id}/management/tool", + PluginToolProviderEntity, + params={"provider": provider_name, "plugin_id": plugin_id}, + ) + + response.declaration.identity.name = f"{response.plugin_id}/{response.declaration.identity.name}" + return response def invoke( self, tenant_id: str, user_id: str, - plugin_id: str, tool_provider: str, tool_name: str, credentials: dict[str, Any], tool_parameters: dict[str, Any], ) -> Generator[ToolInvokeMessage, None, None]: + """ + Invoke the tool with the given tenant, user, plugin, provider, name, credentials and parameters. + """ + + plugin_id, provider_name = self._split_provider(tool_provider) + response = self._request_with_plugin_daemon_response_stream( "POST", f"plugin/{tenant_id}/dispatch/tool/invoke", @@ -47,7 +87,7 @@ class PluginToolManager(BasePluginManager): data={ "user_id": user_id, "data": { - "provider": tool_provider, + "provider": provider_name, "tool": tool_name, "credentials": credentials, "tool_parameters": tool_parameters, @@ -61,11 +101,13 @@ class PluginToolManager(BasePluginManager): return response def validate_provider_credentials( - self, tenant_id: str, user_id: str, plugin_id: str, provider: str, credentials: dict[str, Any] + self, tenant_id: str, user_id: str, provider: str, credentials: dict[str, Any] ) -> bool: """ validate the credentials of the provider """ + plugin_id, provider_name = self._split_provider(provider) + response = self._request_with_plugin_daemon_response_stream( "POST", f"plugin/{tenant_id}/dispatch/tool/validate_credentials", @@ -73,7 +115,7 @@ class PluginToolManager(BasePluginManager): data={ "user_id": user_id, "data": { - "provider": provider, + "provider": provider_name, "credentials": credentials, }, }, diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 698ce3e900..0463f84817 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -59,6 +59,8 @@ class ToolManager: :param tenant_id: the id of the tenant :return: the provider """ + # split provider to + if len(cls._hardcoded_providers) == 0: # init the builtin providers cls.load_hardcoded_providers_cache() @@ -77,8 +79,7 @@ class ToolManager: get the plugin provider """ manager = PluginToolManager() - providers = manager.fetch_tool_providers(tenant_id) - provider_entity = next((x for x in providers if x.declaration.identity.name == provider), None) + provider_entity = manager.fetch_tool_provider(tenant_id, provider) if not provider_entity: raise ToolProviderNotFoundError(f"plugin provider {provider} not found") @@ -181,9 +182,6 @@ class ToolManager: ) elif provider_type == ToolProviderType.API: - if tenant_id is None: - raise ValueError("tenant id is required for api provider") - api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id) # decrypt the credentials