mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 19:59:50 +08:00
feat: type
This commit is contained in:
parent
db8bf2a85e
commit
279dee485d
|
@ -128,10 +128,6 @@ class BasicProviderConfig(BaseModel):
|
|||
return mode
|
||||
raise ValueError(f'invalid mode value {value}')
|
||||
|
||||
@staticmethod
|
||||
def default(value: str) -> str:
|
||||
return ""
|
||||
|
||||
type: Type = Field(..., description="The type of the credentials")
|
||||
name: str = Field(..., description="The name of the credentials")
|
||||
|
||||
|
|
|
@ -26,7 +26,7 @@ class UserToolProvider(BaseModel):
|
|||
author: str
|
||||
name: str # identifier
|
||||
description: I18nObject
|
||||
icon: str
|
||||
icon: str | dict
|
||||
label: I18nObject # label
|
||||
type: ToolProviderType
|
||||
masked_credentials: Optional[dict] = None
|
||||
|
|
|
@ -208,8 +208,12 @@ class WorkflowToolProviderController(ToolProviderController):
|
|||
|
||||
if not db_providers:
|
||||
return []
|
||||
|
||||
app = db_providers.app
|
||||
if not app:
|
||||
raise ValueError("can not read app of workflow")
|
||||
|
||||
self.tools = [self._get_db_provider_tool(db_providers, db_providers.app)]
|
||||
self.tools = [self._get_db_provider_tool(db_providers, app)]
|
||||
|
||||
return self.tools
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import json
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import ForeignKey
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
@ -13,7 +14,7 @@ from .model import Account, App, Tenant
|
|||
from .types import StringUUID
|
||||
|
||||
|
||||
class BuiltinToolProvider(db.Model):
|
||||
class BuiltinToolProvider(Base):
|
||||
"""
|
||||
This table stores the tool provider information for built-in tools for each tenant.
|
||||
"""
|
||||
|
@ -25,61 +26,22 @@ class BuiltinToolProvider(db.Model):
|
|||
)
|
||||
|
||||
# id of the tool provider
|
||||
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text('uuid_generate_v4()'))
|
||||
# id of the tenant
|
||||
tenant_id = db.Column(StringUUID, nullable=True)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=True)
|
||||
# who created this tool provider
|
||||
user_id = db.Column(StringUUID, nullable=False)
|
||||
user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
# name of the tool provider
|
||||
provider = db.Column(db.String(40), nullable=False)
|
||||
provider: Mapped[str] = mapped_column(db.String(40), nullable=False)
|
||||
# credential of the tool provider
|
||||
encrypted_credentials = db.Column(db.Text, nullable=True)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
encrypted_credentials: Mapped[str] = mapped_column(db.Text, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
|
||||
@property
|
||||
def credentials(self) -> dict:
|
||||
return json.loads(self.encrypted_credentials)
|
||||
|
||||
class PublishedAppTool(db.Model):
|
||||
"""
|
||||
The table stores the apps published as a tool for each person.
|
||||
"""
|
||||
__tablename__ = 'tool_published_apps'
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint('id', name='published_app_tool_pkey'),
|
||||
db.UniqueConstraint('app_id', 'user_id', name='unique_published_app_tool')
|
||||
)
|
||||
|
||||
# id of the tool provider
|
||||
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
|
||||
# id of the app
|
||||
app_id = db.Column(StringUUID, ForeignKey('apps.id'), nullable=False)
|
||||
# who published this tool
|
||||
user_id = db.Column(StringUUID, nullable=False)
|
||||
# description of the tool, stored in i18n format, for human
|
||||
description = db.Column(db.Text, nullable=False)
|
||||
# llm_description of the tool, for LLM
|
||||
llm_description = db.Column(db.Text, nullable=False)
|
||||
# query description, query will be seem as a parameter of the tool, to describe this parameter to llm, we need this field
|
||||
query_description = db.Column(db.Text, nullable=False)
|
||||
# query name, the name of the query parameter
|
||||
query_name = db.Column(db.String(40), nullable=False)
|
||||
# name of the tool provider
|
||||
tool_name = db.Column(db.String(40), nullable=False)
|
||||
# author
|
||||
author = db.Column(db.String(40), nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
|
||||
@property
|
||||
def description_i18n(self) -> I18nObject:
|
||||
return I18nObject(**json.loads(self.description))
|
||||
|
||||
@property
|
||||
def app(self) -> App | None:
|
||||
return db.session.query(App).filter(App.id == self.app_id).first()
|
||||
|
||||
class ApiToolProvider(Base):
|
||||
"""
|
||||
The table stores the api providers.
|
||||
|
@ -129,14 +91,14 @@ class ApiToolProvider(Base):
|
|||
return json.loads(self.credentials_str)
|
||||
|
||||
@property
|
||||
def user(self) -> Account:
|
||||
def user(self) -> Account | None:
|
||||
return db.session.query(Account).filter(Account.id == self.user_id).first()
|
||||
|
||||
@property
|
||||
def tenant(self) -> Tenant:
|
||||
def tenant(self) -> Tenant | None:
|
||||
return db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first()
|
||||
|
||||
class ToolLabelBinding(db.Model):
|
||||
class ToolLabelBinding(Base):
|
||||
"""
|
||||
The table stores the labels for tools.
|
||||
"""
|
||||
|
@ -146,15 +108,15 @@ class ToolLabelBinding(db.Model):
|
|||
db.UniqueConstraint('tool_id', 'label_name', name='unique_tool_label_bind'),
|
||||
)
|
||||
|
||||
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text('uuid_generate_v4()'))
|
||||
# tool id
|
||||
tool_id = db.Column(db.String(64), nullable=False)
|
||||
tool_id: Mapped[str] = mapped_column(db.String(64), nullable=False)
|
||||
# tool type
|
||||
tool_type = db.Column(db.String(40), nullable=False)
|
||||
tool_type: Mapped[str] = mapped_column(db.String(40), nullable=False)
|
||||
# label name
|
||||
label_name = db.Column(db.String(40), nullable=False)
|
||||
label_name: Mapped[str] = mapped_column(db.String(40), nullable=False)
|
||||
|
||||
class WorkflowToolProvider(db.Model):
|
||||
class WorkflowToolProvider(Base):
|
||||
"""
|
||||
The table stores the workflow providers.
|
||||
"""
|
||||
|
@ -165,41 +127,37 @@ class WorkflowToolProvider(db.Model):
|
|||
db.UniqueConstraint('tenant_id', 'app_id', name='unique_workflow_tool_provider_app_id'),
|
||||
)
|
||||
|
||||
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text('uuid_generate_v4()'))
|
||||
# name of the workflow provider
|
||||
name = db.Column(db.String(40), nullable=False)
|
||||
name: Mapped[str] = mapped_column(db.String(40), nullable=False)
|
||||
# label of the workflow provider
|
||||
label = db.Column(db.String(255), nullable=False, server_default='')
|
||||
label: Mapped[str] = mapped_column(db.String(255), nullable=False, server_default='')
|
||||
# icon
|
||||
icon = db.Column(db.String(255), nullable=False)
|
||||
icon: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
# app id of the workflow provider
|
||||
app_id = db.Column(StringUUID, nullable=False)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
# version of the workflow provider
|
||||
version = db.Column(db.String(255), nullable=False, server_default='')
|
||||
version: Mapped[str] = mapped_column(db.String(255), nullable=False, server_default='')
|
||||
# who created this tool
|
||||
user_id = db.Column(StringUUID, nullable=False)
|
||||
user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
# tenant id
|
||||
tenant_id = db.Column(StringUUID, nullable=False)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
# description of the provider
|
||||
description = db.Column(db.Text, nullable=False)
|
||||
description: Mapped[str] = mapped_column(db.Text, nullable=False)
|
||||
# parameter configuration
|
||||
parameter_configuration = db.Column(db.Text, nullable=False, server_default='[]')
|
||||
parameter_configuration: Mapped[str] = mapped_column(db.Text, nullable=False, server_default='[]')
|
||||
# privacy policy
|
||||
privacy_policy = db.Column(db.String(255), nullable=True, server_default='')
|
||||
privacy_policy: Mapped[str] = mapped_column(db.String(255), nullable=True, server_default='')
|
||||
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
|
||||
@property
|
||||
def schema_type(self) -> ApiProviderSchemaType:
|
||||
return ApiProviderSchemaType.value_of(self.schema_type_str)
|
||||
|
||||
@property
|
||||
def user(self) -> Account:
|
||||
def user(self) -> Account | None:
|
||||
return db.session.query(Account).filter(Account.id == self.user_id).first()
|
||||
|
||||
@property
|
||||
def tenant(self) -> Tenant:
|
||||
def tenant(self) -> Tenant | None:
|
||||
return db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first()
|
||||
|
||||
@property
|
||||
|
@ -210,7 +168,7 @@ class WorkflowToolProvider(db.Model):
|
|||
]
|
||||
|
||||
@property
|
||||
def app(self) -> App:
|
||||
def app(self) -> App | None:
|
||||
return db.session.query(App).filter(App.id == self.app_id).first()
|
||||
|
||||
class ToolModelInvoke(db.Model):
|
||||
|
|
|
@ -28,10 +28,13 @@ class BuiltinToolManageService:
|
|||
tools = provider_controller.get_tools()
|
||||
|
||||
tool_provider_configurations = ToolConfigurationManager(
|
||||
tenant_id=tenant_id, provider_controller=provider_controller
|
||||
tenant_id=tenant_id,
|
||||
config=provider_controller.get_credentials_schema(),
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.identity.name,
|
||||
)
|
||||
# check if user has added the provider
|
||||
builtin_provider: BuiltinToolProvider = (
|
||||
builtin_provider: BuiltinToolProvider | None = (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
.filter(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
|
@ -75,7 +78,7 @@ class BuiltinToolManageService:
|
|||
update builtin tool provider
|
||||
"""
|
||||
# get if the provider exists
|
||||
provider: BuiltinToolProvider = (
|
||||
provider: BuiltinToolProvider | None = (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
.filter(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
|
@ -89,7 +92,13 @@ class BuiltinToolManageService:
|
|||
provider_controller = ToolManager.get_builtin_provider(provider_name)
|
||||
if not provider_controller.need_credentials:
|
||||
raise ValueError(f"provider {provider_name} does not need credentials")
|
||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
|
||||
tool_configuration = ToolConfigurationManager(
|
||||
tenant_id=tenant_id,
|
||||
config=provider_controller.get_credentials_schema(),
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.identity.name,
|
||||
)
|
||||
|
||||
# get original credentials if exists
|
||||
if provider is not None:
|
||||
original_credentials = tool_configuration.decrypt_tool_credentials(provider.credentials)
|
||||
|
@ -132,7 +141,7 @@ class BuiltinToolManageService:
|
|||
"""
|
||||
get builtin tool provider credentials
|
||||
"""
|
||||
provider: BuiltinToolProvider = (
|
||||
provider_obj: BuiltinToolProvider | None = (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
.filter(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
|
@ -141,12 +150,17 @@ class BuiltinToolManageService:
|
|||
.first()
|
||||
)
|
||||
|
||||
if provider is None:
|
||||
if provider_obj is None:
|
||||
return {}
|
||||
|
||||
provider_controller = ToolManager.get_builtin_provider(provider.provider)
|
||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
|
||||
credentials = tool_configuration.decrypt_tool_credentials(provider.credentials)
|
||||
provider_controller = ToolManager.get_builtin_provider(provider_obj.provider)
|
||||
tool_configuration = ToolConfigurationManager(
|
||||
tenant_id=tenant_id,
|
||||
config=provider_controller.get_credentials_schema(),
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.identity.name,
|
||||
)
|
||||
credentials = tool_configuration.decrypt_tool_credentials(provider_obj.credentials)
|
||||
credentials = tool_configuration.mask_tool_credentials(credentials)
|
||||
return credentials
|
||||
|
||||
|
@ -155,7 +169,7 @@ class BuiltinToolManageService:
|
|||
"""
|
||||
delete tool provider
|
||||
"""
|
||||
provider: BuiltinToolProvider = (
|
||||
provider_obj: BuiltinToolProvider | None = (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
.filter(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
|
@ -164,15 +178,20 @@ class BuiltinToolManageService:
|
|||
.first()
|
||||
)
|
||||
|
||||
if provider is None:
|
||||
if provider_obj is None:
|
||||
raise ValueError(f"you have not added provider {provider_name}")
|
||||
|
||||
db.session.delete(provider)
|
||||
db.session.delete(provider_obj)
|
||||
db.session.commit()
|
||||
|
||||
# delete cache
|
||||
provider_controller = ToolManager.get_builtin_provider(provider_name)
|
||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
|
||||
tool_configuration = ToolConfigurationManager(
|
||||
tenant_id=tenant_id,
|
||||
config=provider_controller.get_credentials_schema(),
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.identity.name,
|
||||
)
|
||||
tool_configuration.delete_tool_credentials_cache()
|
||||
|
||||
return {"result": "success"}
|
||||
|
@ -212,8 +231,8 @@ class BuiltinToolManageService:
|
|||
try:
|
||||
# handle include, exclude
|
||||
if is_filtered(
|
||||
include_set=dify_config.POSITION_TOOL_INCLUDES_SET,
|
||||
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET,
|
||||
include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore
|
||||
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore
|
||||
data=provider_controller,
|
||||
name_func=lambda x: x.identity.name,
|
||||
):
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import json
|
||||
import logging
|
||||
from typing import Optional, Union
|
||||
from typing import Literal, Optional, Union, overload
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
|
@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
class ToolTransformService:
|
||||
@classmethod
|
||||
def get_tool_provider_icon_url(cls, provider_type: str, provider_name: str, icon: str) -> Union[str, dict]:
|
||||
def get_tool_provider_icon_url(cls, provider_type: str, provider_name: str, icon: str | dict) -> Union[str, dict]:
|
||||
"""
|
||||
get tool provider icon url
|
||||
"""
|
||||
|
@ -35,7 +35,9 @@ class ToolTransformService:
|
|||
return url_prefix + "builtin/" + provider_name + "/icon"
|
||||
elif provider_type in [ToolProviderType.API.value, ToolProviderType.WORKFLOW.value]:
|
||||
try:
|
||||
return json.loads(icon)
|
||||
if isinstance(icon, str):
|
||||
return json.loads(icon)
|
||||
return icon
|
||||
except:
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
|
||||
|
@ -92,7 +94,8 @@ class ToolTransformService:
|
|||
# get credentials schema
|
||||
schema = provider_controller.get_credentials_schema()
|
||||
for name, value in schema.items():
|
||||
result.masked_credentials[name] = ProviderConfig.Type.default(value.type)
|
||||
if result.masked_credentials:
|
||||
result.masked_credentials[name] = ""
|
||||
|
||||
# check if the provider need credentials
|
||||
if not provider_controller.need_credentials:
|
||||
|
@ -184,9 +187,14 @@ class ToolTransformService:
|
|||
"""
|
||||
username = "Anonymous"
|
||||
try:
|
||||
username = db_provider.user.name
|
||||
user = db_provider.user
|
||||
if not user:
|
||||
raise ValueError("user not found")
|
||||
|
||||
username = user.name
|
||||
except Exception as e:
|
||||
logger.error(f"failed to get user name for api provider {db_provider.id}: {str(e)}")
|
||||
|
||||
# add provider into providers
|
||||
credentials = db_provider.credentials
|
||||
result = UserToolProvider(
|
||||
|
@ -266,9 +274,9 @@ class ToolTransformService:
|
|||
author=tool.identity.author,
|
||||
name=tool.identity.name,
|
||||
label=tool.identity.label,
|
||||
description=tool.description.human,
|
||||
description=tool.description.human if tool.description else I18nObject(en_US=''),
|
||||
parameters=current_parameters,
|
||||
labels=labels,
|
||||
labels=labels or [],
|
||||
)
|
||||
if isinstance(tool, ApiToolBundle):
|
||||
return UserTool(
|
||||
|
@ -277,5 +285,5 @@ class ToolTransformService:
|
|||
label=I18nObject(en_US=tool.operation_id, zh_Hans=tool.operation_id),
|
||||
description=I18nObject(en_US=tool.summary or "", zh_Hans=tool.summary or ""),
|
||||
parameters=tool.parameters,
|
||||
labels=labels,
|
||||
labels=labels or [],
|
||||
)
|
||||
|
|
|
@ -4,7 +4,7 @@ from datetime import datetime
|
|||
from sqlalchemy import or_
|
||||
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.entities.api_entities import UserToolProvider
|
||||
from core.tools.entities.api_entities import UserTool, UserToolProvider
|
||||
from core.tools.provider.workflow_tool_provider import WorkflowToolProviderController
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
|
||||
|
@ -32,7 +32,7 @@ class WorkflowToolManageService:
|
|||
description: str,
|
||||
parameters: list[dict],
|
||||
privacy_policy: str = "",
|
||||
labels: list[str] = None,
|
||||
labels: list[str] | None = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Create a workflow tool.
|
||||
|
@ -62,12 +62,12 @@ class WorkflowToolManageService:
|
|||
if existing_workflow_tool_provider is not None:
|
||||
raise ValueError(f"Tool with name {name} or app_id {workflow_app_id} already exists")
|
||||
|
||||
app: App = db.session.query(App).filter(App.id == workflow_app_id, App.tenant_id == tenant_id).first()
|
||||
app: App | None = db.session.query(App).filter(App.id == workflow_app_id, App.tenant_id == tenant_id).first()
|
||||
|
||||
if app is None:
|
||||
raise ValueError(f"App {workflow_app_id} not found")
|
||||
|
||||
workflow: Workflow = app.workflow
|
||||
workflow: Workflow | None = app.workflow
|
||||
if workflow is None:
|
||||
raise ValueError(f"Workflow not found for app {workflow_app_id}")
|
||||
|
||||
|
@ -106,7 +106,7 @@ class WorkflowToolManageService:
|
|||
description: str,
|
||||
parameters: list[dict],
|
||||
privacy_policy: str = "",
|
||||
labels: list[str] = None,
|
||||
labels: list[str] | None = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Update a workflow tool.
|
||||
|
@ -138,7 +138,7 @@ class WorkflowToolManageService:
|
|||
if existing_workflow_tool_provider is not None:
|
||||
raise ValueError(f"Tool with name {name} already exists")
|
||||
|
||||
workflow_tool_provider: WorkflowToolProvider = (
|
||||
workflow_tool_provider: WorkflowToolProvider | None = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
|
||||
.first()
|
||||
|
@ -147,14 +147,14 @@ class WorkflowToolManageService:
|
|||
if workflow_tool_provider is None:
|
||||
raise ValueError(f"Tool {workflow_tool_id} not found")
|
||||
|
||||
app: App = (
|
||||
app: App | None = (
|
||||
db.session.query(App).filter(App.id == workflow_tool_provider.app_id, App.tenant_id == tenant_id).first()
|
||||
)
|
||||
|
||||
if app is None:
|
||||
raise ValueError(f"App {workflow_tool_provider.app_id} not found")
|
||||
|
||||
workflow: Workflow = app.workflow
|
||||
workflow: Workflow | None = app.workflow
|
||||
if workflow is None:
|
||||
raise ValueError(f"Workflow not found for app {workflow_tool_provider.app_id}")
|
||||
|
||||
|
@ -243,36 +243,12 @@ class WorkflowToolManageService:
|
|||
:param workflow_app_id: the workflow app id
|
||||
:return: the tool
|
||||
"""
|
||||
db_tool: WorkflowToolProvider = (
|
||||
db_tool: WorkflowToolProvider | None = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if db_tool is None:
|
||||
raise ValueError(f"Tool {workflow_tool_id} not found")
|
||||
|
||||
workflow_app: App = db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first()
|
||||
|
||||
if workflow_app is None:
|
||||
raise ValueError(f"App {db_tool.app_id} not found")
|
||||
|
||||
tool = ToolTransformService.workflow_provider_to_controller(db_tool)
|
||||
|
||||
return {
|
||||
"name": db_tool.name,
|
||||
"label": db_tool.label,
|
||||
"workflow_tool_id": db_tool.id,
|
||||
"workflow_app_id": db_tool.app_id,
|
||||
"icon": json.loads(db_tool.icon),
|
||||
"description": db_tool.description,
|
||||
"parameters": jsonable_encoder(db_tool.parameter_configurations),
|
||||
"tool": ToolTransformService.tool_to_user_tool(
|
||||
tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool)
|
||||
),
|
||||
"synced": workflow_app.workflow.version == db_tool.version,
|
||||
"privacy_policy": db_tool.privacy_policy,
|
||||
}
|
||||
return cls._get_workflow_tool(db_tool)
|
||||
|
||||
@classmethod
|
||||
def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_id: str) -> dict:
|
||||
|
@ -283,19 +259,31 @@ class WorkflowToolManageService:
|
|||
:param workflow_app_id: the workflow app id
|
||||
:return: the tool
|
||||
"""
|
||||
db_tool: WorkflowToolProvider = (
|
||||
db_tool: WorkflowToolProvider | None = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == workflow_app_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
return cls._get_workflow_tool(db_tool)
|
||||
|
||||
@classmethod
|
||||
def _get_workflow_tool(cls, db_tool: WorkflowToolProvider | None):
|
||||
"""
|
||||
Get a workflow tool.
|
||||
:db_tool: the database tool
|
||||
:return: the tool
|
||||
"""
|
||||
if db_tool is None:
|
||||
raise ValueError(f"Tool {workflow_app_id} not found")
|
||||
raise ValueError("Tool not found")
|
||||
|
||||
workflow_app: App = db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first()
|
||||
workflow_app: App | None = db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == db_tool.tenant_id).first()
|
||||
|
||||
if workflow_app is None:
|
||||
raise ValueError(f"App {db_tool.app_id} not found")
|
||||
|
||||
workflow = workflow_app.workflow
|
||||
if not workflow:
|
||||
raise ValueError("Workflow not found")
|
||||
|
||||
tool = ToolTransformService.workflow_provider_to_controller(db_tool)
|
||||
|
||||
|
@ -308,14 +296,14 @@ class WorkflowToolManageService:
|
|||
"description": db_tool.description,
|
||||
"parameters": jsonable_encoder(db_tool.parameter_configurations),
|
||||
"tool": ToolTransformService.tool_to_user_tool(
|
||||
tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool)
|
||||
tool.get_tools(db_tool.tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool)
|
||||
),
|
||||
"synced": workflow_app.workflow.version == db_tool.version,
|
||||
"synced": workflow.version == db_tool.version,
|
||||
"privacy_policy": db_tool.privacy_policy,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> list[dict]:
|
||||
def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> list[UserTool]:
|
||||
"""
|
||||
List workflow tool provider tools.
|
||||
:param user_id: the user id
|
||||
|
@ -323,7 +311,7 @@ class WorkflowToolManageService:
|
|||
:param workflow_app_id: the workflow app id
|
||||
:return: the list of tools
|
||||
"""
|
||||
db_tool: WorkflowToolProvider = (
|
||||
db_tool: WorkflowToolProvider | None = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
|
||||
.first()
|
||||
|
@ -336,6 +324,7 @@ class WorkflowToolManageService:
|
|||
|
||||
return [
|
||||
ToolTransformService.tool_to_user_tool(
|
||||
tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool)
|
||||
tool=tool.get_tools(db_tool.tenant_id)[0],
|
||||
labels=ToolLabelManager.get_tool_labels(tool)
|
||||
)
|
||||
]
|
||||
|
|
Loading…
Reference in New Issue
Block a user