2024-02-01 18:11:57 +08:00
|
|
|
|
import json
|
2024-03-30 14:44:50 +08:00
|
|
|
|
import logging
|
2024-02-06 13:21:13 +08:00
|
|
|
|
|
|
|
|
|
from flask import current_app
|
|
|
|
|
from httpx import get
|
2024-01-23 19:58:23 +08:00
|
|
|
|
|
|
|
|
|
from core.tools.entities.common_entities import I18nObject
|
|
|
|
|
from core.tools.entities.tool_bundle import ApiBasedToolBundle
|
2024-02-06 13:21:13 +08:00
|
|
|
|
from core.tools.entities.tool_entities import (
|
|
|
|
|
ApiProviderAuthType,
|
|
|
|
|
ApiProviderSchemaType,
|
|
|
|
|
ToolCredentialsOption,
|
2024-03-07 15:04:42 +08:00
|
|
|
|
ToolParameter,
|
2024-02-06 13:21:13 +08:00
|
|
|
|
ToolProviderCredentials,
|
|
|
|
|
)
|
2024-02-01 18:11:57 +08:00
|
|
|
|
from core.tools.entities.user_entities import UserTool, UserToolProvider
|
|
|
|
|
from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError
|
2024-01-23 19:58:23 +08:00
|
|
|
|
from core.tools.provider.api_tool_provider import ApiBasedToolProviderController
|
2024-02-01 18:11:57 +08:00
|
|
|
|
from core.tools.provider.tool_provider import ToolProviderController
|
|
|
|
|
from core.tools.tool_manager import ToolManager
|
2024-03-08 20:31:13 +08:00
|
|
|
|
from core.tools.utils.configuration import ToolConfigurationManager
|
2024-02-01 18:11:57 +08:00
|
|
|
|
from core.tools.utils.encoder import serialize_base_model_array, serialize_base_model_dict
|
|
|
|
|
from core.tools.utils.parser import ApiBasedToolSchemaParser
|
2024-01-23 19:58:23 +08:00
|
|
|
|
from extensions.ext_database import db
|
2024-02-01 18:11:57 +08:00
|
|
|
|
from models.tools import ApiToolProvider, BuiltinToolProvider
|
2024-03-08 15:22:55 +08:00
|
|
|
|
from services.model_provider_service import ModelProviderService
|
2024-01-23 19:58:23 +08:00
|
|
|
|
|
2024-03-30 14:44:50 +08:00
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
2024-01-23 19:58:23 +08:00
|
|
|
|
|
|
|
|
|
class ToolManageService:
|
|
|
|
|
@staticmethod
|
2024-01-31 11:58:07 +08:00
|
|
|
|
def list_tool_providers(user_id: str, tenant_id: str):
|
2024-01-23 19:58:23 +08:00
|
|
|
|
"""
|
|
|
|
|
list tool providers
|
|
|
|
|
|
|
|
|
|
:return: the list of tool providers
|
|
|
|
|
"""
|
|
|
|
|
result = [provider.to_dict() for provider in ToolManager.user_list_providers(
|
2024-01-31 11:58:07 +08:00
|
|
|
|
user_id, tenant_id
|
2024-01-23 19:58:23 +08:00
|
|
|
|
)]
|
|
|
|
|
|
|
|
|
|
# add icon url prefix
|
|
|
|
|
for provider in result:
|
2024-01-31 11:58:07 +08:00
|
|
|
|
ToolManageService.repack_provider(provider)
|
2024-01-23 19:58:23 +08:00
|
|
|
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
2024-01-31 11:58:07 +08:00
|
|
|
|
def repack_provider(provider: dict):
|
2024-01-23 19:58:23 +08:00
|
|
|
|
"""
|
2024-01-31 11:58:07 +08:00
|
|
|
|
repack provider
|
2024-01-23 19:58:23 +08:00
|
|
|
|
|
|
|
|
|
:param provider: the provider dict
|
|
|
|
|
"""
|
|
|
|
|
url_prefix = (current_app.config.get("CONSOLE_API_URL")
|
2024-03-08 15:22:55 +08:00
|
|
|
|
+ "/console/api/workspaces/current/tool-provider/")
|
2024-01-23 19:58:23 +08:00
|
|
|
|
|
|
|
|
|
if 'icon' in provider:
|
|
|
|
|
if provider['type'] == UserToolProvider.ProviderType.BUILTIN.value:
|
2024-03-08 15:22:55 +08:00
|
|
|
|
provider['icon'] = url_prefix + 'builtin/' + provider['name'] + '/icon'
|
|
|
|
|
elif provider['type'] == UserToolProvider.ProviderType.MODEL.value:
|
|
|
|
|
provider['icon'] = url_prefix + 'model/' + provider['name'] + '/icon'
|
2024-01-23 19:58:23 +08:00
|
|
|
|
elif provider['type'] == UserToolProvider.ProviderType.API.value:
|
|
|
|
|
try:
|
|
|
|
|
provider['icon'] = json.loads(provider['icon'])
|
|
|
|
|
except:
|
|
|
|
|
provider['icon'] = {
|
|
|
|
|
"background": "#252525",
|
|
|
|
|
"content": "\ud83d\ude01"
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def list_builtin_tool_provider_tools(
|
|
|
|
|
user_id: str, tenant_id: str, provider: str
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
list builtin tool provider tools
|
|
|
|
|
"""
|
|
|
|
|
provider_controller: ToolProviderController = ToolManager.get_builtin_provider(provider)
|
|
|
|
|
tools = provider_controller.get_tools()
|
|
|
|
|
|
2024-03-08 20:31:13 +08:00
|
|
|
|
tool_provider_configurations = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
|
2024-03-07 15:04:42 +08:00
|
|
|
|
# check if user has added the provider
|
|
|
|
|
builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
|
|
|
|
|
BuiltinToolProvider.tenant_id == tenant_id,
|
|
|
|
|
BuiltinToolProvider.provider == provider,
|
|
|
|
|
).first()
|
|
|
|
|
|
|
|
|
|
credentials = {}
|
|
|
|
|
if builtin_provider is not None:
|
|
|
|
|
# get credentials
|
|
|
|
|
credentials = builtin_provider.credentials
|
|
|
|
|
credentials = tool_provider_configurations.decrypt_tool_credentials(credentials)
|
|
|
|
|
|
|
|
|
|
result = []
|
|
|
|
|
for tool in tools:
|
|
|
|
|
# fork tool runtime
|
|
|
|
|
tool = tool.fork_tool_runtime(meta={
|
|
|
|
|
'credentials': credentials,
|
|
|
|
|
'tenant_id': tenant_id,
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
# get tool parameters
|
|
|
|
|
parameters = tool.parameters or []
|
|
|
|
|
# get tool runtime parameters
|
|
|
|
|
runtime_parameters = tool.get_runtime_parameters()
|
|
|
|
|
# override parameters
|
|
|
|
|
current_parameters = parameters.copy()
|
|
|
|
|
for runtime_parameter in runtime_parameters:
|
|
|
|
|
found = False
|
|
|
|
|
for index, parameter in enumerate(current_parameters):
|
|
|
|
|
if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form:
|
|
|
|
|
current_parameters[index] = runtime_parameter
|
|
|
|
|
found = True
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
|
|
|
|
|
current_parameters.append(runtime_parameter)
|
|
|
|
|
|
|
|
|
|
user_tool = UserTool(
|
2024-01-23 19:58:23 +08:00
|
|
|
|
author=tool.identity.author,
|
|
|
|
|
name=tool.identity.name,
|
|
|
|
|
label=tool.identity.label,
|
|
|
|
|
description=tool.description.human,
|
2024-03-07 15:04:42 +08:00
|
|
|
|
parameters=current_parameters
|
|
|
|
|
)
|
|
|
|
|
result.append(user_tool)
|
2024-01-23 19:58:23 +08:00
|
|
|
|
|
|
|
|
|
return json.loads(
|
|
|
|
|
serialize_base_model_array(result)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def list_builtin_provider_credentials_schema(
|
|
|
|
|
provider_name
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
list builtin provider credentials schema
|
|
|
|
|
|
|
|
|
|
:return: the list of tool providers
|
|
|
|
|
"""
|
|
|
|
|
provider = ToolManager.get_builtin_provider(provider_name)
|
2024-03-18 16:55:26 +08:00
|
|
|
|
return json.loads(serialize_base_model_array([
|
|
|
|
|
v for _, v in (provider.credentials_schema or {}).items()
|
|
|
|
|
]))
|
2024-01-23 19:58:23 +08:00
|
|
|
|
|
|
|
|
|
@staticmethod
|
2024-02-09 15:21:33 +08:00
|
|
|
|
def parser_api_schema(schema: str) -> list[ApiBasedToolBundle]:
|
2024-01-23 19:58:23 +08:00
|
|
|
|
"""
|
|
|
|
|
parse api schema to tool bundle
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
warnings = {}
|
|
|
|
|
try:
|
|
|
|
|
tool_bundles, schema_type = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, warning=warnings)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
raise ValueError(f'invalid schema: {str(e)}')
|
|
|
|
|
|
2024-01-24 12:00:34 +08:00
|
|
|
|
credentials_schema = [
|
2024-01-23 19:58:23 +08:00
|
|
|
|
ToolProviderCredentials(
|
|
|
|
|
name='auth_type',
|
|
|
|
|
type=ToolProviderCredentials.CredentialsType.SELECT,
|
|
|
|
|
required=True,
|
|
|
|
|
default='none',
|
|
|
|
|
options=[
|
|
|
|
|
ToolCredentialsOption(value='none', label=I18nObject(
|
|
|
|
|
en_US='None',
|
|
|
|
|
zh_Hans='无'
|
|
|
|
|
)),
|
|
|
|
|
ToolCredentialsOption(value='api_key', label=I18nObject(
|
|
|
|
|
en_US='Api Key',
|
|
|
|
|
zh_Hans='Api Key'
|
|
|
|
|
)),
|
|
|
|
|
],
|
|
|
|
|
placeholder=I18nObject(
|
|
|
|
|
en_US='Select auth type',
|
|
|
|
|
zh_Hans='选择认证方式'
|
|
|
|
|
)
|
|
|
|
|
),
|
|
|
|
|
ToolProviderCredentials(
|
|
|
|
|
name='api_key_header',
|
|
|
|
|
type=ToolProviderCredentials.CredentialsType.TEXT_INPUT,
|
|
|
|
|
required=False,
|
|
|
|
|
placeholder=I18nObject(
|
|
|
|
|
en_US='Enter api key header',
|
|
|
|
|
zh_Hans='输入 api key header,如:X-API-KEY'
|
|
|
|
|
),
|
|
|
|
|
default='api_key',
|
|
|
|
|
help=I18nObject(
|
|
|
|
|
en_US='HTTP header name for api key',
|
|
|
|
|
zh_Hans='HTTP 头部字段名,用于传递 api key'
|
|
|
|
|
)
|
|
|
|
|
),
|
|
|
|
|
ToolProviderCredentials(
|
|
|
|
|
name='api_key_value',
|
|
|
|
|
type=ToolProviderCredentials.CredentialsType.TEXT_INPUT,
|
|
|
|
|
required=False,
|
|
|
|
|
placeholder=I18nObject(
|
|
|
|
|
en_US='Enter api key',
|
|
|
|
|
zh_Hans='输入 api key'
|
|
|
|
|
),
|
|
|
|
|
default=''
|
|
|
|
|
),
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
return json.loads(serialize_base_model_dict(
|
|
|
|
|
{
|
|
|
|
|
'schema_type': schema_type,
|
|
|
|
|
'parameters_schema': tool_bundles,
|
2024-01-24 12:00:34 +08:00
|
|
|
|
'credentials_schema': credentials_schema,
|
2024-01-23 19:58:23 +08:00
|
|
|
|
'warning': warnings
|
|
|
|
|
}
|
|
|
|
|
))
|
|
|
|
|
except Exception as e:
|
|
|
|
|
raise ValueError(f'invalid schema: {str(e)}')
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
2024-02-09 15:21:33 +08:00
|
|
|
|
def convert_schema_to_tool_bundles(schema: str, extra_info: dict = None) -> list[ApiBasedToolBundle]:
|
2024-01-23 19:58:23 +08:00
|
|
|
|
"""
|
|
|
|
|
convert schema to tool bundles
|
|
|
|
|
|
|
|
|
|
:return: the list of tool bundles, description
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
tool_bundles = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, extra_info=extra_info)
|
|
|
|
|
return tool_bundles
|
|
|
|
|
except Exception as e:
|
|
|
|
|
raise ValueError(f'invalid schema: {str(e)}')
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def create_api_tool_provider(
|
|
|
|
|
user_id: str, tenant_id: str, provider_name: str, icon: dict, credentials: dict,
|
|
|
|
|
schema_type: str, schema: str, privacy_policy: str
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
create api tool provider
|
|
|
|
|
"""
|
|
|
|
|
if schema_type not in [member.value for member in ApiProviderSchemaType]:
|
|
|
|
|
raise ValueError(f'invalid schema type {schema}')
|
|
|
|
|
|
|
|
|
|
# check if the provider exists
|
|
|
|
|
provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
|
|
|
|
|
ApiToolProvider.tenant_id == tenant_id,
|
|
|
|
|
ApiToolProvider.name == provider_name,
|
|
|
|
|
).first()
|
|
|
|
|
|
|
|
|
|
if provider is not None:
|
|
|
|
|
raise ValueError(f'provider {provider_name} already exists')
|
|
|
|
|
|
|
|
|
|
# parse openapi to tool bundle
|
|
|
|
|
extra_info = {}
|
|
|
|
|
# extra info like description will be set here
|
|
|
|
|
tool_bundles, schema_type = ToolManageService.convert_schema_to_tool_bundles(schema, extra_info)
|
|
|
|
|
|
2024-03-01 16:43:47 +08:00
|
|
|
|
if len(tool_bundles) > 100:
|
|
|
|
|
raise ValueError('the number of apis should be less than 100')
|
2024-01-23 19:58:23 +08:00
|
|
|
|
|
|
|
|
|
# create db provider
|
|
|
|
|
db_provider = ApiToolProvider(
|
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
|
user_id=user_id,
|
|
|
|
|
name=provider_name,
|
|
|
|
|
icon=json.dumps(icon),
|
|
|
|
|
schema=schema,
|
|
|
|
|
description=extra_info.get('description', ''),
|
|
|
|
|
schema_type_str=schema_type,
|
|
|
|
|
tools_str=serialize_base_model_array(tool_bundles),
|
|
|
|
|
credentials_str={},
|
|
|
|
|
privacy_policy=privacy_policy
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if 'auth_type' not in credentials:
|
|
|
|
|
raise ValueError('auth_type is required')
|
|
|
|
|
|
|
|
|
|
# get auth type, none or api key
|
|
|
|
|
auth_type = ApiProviderAuthType.value_of(credentials['auth_type'])
|
|
|
|
|
|
|
|
|
|
# create provider entity
|
|
|
|
|
provider_controller = ApiBasedToolProviderController.from_db(db_provider, auth_type)
|
|
|
|
|
# load tools into provider entity
|
|
|
|
|
provider_controller.load_bundled_tools(tool_bundles)
|
|
|
|
|
|
|
|
|
|
# encrypt credentials
|
2024-03-08 20:31:13 +08:00
|
|
|
|
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
|
2024-01-23 19:58:23 +08:00
|
|
|
|
encrypted_credentials = tool_configuration.encrypt_tool_credentials(credentials)
|
|
|
|
|
db_provider.credentials_str = json.dumps(encrypted_credentials)
|
|
|
|
|
|
|
|
|
|
db.session.add(db_provider)
|
|
|
|
|
db.session.commit()
|
|
|
|
|
|
|
|
|
|
return { 'result': 'success' }
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_api_tool_provider_remote_schema(
|
|
|
|
|
user_id: str, tenant_id: str, url: str
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
get api tool provider remote schema
|
|
|
|
|
"""
|
|
|
|
|
headers = {
|
|
|
|
|
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36 Edg/120.0.0.0",
|
|
|
|
|
"Accept": "*/*",
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
response = get(url, headers=headers, timeout=10)
|
|
|
|
|
if response.status_code != 200:
|
|
|
|
|
raise ValueError(f'Got status code {response.status_code}')
|
|
|
|
|
schema = response.text
|
|
|
|
|
|
|
|
|
|
# try to parse schema, avoid SSRF attack
|
|
|
|
|
ToolManageService.parser_api_schema(schema)
|
|
|
|
|
except Exception as e:
|
2024-03-30 14:44:50 +08:00
|
|
|
|
logger.error(f"parse api schema error: {str(e)}")
|
2024-02-08 14:11:10 +08:00
|
|
|
|
raise ValueError('invalid schema, please check the url you provided')
|
2024-01-23 19:58:23 +08:00
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
'schema': schema
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def list_api_tool_provider_tools(
|
|
|
|
|
user_id: str, tenant_id: str, provider: str
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
list api tool provider tools
|
|
|
|
|
"""
|
|
|
|
|
provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
|
|
|
|
|
ApiToolProvider.tenant_id == tenant_id,
|
|
|
|
|
ApiToolProvider.name == provider,
|
|
|
|
|
).first()
|
|
|
|
|
|
|
|
|
|
if provider is None:
|
2024-01-31 11:58:07 +08:00
|
|
|
|
raise ValueError(f'you have not added provider {provider}')
|
2024-01-23 19:58:23 +08:00
|
|
|
|
|
|
|
|
|
return json.loads(
|
|
|
|
|
serialize_base_model_array([
|
|
|
|
|
UserTool(
|
|
|
|
|
author=tool_bundle.author,
|
|
|
|
|
name=tool_bundle.operation_id,
|
|
|
|
|
label=I18nObject(
|
|
|
|
|
en_US=tool_bundle.operation_id,
|
|
|
|
|
zh_Hans=tool_bundle.operation_id
|
|
|
|
|
),
|
|
|
|
|
description=I18nObject(
|
|
|
|
|
en_US=tool_bundle.summary or '',
|
|
|
|
|
zh_Hans=tool_bundle.summary or ''
|
|
|
|
|
),
|
|
|
|
|
parameters=tool_bundle.parameters
|
|
|
|
|
) for tool_bundle in provider.tools
|
|
|
|
|
])
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def update_builtin_tool_provider(
|
|
|
|
|
user_id: str, tenant_id: str, provider_name: str, credentials: dict
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
update builtin tool provider
|
|
|
|
|
"""
|
2024-01-30 18:41:36 +08:00
|
|
|
|
# get if the provider exists
|
|
|
|
|
provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
|
|
|
|
|
BuiltinToolProvider.tenant_id == tenant_id,
|
|
|
|
|
BuiltinToolProvider.provider == provider_name,
|
|
|
|
|
).first()
|
|
|
|
|
|
2024-01-23 19:58:23 +08:00
|
|
|
|
try:
|
|
|
|
|
# get provider
|
|
|
|
|
provider_controller = ToolManager.get_builtin_provider(provider_name)
|
|
|
|
|
if not provider_controller.need_credentials:
|
|
|
|
|
raise ValueError(f'provider {provider_name} does not need credentials')
|
2024-03-08 20:31:13 +08:00
|
|
|
|
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
|
2024-01-30 18:41:36 +08:00
|
|
|
|
# get original credentials if exists
|
|
|
|
|
if provider is not None:
|
|
|
|
|
original_credentials = tool_configuration.decrypt_tool_credentials(provider.credentials)
|
|
|
|
|
masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
|
|
|
|
|
# check if the credential has changed, save the original credential
|
|
|
|
|
for name, value in credentials.items():
|
|
|
|
|
if name in masked_credentials and value == masked_credentials[name]:
|
|
|
|
|
credentials[name] = original_credentials[name]
|
2024-01-23 19:58:23 +08:00
|
|
|
|
# validate credentials
|
|
|
|
|
provider_controller.validate_credentials(credentials)
|
|
|
|
|
# encrypt credentials
|
|
|
|
|
credentials = tool_configuration.encrypt_tool_credentials(credentials)
|
|
|
|
|
except (ToolProviderNotFoundError, ToolNotFoundError, ToolProviderCredentialValidationError) as e:
|
|
|
|
|
raise ValueError(str(e))
|
|
|
|
|
|
|
|
|
|
if provider is None:
|
|
|
|
|
# create provider
|
|
|
|
|
provider = BuiltinToolProvider(
|
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
|
user_id=user_id,
|
|
|
|
|
provider=provider_name,
|
|
|
|
|
encrypted_credentials=json.dumps(credentials),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
db.session.add(provider)
|
|
|
|
|
db.session.commit()
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
provider.encrypted_credentials = json.dumps(credentials)
|
|
|
|
|
db.session.add(provider)
|
|
|
|
|
db.session.commit()
|
|
|
|
|
|
2024-02-12 18:17:43 +08:00
|
|
|
|
# delete cache
|
|
|
|
|
tool_configuration.delete_tool_credentials_cache()
|
|
|
|
|
|
2024-01-23 19:58:23 +08:00
|
|
|
|
return { 'result': 'success' }
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def update_api_tool_provider(
|
2024-02-01 09:10:32 +08:00
|
|
|
|
user_id: str, tenant_id: str, provider_name: str, original_provider: str, icon: dict, credentials: dict,
|
2024-01-23 19:58:23 +08:00
|
|
|
|
schema_type: str, schema: str, privacy_policy: str
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
update api tool provider
|
|
|
|
|
"""
|
|
|
|
|
if schema_type not in [member.value for member in ApiProviderSchemaType]:
|
|
|
|
|
raise ValueError(f'invalid schema type {schema}')
|
|
|
|
|
|
|
|
|
|
# check if the provider exists
|
|
|
|
|
provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
|
|
|
|
|
ApiToolProvider.tenant_id == tenant_id,
|
|
|
|
|
ApiToolProvider.name == original_provider,
|
|
|
|
|
).first()
|
|
|
|
|
|
|
|
|
|
if provider is None:
|
|
|
|
|
raise ValueError(f'api provider {provider_name} does not exists')
|
|
|
|
|
|
|
|
|
|
# parse openapi to tool bundle
|
|
|
|
|
extra_info = {}
|
|
|
|
|
# extra info like description will be set here
|
|
|
|
|
tool_bundles, schema_type = ToolManageService.convert_schema_to_tool_bundles(schema, extra_info)
|
|
|
|
|
|
|
|
|
|
# update db provider
|
|
|
|
|
provider.name = provider_name
|
2024-02-01 09:10:32 +08:00
|
|
|
|
provider.icon = json.dumps(icon)
|
2024-01-23 19:58:23 +08:00
|
|
|
|
provider.schema = schema
|
|
|
|
|
provider.description = extra_info.get('description', '')
|
|
|
|
|
provider.schema_type_str = ApiProviderSchemaType.OPENAPI.value
|
|
|
|
|
provider.tools_str = serialize_base_model_array(tool_bundles)
|
|
|
|
|
provider.privacy_policy = privacy_policy
|
|
|
|
|
|
|
|
|
|
if 'auth_type' not in credentials:
|
|
|
|
|
raise ValueError('auth_type is required')
|
|
|
|
|
|
|
|
|
|
# get auth type, none or api key
|
|
|
|
|
auth_type = ApiProviderAuthType.value_of(credentials['auth_type'])
|
|
|
|
|
|
|
|
|
|
# create provider entity
|
2024-02-12 18:17:43 +08:00
|
|
|
|
provider_controller = ApiBasedToolProviderController.from_db(provider, auth_type)
|
2024-01-23 19:58:23 +08:00
|
|
|
|
# load tools into provider entity
|
2024-02-12 18:17:43 +08:00
|
|
|
|
provider_controller.load_bundled_tools(tool_bundles)
|
|
|
|
|
|
|
|
|
|
# get original credentials if exists
|
2024-03-08 20:31:13 +08:00
|
|
|
|
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
|
2024-02-12 18:17:43 +08:00
|
|
|
|
|
|
|
|
|
original_credentials = tool_configuration.decrypt_tool_credentials(provider.credentials)
|
|
|
|
|
masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
|
|
|
|
|
# check if the credential has changed, save the original credential
|
|
|
|
|
for name, value in credentials.items():
|
|
|
|
|
if name in masked_credentials and value == masked_credentials[name]:
|
|
|
|
|
credentials[name] = original_credentials[name]
|
|
|
|
|
|
|
|
|
|
credentials = tool_configuration.encrypt_tool_credentials(credentials)
|
|
|
|
|
provider.credentials_str = json.dumps(credentials)
|
2024-01-23 19:58:23 +08:00
|
|
|
|
|
|
|
|
|
db.session.add(provider)
|
|
|
|
|
db.session.commit()
|
|
|
|
|
|
2024-02-12 18:17:43 +08:00
|
|
|
|
# delete cache
|
|
|
|
|
tool_configuration.delete_tool_credentials_cache()
|
|
|
|
|
|
2024-01-23 19:58:23 +08:00
|
|
|
|
return { 'result': 'success' }
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def delete_builtin_tool_provider(
|
2024-02-12 18:17:43 +08:00
|
|
|
|
user_id: str, tenant_id: str, provider_name: str
|
2024-01-23 19:58:23 +08:00
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
delete tool provider
|
|
|
|
|
"""
|
|
|
|
|
provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
|
|
|
|
|
BuiltinToolProvider.tenant_id == tenant_id,
|
2024-02-12 18:17:43 +08:00
|
|
|
|
BuiltinToolProvider.provider == provider_name,
|
2024-01-23 19:58:23 +08:00
|
|
|
|
).first()
|
|
|
|
|
|
|
|
|
|
if provider is None:
|
2024-02-12 18:17:43 +08:00
|
|
|
|
raise ValueError(f'you have not added provider {provider_name}')
|
2024-01-23 19:58:23 +08:00
|
|
|
|
|
|
|
|
|
db.session.delete(provider)
|
|
|
|
|
db.session.commit()
|
|
|
|
|
|
2024-02-12 18:17:43 +08:00
|
|
|
|
# delete cache
|
|
|
|
|
provider_controller = ToolManager.get_builtin_provider(provider_name)
|
2024-03-08 20:31:13 +08:00
|
|
|
|
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
|
2024-02-12 18:17:43 +08:00
|
|
|
|
tool_configuration.delete_tool_credentials_cache()
|
|
|
|
|
|
2024-01-23 19:58:23 +08:00
|
|
|
|
return { 'result': 'success' }
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_builtin_tool_provider_icon(
|
|
|
|
|
provider: str
|
|
|
|
|
):
|
|
|
|
|
"""
|
2024-02-12 18:17:43 +08:00
|
|
|
|
get tool provider icon and it's mimetype
|
2024-01-23 19:58:23 +08:00
|
|
|
|
"""
|
|
|
|
|
icon_path, mime_type = ToolManager.get_builtin_provider_icon(provider)
|
|
|
|
|
with open(icon_path, 'rb') as f:
|
|
|
|
|
icon_bytes = f.read()
|
|
|
|
|
|
|
|
|
|
return icon_bytes, mime_type
|
|
|
|
|
|
2024-03-08 15:22:55 +08:00
|
|
|
|
@staticmethod
|
|
|
|
|
def get_model_tool_provider_icon(
|
|
|
|
|
provider: str
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
get tool provider icon and it's mimetype
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
service = ModelProviderService()
|
|
|
|
|
icon_bytes, mime_type = service.get_model_provider_icon(provider=provider, icon_type='icon_small', lang='en_US')
|
|
|
|
|
|
|
|
|
|
if icon_bytes is None:
|
|
|
|
|
raise ValueError(f'provider {provider} does not exists')
|
|
|
|
|
|
|
|
|
|
return icon_bytes, mime_type
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def list_model_tool_provider_tools(
|
|
|
|
|
user_id: str, tenant_id: str, provider: str
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
list model tool provider tools
|
|
|
|
|
"""
|
|
|
|
|
provider_controller = ToolManager.get_model_provider(tenant_id=tenant_id, provider_name=provider)
|
|
|
|
|
tools = provider_controller.get_tools(user_id=user_id, tenant_id=tenant_id)
|
|
|
|
|
|
|
|
|
|
result = [
|
|
|
|
|
UserTool(
|
|
|
|
|
author=tool.identity.author,
|
|
|
|
|
name=tool.identity.name,
|
|
|
|
|
label=tool.identity.label,
|
|
|
|
|
description=tool.description.human,
|
|
|
|
|
parameters=tool.parameters or []
|
|
|
|
|
) for tool in tools
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
return json.loads(
|
|
|
|
|
serialize_base_model_array(result)
|
|
|
|
|
)
|
|
|
|
|
|
2024-01-23 19:58:23 +08:00
|
|
|
|
@staticmethod
|
|
|
|
|
def delete_api_tool_provider(
|
2024-02-12 18:17:43 +08:00
|
|
|
|
user_id: str, tenant_id: str, provider_name: str
|
2024-01-23 19:58:23 +08:00
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
delete tool provider
|
|
|
|
|
"""
|
|
|
|
|
provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
|
|
|
|
|
ApiToolProvider.tenant_id == tenant_id,
|
2024-02-12 18:17:43 +08:00
|
|
|
|
ApiToolProvider.name == provider_name,
|
2024-01-23 19:58:23 +08:00
|
|
|
|
).first()
|
|
|
|
|
|
|
|
|
|
if provider is None:
|
2024-02-12 18:17:43 +08:00
|
|
|
|
raise ValueError(f'you have not added provider {provider_name}')
|
2024-01-23 19:58:23 +08:00
|
|
|
|
|
|
|
|
|
db.session.delete(provider)
|
|
|
|
|
db.session.commit()
|
|
|
|
|
|
|
|
|
|
return { 'result': 'success' }
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_api_tool_provider(
|
|
|
|
|
user_id: str, tenant_id: str, provider: str
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
get api tool provider
|
|
|
|
|
"""
|
|
|
|
|
return ToolManager.user_get_api_provider(provider=provider, tenant_id=tenant_id)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def test_api_tool_preview(
|
2024-03-04 14:16:47 +08:00
|
|
|
|
tenant_id: str,
|
|
|
|
|
provider_name: str,
|
|
|
|
|
tool_name: str,
|
|
|
|
|
credentials: dict,
|
|
|
|
|
parameters: dict,
|
|
|
|
|
schema_type: str,
|
|
|
|
|
schema: str
|
2024-01-23 19:58:23 +08:00
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
test api tool before adding api tool provider
|
|
|
|
|
"""
|
|
|
|
|
if schema_type not in [member.value for member in ApiProviderSchemaType]:
|
|
|
|
|
raise ValueError(f'invalid schema type {schema_type}')
|
|
|
|
|
|
2024-01-30 22:22:58 +08:00
|
|
|
|
try:
|
|
|
|
|
tool_bundles, _ = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema)
|
|
|
|
|
except Exception as e:
|
2024-02-08 14:11:10 +08:00
|
|
|
|
raise ValueError('invalid schema')
|
2024-01-23 19:58:23 +08:00
|
|
|
|
|
|
|
|
|
# get tool bundle
|
|
|
|
|
tool_bundle = next(filter(lambda tb: tb.operation_id == tool_name, tool_bundles), None)
|
|
|
|
|
if tool_bundle is None:
|
|
|
|
|
raise ValueError(f'invalid tool name {tool_name}')
|
|
|
|
|
|
2024-03-04 14:16:47 +08:00
|
|
|
|
db_provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
|
|
|
|
|
ApiToolProvider.tenant_id == tenant_id,
|
|
|
|
|
ApiToolProvider.name == provider_name,
|
|
|
|
|
).first()
|
|
|
|
|
|
|
|
|
|
if not db_provider:
|
|
|
|
|
# create a fake db provider
|
|
|
|
|
db_provider = ApiToolProvider(
|
|
|
|
|
tenant_id='', user_id='', name='', icon='',
|
|
|
|
|
schema=schema,
|
|
|
|
|
description='',
|
|
|
|
|
schema_type_str=ApiProviderSchemaType.OPENAPI.value,
|
|
|
|
|
tools_str=serialize_base_model_array(tool_bundles),
|
|
|
|
|
credentials_str=json.dumps(credentials),
|
|
|
|
|
)
|
2024-01-23 19:58:23 +08:00
|
|
|
|
|
|
|
|
|
if 'auth_type' not in credentials:
|
|
|
|
|
raise ValueError('auth_type is required')
|
|
|
|
|
|
|
|
|
|
# get auth type, none or api key
|
|
|
|
|
auth_type = ApiProviderAuthType.value_of(credentials['auth_type'])
|
|
|
|
|
|
|
|
|
|
# create provider entity
|
|
|
|
|
provider_controller = ApiBasedToolProviderController.from_db(db_provider, auth_type)
|
|
|
|
|
# load tools into provider entity
|
|
|
|
|
provider_controller.load_bundled_tools(tool_bundles)
|
|
|
|
|
|
2024-03-04 14:16:47 +08:00
|
|
|
|
# decrypt credentials
|
|
|
|
|
if db_provider.id:
|
2024-03-08 20:31:13 +08:00
|
|
|
|
tool_configuration = ToolConfigurationManager(
|
2024-03-04 14:16:47 +08:00
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
|
provider_controller=provider_controller
|
|
|
|
|
)
|
|
|
|
|
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
|
|
|
|
|
# check if the credential has changed, save the original credential
|
|
|
|
|
masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
|
|
|
|
|
for name, value in credentials.items():
|
|
|
|
|
if name in masked_credentials and value == masked_credentials[name]:
|
|
|
|
|
credentials[name] = decrypted_credentials[name]
|
|
|
|
|
|
2024-01-23 19:58:23 +08:00
|
|
|
|
try:
|
|
|
|
|
provider_controller.validate_credentials_format(credentials)
|
|
|
|
|
# get tool
|
|
|
|
|
tool = provider_controller.get_tool(tool_name)
|
|
|
|
|
tool = tool.fork_tool_runtime(meta={
|
|
|
|
|
'credentials': credentials,
|
|
|
|
|
'tenant_id': tenant_id,
|
|
|
|
|
})
|
2024-02-05 18:48:30 +08:00
|
|
|
|
result = tool.validate_credentials(credentials, parameters)
|
2024-01-23 19:58:23 +08:00
|
|
|
|
except Exception as e:
|
|
|
|
|
return { 'error': str(e) }
|
|
|
|
|
|
2024-03-30 14:44:50 +08:00
|
|
|
|
return { 'result': result or 'empty response' }
|