dify/api/services/tools/api_tools_manage_service.py

460 lines
17 KiB
Python
Raw Normal View History

2024-02-01 18:11:57 +08:00
import json
import logging
from httpx import get
from core.model_runtime.utils.encoders import jsonable_encoder
2024-05-27 22:01:11 +08:00
from core.tools.entities.api_entities import UserTool, UserToolProvider
from core.tools.entities.common_entities import I18nObject
2024-05-27 22:01:11 +08:00
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import (
ApiProviderAuthType,
ApiProviderSchemaType,
ToolCredentialsOption,
ToolProviderCredentials,
)
2024-05-27 22:01:11 +08:00
from core.tools.provider.api_tool_provider import ApiToolProviderController
from core.tools.tool_label_manager import ToolLabelManager
2024-02-01 18:11:57 +08:00
from core.tools.tool_manager import ToolManager
2024-03-08 20:31:13 +08:00
from core.tools.utils.configuration import ToolConfigurationManager
2024-02-01 18:11:57 +08:00
from core.tools.utils.parser import ApiBasedToolSchemaParser
from extensions.ext_database import db
2024-05-27 22:01:11 +08:00
from models.tools import ApiToolProvider
from services.tools.tools_transform_service import ToolTransformService
logger = logging.getLogger(__name__)
2024-05-27 22:01:11 +08:00
class ApiToolManageService:
@staticmethod
2024-05-27 22:01:11 +08:00
def parser_api_schema(schema: str) -> list[ApiToolBundle]:
"""
parse api schema to tool bundle
"""
try:
warnings = {}
try:
tool_bundles, schema_type = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, warning=warnings)
except Exception as e:
raise ValueError(f"invalid schema: {str(e)}")
credentials_schema = [
ToolProviderCredentials(
name="auth_type",
type=ToolProviderCredentials.CredentialsType.SELECT,
required=True,
default="none",
options=[
ToolCredentialsOption(value="none", label=I18nObject(en_US="None", zh_Hans="")),
ToolCredentialsOption(value="api_key", label=I18nObject(en_US="Api Key", zh_Hans="Api Key")),
],
placeholder=I18nObject(en_US="Select auth type", zh_Hans="选择认证方式"),
),
ToolProviderCredentials(
name="api_key_header",
type=ToolProviderCredentials.CredentialsType.TEXT_INPUT,
required=False,
placeholder=I18nObject(en_US="Enter api key header", zh_Hans="输入 api key headerX-API-KEY"),
default="api_key",
help=I18nObject(en_US="HTTP header name for api key", zh_Hans="HTTP 头部字段名,用于传递 api key"),
),
ToolProviderCredentials(
name="api_key_value",
type=ToolProviderCredentials.CredentialsType.TEXT_INPUT,
required=False,
placeholder=I18nObject(en_US="Enter api key", zh_Hans="输入 api key"),
default="",
),
]
return jsonable_encoder(
{
"schema_type": schema_type,
"parameters_schema": tool_bundles,
"credentials_schema": credentials_schema,
"warning": warnings,
}
)
except Exception as e:
raise ValueError(f"invalid schema: {str(e)}")
@staticmethod
2024-05-27 22:01:11 +08:00
def convert_schema_to_tool_bundles(schema: str, extra_info: dict = None) -> list[ApiToolBundle]:
"""
convert schema to tool bundles
:return: the list of tool bundles, description
"""
try:
tool_bundles = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, extra_info=extra_info)
return tool_bundles
except Exception as e:
raise ValueError(f"invalid schema: {str(e)}")
@staticmethod
def create_api_tool_provider(
user_id: str,
tenant_id: str,
provider_name: str,
icon: dict,
credentials: dict,
schema_type: str,
schema: str,
privacy_policy: str,
custom_disclaimer: str,
labels: list[str],
):
"""
create api tool provider
"""
if schema_type not in [member.value for member in ApiProviderSchemaType]:
raise ValueError(f"invalid schema type {schema}")
# check if the provider exists
provider: ApiToolProvider = (
db.session.query(ApiToolProvider)
.filter(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider_name,
)
.first()
)
if provider is not None:
raise ValueError(f"provider {provider_name} already exists")
# parse openapi to tool bundle
extra_info = {}
# extra info like description will be set here
2024-05-27 22:01:11 +08:00
tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info)
2024-03-01 16:43:47 +08:00
if len(tool_bundles) > 100:
raise ValueError("the number of apis should be less than 100")
# create db provider
db_provider = ApiToolProvider(
tenant_id=tenant_id,
user_id=user_id,
name=provider_name,
icon=json.dumps(icon),
schema=schema,
description=extra_info.get("description", ""),
schema_type_str=schema_type,
tools_str=json.dumps(jsonable_encoder(tool_bundles)),
credentials_str={},
2024-05-18 10:52:48 +08:00
privacy_policy=privacy_policy,
custom_disclaimer=custom_disclaimer,
)
if "auth_type" not in credentials:
raise ValueError("auth_type is required")
# get auth type, none or api key
auth_type = ApiProviderAuthType.value_of(credentials["auth_type"])
# create provider entity
2024-05-27 22:01:11 +08:00
provider_controller = ApiToolProviderController.from_db(db_provider, auth_type)
# load tools into provider entity
provider_controller.load_bundled_tools(tool_bundles)
# encrypt credentials
2024-03-08 20:31:13 +08:00
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
encrypted_credentials = tool_configuration.encrypt_tool_credentials(credentials)
db_provider.credentials_str = json.dumps(encrypted_credentials)
db.session.add(db_provider)
db.session.commit()
2024-05-27 22:01:11 +08:00
# update labels
ToolLabelManager.update_tool_labels(provider_controller, labels)
return {"result": "success"}
@staticmethod
def get_api_tool_provider_remote_schema(user_id: str, tenant_id: str, url: str):
"""
get api tool provider remote schema
"""
headers = {
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko)"
" Chrome/120.0.0.0 Safari/537.36 Edg/120.0.0.0",
"Accept": "*/*",
}
try:
response = get(url, headers=headers, timeout=10)
if response.status_code != 200:
raise ValueError(f"Got status code {response.status_code}")
schema = response.text
# try to parse schema, avoid SSRF attack
2024-05-27 22:01:11 +08:00
ApiToolManageService.parser_api_schema(schema)
except Exception as e:
logger.error(f"parse api schema error: {str(e)}")
raise ValueError("invalid schema, please check the url you provided")
return {"schema": schema}
@staticmethod
def list_api_tool_provider_tools(user_id: str, tenant_id: str, provider: str) -> list[UserTool]:
"""
list api tool provider tools
"""
provider: ApiToolProvider = (
db.session.query(ApiToolProvider)
.filter(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider,
)
.first()
)
if provider is None:
raise ValueError(f"you have not added provider {provider}")
2024-05-27 22:01:11 +08:00
controller = ToolTransformService.api_provider_to_controller(db_provider=provider)
labels = ToolLabelManager.get_tool_labels(controller)
return [
2024-05-27 22:01:11 +08:00
ToolTransformService.tool_to_user_tool(
tool_bundle,
labels=labels,
)
for tool_bundle in provider.tools
]
@staticmethod
def update_api_tool_provider(
user_id: str,
tenant_id: str,
provider_name: str,
original_provider: str,
icon: dict,
credentials: dict,
schema_type: str,
schema: str,
privacy_policy: str,
custom_disclaimer: str,
labels: list[str],
):
"""
update api tool provider
"""
if schema_type not in [member.value for member in ApiProviderSchemaType]:
raise ValueError(f"invalid schema type {schema}")
# check if the provider exists
provider: ApiToolProvider = (
db.session.query(ApiToolProvider)
.filter(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == original_provider,
)
.first()
)
if provider is None:
raise ValueError(f"api provider {provider_name} does not exists")
# parse openapi to tool bundle
extra_info = {}
# extra info like description will be set here
2024-05-27 22:01:11 +08:00
tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info)
# update db provider
provider.name = provider_name
2024-02-01 09:10:32 +08:00
provider.icon = json.dumps(icon)
provider.schema = schema
provider.description = extra_info.get("description", "")
provider.schema_type_str = ApiProviderSchemaType.OPENAPI.value
provider.tools_str = json.dumps(jsonable_encoder(tool_bundles))
provider.privacy_policy = privacy_policy
2024-05-18 10:52:48 +08:00
provider.custom_disclaimer = custom_disclaimer
if "auth_type" not in credentials:
raise ValueError("auth_type is required")
# get auth type, none or api key
auth_type = ApiProviderAuthType.value_of(credentials["auth_type"])
# create provider entity
2024-05-27 22:01:11 +08:00
provider_controller = ApiToolProviderController.from_db(provider, auth_type)
# load tools into provider entity
provider_controller.load_bundled_tools(tool_bundles)
# get original credentials if exists
2024-03-08 20:31:13 +08:00
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
original_credentials = tool_configuration.decrypt_tool_credentials(provider.credentials)
masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
# check if the credential has changed, save the original credential
for name, value in credentials.items():
if name in masked_credentials and value == masked_credentials[name]:
credentials[name] = original_credentials[name]
credentials = tool_configuration.encrypt_tool_credentials(credentials)
provider.credentials_str = json.dumps(credentials)
db.session.add(provider)
db.session.commit()
# delete cache
tool_configuration.delete_tool_credentials_cache()
2024-05-27 22:01:11 +08:00
# update labels
ToolLabelManager.update_tool_labels(provider_controller, labels)
return {"result": "success"}
@staticmethod
def delete_api_tool_provider(user_id: str, tenant_id: str, provider_name: str):
"""
delete tool provider
"""
provider: ApiToolProvider = (
db.session.query(ApiToolProvider)
.filter(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider_name,
)
.first()
)
if provider is None:
raise ValueError(f"you have not added provider {provider_name}")
db.session.delete(provider)
db.session.commit()
return {"result": "success"}
@staticmethod
def get_api_tool_provider(user_id: str, tenant_id: str, provider: str):
"""
get api tool provider
"""
return ToolManager.user_get_api_provider(provider=provider, tenant_id=tenant_id)
@staticmethod
def test_api_tool_preview(
tenant_id: str,
provider_name: str,
tool_name: str,
credentials: dict,
parameters: dict,
schema_type: str,
schema: str,
):
"""
test api tool before adding api tool provider
"""
if schema_type not in [member.value for member in ApiProviderSchemaType]:
raise ValueError(f"invalid schema type {schema_type}")
2024-01-30 22:22:58 +08:00
try:
tool_bundles, _ = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema)
except Exception as e:
raise ValueError("invalid schema")
# get tool bundle
tool_bundle = next(filter(lambda tb: tb.operation_id == tool_name, tool_bundles), None)
if tool_bundle is None:
raise ValueError(f"invalid tool name {tool_name}")
db_provider: ApiToolProvider = (
db.session.query(ApiToolProvider)
.filter(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider_name,
)
.first()
)
if not db_provider:
# create a fake db provider
db_provider = ApiToolProvider(
tenant_id="",
user_id="",
name="",
icon="",
schema=schema,
description="",
schema_type_str=ApiProviderSchemaType.OPENAPI.value,
tools_str=json.dumps(jsonable_encoder(tool_bundles)),
credentials_str=json.dumps(credentials),
)
if "auth_type" not in credentials:
raise ValueError("auth_type is required")
# get auth type, none or api key
auth_type = ApiProviderAuthType.value_of(credentials["auth_type"])
# create provider entity
2024-05-27 22:01:11 +08:00
provider_controller = ApiToolProviderController.from_db(db_provider, auth_type)
# load tools into provider entity
provider_controller.load_bundled_tools(tool_bundles)
# decrypt credentials
if db_provider.id:
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
# check if the credential has changed, save the original credential
masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
for name, value in credentials.items():
if name in masked_credentials and value == masked_credentials[name]:
credentials[name] = decrypted_credentials[name]
try:
provider_controller.validate_credentials_format(credentials)
# get tool
tool = provider_controller.get_tool(tool_name)
tool = tool.fork_tool_runtime(
runtime={
"credentials": credentials,
"tenant_id": tenant_id,
}
)
result = tool.validate_credentials(credentials, parameters)
except Exception as e:
return {"error": str(e)}
return {"result": result or "empty response"}
@staticmethod
def list_api_tools(user_id: str, tenant_id: str) -> list[UserToolProvider]:
"""
list api tools
"""
# get all api providers
db_providers: list[ApiToolProvider] = (
db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all() or []
)
result: list[UserToolProvider] = []
for provider in db_providers:
# convert provider controller to user provider
provider_controller = ToolTransformService.api_provider_to_controller(db_provider=provider)
2024-05-27 22:01:11 +08:00
labels = ToolLabelManager.get_tool_labels(provider_controller)
user_provider = ToolTransformService.api_provider_to_user_provider(
provider_controller, db_provider=provider, decrypt_credentials=True
)
user_provider.labels = labels
# add icon
ToolTransformService.repack_provider(user_provider)
tools = provider_controller.get_tools(user_id=user_id, tenant_id=tenant_id)
for tool in tools:
user_provider.tools.append(
ToolTransformService.tool_to_user_tool(
tenant_id=tenant_id, tool=tool, credentials=user_provider.original_credentials, labels=labels
)
)
result.append(user_provider)
return result