feat: type

This commit is contained in:
Yeuoly 2024-08-30 21:10:19 +08:00
parent db8bf2a85e
commit 279dee485d
No known key found for this signature in database
GPG Key ID: A66E7E320FB19F61
7 changed files with 121 additions and 147 deletions

View File

@ -128,10 +128,6 @@ class BasicProviderConfig(BaseModel):
return mode
raise ValueError(f'invalid mode value {value}')
@staticmethod
def default(value: str) -> str:
return ""
type: Type = Field(..., description="The type of the credentials")
name: str = Field(..., description="The name of the credentials")

View File

@ -26,7 +26,7 @@ class UserToolProvider(BaseModel):
author: str
name: str # identifier
description: I18nObject
icon: str
icon: str | dict
label: I18nObject # label
type: ToolProviderType
masked_credentials: Optional[dict] = None

View File

@ -208,8 +208,12 @@ class WorkflowToolProviderController(ToolProviderController):
if not db_providers:
return []
app = db_providers.app
if not app:
raise ValueError("can not read app of workflow")
self.tools = [self._get_db_provider_tool(db_providers, db_providers.app)]
self.tools = [self._get_db_provider_tool(db_providers, app)]
return self.tools

View File

@ -1,4 +1,5 @@
import json
from datetime import datetime
from sqlalchemy import ForeignKey
from sqlalchemy.orm import Mapped, mapped_column
@ -13,7 +14,7 @@ from .model import Account, App, Tenant
from .types import StringUUID
class BuiltinToolProvider(db.Model):
class BuiltinToolProvider(Base):
"""
This table stores the tool provider information for built-in tools for each tenant.
"""
@ -25,61 +26,22 @@ class BuiltinToolProvider(db.Model):
)
# id of the tool provider
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text('uuid_generate_v4()'))
# id of the tenant
tenant_id = db.Column(StringUUID, nullable=True)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=True)
# who created this tool provider
user_id = db.Column(StringUUID, nullable=False)
user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
# name of the tool provider
provider = db.Column(db.String(40), nullable=False)
provider: Mapped[str] = mapped_column(db.String(40), nullable=False)
# credential of the tool provider
encrypted_credentials = db.Column(db.Text, nullable=True)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
encrypted_credentials: Mapped[str] = mapped_column(db.Text, nullable=True)
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
@property
def credentials(self) -> dict:
return json.loads(self.encrypted_credentials)
class PublishedAppTool(db.Model):
"""
The table stores the apps published as a tool for each person.
"""
__tablename__ = 'tool_published_apps'
__table_args__ = (
db.PrimaryKeyConstraint('id', name='published_app_tool_pkey'),
db.UniqueConstraint('app_id', 'user_id', name='unique_published_app_tool')
)
# id of the tool provider
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
# id of the app
app_id = db.Column(StringUUID, ForeignKey('apps.id'), nullable=False)
# who published this tool
user_id = db.Column(StringUUID, nullable=False)
# description of the tool, stored in i18n format, for human
description = db.Column(db.Text, nullable=False)
# llm_description of the tool, for LLM
llm_description = db.Column(db.Text, nullable=False)
# query description, query will be seem as a parameter of the tool, to describe this parameter to llm, we need this field
query_description = db.Column(db.Text, nullable=False)
# query name, the name of the query parameter
query_name = db.Column(db.String(40), nullable=False)
# name of the tool provider
tool_name = db.Column(db.String(40), nullable=False)
# author
author = db.Column(db.String(40), nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
@property
def description_i18n(self) -> I18nObject:
return I18nObject(**json.loads(self.description))
@property
def app(self) -> App | None:
return db.session.query(App).filter(App.id == self.app_id).first()
class ApiToolProvider(Base):
"""
The table stores the api providers.
@ -129,14 +91,14 @@ class ApiToolProvider(Base):
return json.loads(self.credentials_str)
@property
def user(self) -> Account:
def user(self) -> Account | None:
return db.session.query(Account).filter(Account.id == self.user_id).first()
@property
def tenant(self) -> Tenant:
def tenant(self) -> Tenant | None:
return db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first()
class ToolLabelBinding(db.Model):
class ToolLabelBinding(Base):
"""
The table stores the labels for tools.
"""
@ -146,15 +108,15 @@ class ToolLabelBinding(db.Model):
db.UniqueConstraint('tool_id', 'label_name', name='unique_tool_label_bind'),
)
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text('uuid_generate_v4()'))
# tool id
tool_id = db.Column(db.String(64), nullable=False)
tool_id: Mapped[str] = mapped_column(db.String(64), nullable=False)
# tool type
tool_type = db.Column(db.String(40), nullable=False)
tool_type: Mapped[str] = mapped_column(db.String(40), nullable=False)
# label name
label_name = db.Column(db.String(40), nullable=False)
label_name: Mapped[str] = mapped_column(db.String(40), nullable=False)
class WorkflowToolProvider(db.Model):
class WorkflowToolProvider(Base):
"""
The table stores the workflow providers.
"""
@ -165,41 +127,37 @@ class WorkflowToolProvider(db.Model):
db.UniqueConstraint('tenant_id', 'app_id', name='unique_workflow_tool_provider_app_id'),
)
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text('uuid_generate_v4()'))
# name of the workflow provider
name = db.Column(db.String(40), nullable=False)
name: Mapped[str] = mapped_column(db.String(40), nullable=False)
# label of the workflow provider
label = db.Column(db.String(255), nullable=False, server_default='')
label: Mapped[str] = mapped_column(db.String(255), nullable=False, server_default='')
# icon
icon = db.Column(db.String(255), nullable=False)
icon: Mapped[str] = mapped_column(db.String(255), nullable=False)
# app id of the workflow provider
app_id = db.Column(StringUUID, nullable=False)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
# version of the workflow provider
version = db.Column(db.String(255), nullable=False, server_default='')
version: Mapped[str] = mapped_column(db.String(255), nullable=False, server_default='')
# who created this tool
user_id = db.Column(StringUUID, nullable=False)
user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
# tenant id
tenant_id = db.Column(StringUUID, nullable=False)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
# description of the provider
description = db.Column(db.Text, nullable=False)
description: Mapped[str] = mapped_column(db.Text, nullable=False)
# parameter configuration
parameter_configuration = db.Column(db.Text, nullable=False, server_default='[]')
parameter_configuration: Mapped[str] = mapped_column(db.Text, nullable=False, server_default='[]')
# privacy policy
privacy_policy = db.Column(db.String(255), nullable=True, server_default='')
privacy_policy: Mapped[str] = mapped_column(db.String(255), nullable=True, server_default='')
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
@property
def schema_type(self) -> ApiProviderSchemaType:
return ApiProviderSchemaType.value_of(self.schema_type_str)
@property
def user(self) -> Account:
def user(self) -> Account | None:
return db.session.query(Account).filter(Account.id == self.user_id).first()
@property
def tenant(self) -> Tenant:
def tenant(self) -> Tenant | None:
return db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first()
@property
@ -210,7 +168,7 @@ class WorkflowToolProvider(db.Model):
]
@property
def app(self) -> App:
def app(self) -> App | None:
return db.session.query(App).filter(App.id == self.app_id).first()
class ToolModelInvoke(db.Model):

View File

@ -28,10 +28,13 @@ class BuiltinToolManageService:
tools = provider_controller.get_tools()
tool_provider_configurations = ToolConfigurationManager(
tenant_id=tenant_id, provider_controller=provider_controller
tenant_id=tenant_id,
config=provider_controller.get_credentials_schema(),
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.identity.name,
)
# check if user has added the provider
builtin_provider: BuiltinToolProvider = (
builtin_provider: BuiltinToolProvider | None = (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
@ -75,7 +78,7 @@ class BuiltinToolManageService:
update builtin tool provider
"""
# get if the provider exists
provider: BuiltinToolProvider = (
provider: BuiltinToolProvider | None = (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
@ -89,7 +92,13 @@ class BuiltinToolManageService:
provider_controller = ToolManager.get_builtin_provider(provider_name)
if not provider_controller.need_credentials:
raise ValueError(f"provider {provider_name} does not need credentials")
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
tool_configuration = ToolConfigurationManager(
tenant_id=tenant_id,
config=provider_controller.get_credentials_schema(),
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.identity.name,
)
# get original credentials if exists
if provider is not None:
original_credentials = tool_configuration.decrypt_tool_credentials(provider.credentials)
@ -132,7 +141,7 @@ class BuiltinToolManageService:
"""
get builtin tool provider credentials
"""
provider: BuiltinToolProvider = (
provider_obj: BuiltinToolProvider | None = (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
@ -141,12 +150,17 @@ class BuiltinToolManageService:
.first()
)
if provider is None:
if provider_obj is None:
return {}
provider_controller = ToolManager.get_builtin_provider(provider.provider)
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
credentials = tool_configuration.decrypt_tool_credentials(provider.credentials)
provider_controller = ToolManager.get_builtin_provider(provider_obj.provider)
tool_configuration = ToolConfigurationManager(
tenant_id=tenant_id,
config=provider_controller.get_credentials_schema(),
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.identity.name,
)
credentials = tool_configuration.decrypt_tool_credentials(provider_obj.credentials)
credentials = tool_configuration.mask_tool_credentials(credentials)
return credentials
@ -155,7 +169,7 @@ class BuiltinToolManageService:
"""
delete tool provider
"""
provider: BuiltinToolProvider = (
provider_obj: BuiltinToolProvider | None = (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
@ -164,15 +178,20 @@ class BuiltinToolManageService:
.first()
)
if provider is None:
if provider_obj is None:
raise ValueError(f"you have not added provider {provider_name}")
db.session.delete(provider)
db.session.delete(provider_obj)
db.session.commit()
# delete cache
provider_controller = ToolManager.get_builtin_provider(provider_name)
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
tool_configuration = ToolConfigurationManager(
tenant_id=tenant_id,
config=provider_controller.get_credentials_schema(),
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.identity.name,
)
tool_configuration.delete_tool_credentials_cache()
return {"result": "success"}
@ -212,8 +231,8 @@ class BuiltinToolManageService:
try:
# handle include, exclude
if is_filtered(
include_set=dify_config.POSITION_TOOL_INCLUDES_SET,
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET,
include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore
data=provider_controller,
name_func=lambda x: x.identity.name,
):

View File

@ -1,6 +1,6 @@
import json
import logging
from typing import Optional, Union
from typing import Literal, Optional, Union, overload
from configs import dify_config
from core.entities.provider_entities import ProviderConfig
@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
class ToolTransformService:
@classmethod
def get_tool_provider_icon_url(cls, provider_type: str, provider_name: str, icon: str) -> Union[str, dict]:
def get_tool_provider_icon_url(cls, provider_type: str, provider_name: str, icon: str | dict) -> Union[str, dict]:
"""
get tool provider icon url
"""
@ -35,7 +35,9 @@ class ToolTransformService:
return url_prefix + "builtin/" + provider_name + "/icon"
elif provider_type in [ToolProviderType.API.value, ToolProviderType.WORKFLOW.value]:
try:
return json.loads(icon)
if isinstance(icon, str):
return json.loads(icon)
return icon
except:
return {"background": "#252525", "content": "\ud83d\ude01"}
@ -92,7 +94,8 @@ class ToolTransformService:
# get credentials schema
schema = provider_controller.get_credentials_schema()
for name, value in schema.items():
result.masked_credentials[name] = ProviderConfig.Type.default(value.type)
if result.masked_credentials:
result.masked_credentials[name] = ""
# check if the provider need credentials
if not provider_controller.need_credentials:
@ -184,9 +187,14 @@ class ToolTransformService:
"""
username = "Anonymous"
try:
username = db_provider.user.name
user = db_provider.user
if not user:
raise ValueError("user not found")
username = user.name
except Exception as e:
logger.error(f"failed to get user name for api provider {db_provider.id}: {str(e)}")
# add provider into providers
credentials = db_provider.credentials
result = UserToolProvider(
@ -266,9 +274,9 @@ class ToolTransformService:
author=tool.identity.author,
name=tool.identity.name,
label=tool.identity.label,
description=tool.description.human,
description=tool.description.human if tool.description else I18nObject(en_US=''),
parameters=current_parameters,
labels=labels,
labels=labels or [],
)
if isinstance(tool, ApiToolBundle):
return UserTool(
@ -277,5 +285,5 @@ class ToolTransformService:
label=I18nObject(en_US=tool.operation_id, zh_Hans=tool.operation_id),
description=I18nObject(en_US=tool.summary or "", zh_Hans=tool.summary or ""),
parameters=tool.parameters,
labels=labels,
labels=labels or [],
)

View File

@ -4,7 +4,7 @@ from datetime import datetime
from sqlalchemy import or_
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.entities.api_entities import UserToolProvider
from core.tools.entities.api_entities import UserTool, UserToolProvider
from core.tools.provider.workflow_tool_provider import WorkflowToolProviderController
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
@ -32,7 +32,7 @@ class WorkflowToolManageService:
description: str,
parameters: list[dict],
privacy_policy: str = "",
labels: list[str] = None,
labels: list[str] | None = None,
) -> dict:
"""
Create a workflow tool.
@ -62,12 +62,12 @@ class WorkflowToolManageService:
if existing_workflow_tool_provider is not None:
raise ValueError(f"Tool with name {name} or app_id {workflow_app_id} already exists")
app: App = db.session.query(App).filter(App.id == workflow_app_id, App.tenant_id == tenant_id).first()
app: App | None = db.session.query(App).filter(App.id == workflow_app_id, App.tenant_id == tenant_id).first()
if app is None:
raise ValueError(f"App {workflow_app_id} not found")
workflow: Workflow = app.workflow
workflow: Workflow | None = app.workflow
if workflow is None:
raise ValueError(f"Workflow not found for app {workflow_app_id}")
@ -106,7 +106,7 @@ class WorkflowToolManageService:
description: str,
parameters: list[dict],
privacy_policy: str = "",
labels: list[str] = None,
labels: list[str] | None = None,
) -> dict:
"""
Update a workflow tool.
@ -138,7 +138,7 @@ class WorkflowToolManageService:
if existing_workflow_tool_provider is not None:
raise ValueError(f"Tool with name {name} already exists")
workflow_tool_provider: WorkflowToolProvider = (
workflow_tool_provider: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider)
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
.first()
@ -147,14 +147,14 @@ class WorkflowToolManageService:
if workflow_tool_provider is None:
raise ValueError(f"Tool {workflow_tool_id} not found")
app: App = (
app: App | None = (
db.session.query(App).filter(App.id == workflow_tool_provider.app_id, App.tenant_id == tenant_id).first()
)
if app is None:
raise ValueError(f"App {workflow_tool_provider.app_id} not found")
workflow: Workflow = app.workflow
workflow: Workflow | None = app.workflow
if workflow is None:
raise ValueError(f"Workflow not found for app {workflow_tool_provider.app_id}")
@ -243,36 +243,12 @@ class WorkflowToolManageService:
:param workflow_app_id: the workflow app id
:return: the tool
"""
db_tool: WorkflowToolProvider = (
db_tool: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider)
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
.first()
)
if db_tool is None:
raise ValueError(f"Tool {workflow_tool_id} not found")
workflow_app: App = db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first()
if workflow_app is None:
raise ValueError(f"App {db_tool.app_id} not found")
tool = ToolTransformService.workflow_provider_to_controller(db_tool)
return {
"name": db_tool.name,
"label": db_tool.label,
"workflow_tool_id": db_tool.id,
"workflow_app_id": db_tool.app_id,
"icon": json.loads(db_tool.icon),
"description": db_tool.description,
"parameters": jsonable_encoder(db_tool.parameter_configurations),
"tool": ToolTransformService.tool_to_user_tool(
tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool)
),
"synced": workflow_app.workflow.version == db_tool.version,
"privacy_policy": db_tool.privacy_policy,
}
return cls._get_workflow_tool(db_tool)
@classmethod
def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_id: str) -> dict:
@ -283,19 +259,31 @@ class WorkflowToolManageService:
:param workflow_app_id: the workflow app id
:return: the tool
"""
db_tool: WorkflowToolProvider = (
db_tool: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider)
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == workflow_app_id)
.first()
)
return cls._get_workflow_tool(db_tool)
@classmethod
def _get_workflow_tool(cls, db_tool: WorkflowToolProvider | None):
"""
Get a workflow tool.
:db_tool: the database tool
:return: the tool
"""
if db_tool is None:
raise ValueError(f"Tool {workflow_app_id} not found")
raise ValueError("Tool not found")
workflow_app: App = db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first()
workflow_app: App | None = db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == db_tool.tenant_id).first()
if workflow_app is None:
raise ValueError(f"App {db_tool.app_id} not found")
workflow = workflow_app.workflow
if not workflow:
raise ValueError("Workflow not found")
tool = ToolTransformService.workflow_provider_to_controller(db_tool)
@ -308,14 +296,14 @@ class WorkflowToolManageService:
"description": db_tool.description,
"parameters": jsonable_encoder(db_tool.parameter_configurations),
"tool": ToolTransformService.tool_to_user_tool(
tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool)
tool.get_tools(db_tool.tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool)
),
"synced": workflow_app.workflow.version == db_tool.version,
"synced": workflow.version == db_tool.version,
"privacy_policy": db_tool.privacy_policy,
}
@classmethod
def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> list[dict]:
def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> list[UserTool]:
"""
List workflow tool provider tools.
:param user_id: the user id
@ -323,7 +311,7 @@ class WorkflowToolManageService:
:param workflow_app_id: the workflow app id
:return: the list of tools
"""
db_tool: WorkflowToolProvider = (
db_tool: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider)
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
.first()
@ -336,6 +324,7 @@ class WorkflowToolManageService:
return [
ToolTransformService.tool_to_user_tool(
tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool)
tool=tool.get_tools(db_tool.tenant_id)[0],
labels=ToolLabelManager.get_tool_labels(tool)
)
]