refactor: invoke tool from dify

This commit is contained in:
Yeuoly 2024-09-29 14:44:22 +08:00
parent 735e57b73a
commit c8bc3892b3
No known key found for this signature in database
GPG Key ID: A66E7E320FB19F61
2 changed files with 50 additions and 10 deletions

View File

@ -7,9 +7,23 @@ from core.tools.entities.tool_entities import ToolInvokeMessage
class PluginToolManager(BasePluginManager):
def _split_provider(self, provider: str) -> tuple[str, str]:
"""
split the provider to plugin_id and provider_name
provider follows format: plugin_id/provider_name
"""
if "/" in provider:
parts = provider.split("/", 1)
if len(parts) == 2:
return parts[0], parts[1]
raise ValueError(f"invalid provider format: {provider}")
raise ValueError(f"invalid provider format: {provider}")
def fetch_tool_providers(self, tenant_id: str) -> list[PluginToolProviderEntity]:
"""
Fetch tool providers for the given asset.
Fetch tool providers for the given tenant.
"""
def transformer(json_response: dict[str, Any]) -> dict:
@ -28,18 +42,44 @@ class PluginToolManager(BasePluginManager):
params={"page": 1, "page_size": 256},
transformer=transformer,
)
for provider in response:
provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}"
return response
def fetch_tool_provider(self, tenant_id: str, provider: str) -> PluginToolProviderEntity:
"""
Fetch tool provider for the given tenant and plugin.
"""
plugin_id, provider_name = self._split_provider(provider)
response = self._request_with_plugin_daemon_response(
"GET",
f"plugin/{tenant_id}/management/tool",
PluginToolProviderEntity,
params={"provider": provider_name, "plugin_id": plugin_id},
)
response.declaration.identity.name = f"{response.plugin_id}/{response.declaration.identity.name}"
return response
def invoke(
self,
tenant_id: str,
user_id: str,
plugin_id: str,
tool_provider: str,
tool_name: str,
credentials: dict[str, Any],
tool_parameters: dict[str, Any],
) -> Generator[ToolInvokeMessage, None, None]:
"""
Invoke the tool with the given tenant, user, plugin, provider, name, credentials and parameters.
"""
plugin_id, provider_name = self._split_provider(tool_provider)
response = self._request_with_plugin_daemon_response_stream(
"POST",
f"plugin/{tenant_id}/dispatch/tool/invoke",
@ -47,7 +87,7 @@ class PluginToolManager(BasePluginManager):
data={
"user_id": user_id,
"data": {
"provider": tool_provider,
"provider": provider_name,
"tool": tool_name,
"credentials": credentials,
"tool_parameters": tool_parameters,
@ -61,11 +101,13 @@ class PluginToolManager(BasePluginManager):
return response
def validate_provider_credentials(
self, tenant_id: str, user_id: str, plugin_id: str, provider: str, credentials: dict[str, Any]
self, tenant_id: str, user_id: str, provider: str, credentials: dict[str, Any]
) -> bool:
"""
validate the credentials of the provider
"""
plugin_id, provider_name = self._split_provider(provider)
response = self._request_with_plugin_daemon_response_stream(
"POST",
f"plugin/{tenant_id}/dispatch/tool/validate_credentials",
@ -73,7 +115,7 @@ class PluginToolManager(BasePluginManager):
data={
"user_id": user_id,
"data": {
"provider": provider,
"provider": provider_name,
"credentials": credentials,
},
},

View File

@ -59,6 +59,8 @@ class ToolManager:
:param tenant_id: the id of the tenant
:return: the provider
"""
# split provider to
if len(cls._hardcoded_providers) == 0:
# init the builtin providers
cls.load_hardcoded_providers_cache()
@ -77,8 +79,7 @@ class ToolManager:
get the plugin provider
"""
manager = PluginToolManager()
providers = manager.fetch_tool_providers(tenant_id)
provider_entity = next((x for x in providers if x.declaration.identity.name == provider), None)
provider_entity = manager.fetch_tool_provider(tenant_id, provider)
if not provider_entity:
raise ToolProviderNotFoundError(f"plugin provider {provider} not found")
@ -181,9 +182,6 @@ class ToolManager:
)
elif provider_type == ToolProviderType.API:
if tenant_id is None:
raise ValueError("tenant id is required for api provider")
api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id)
# decrypt the credentials