mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 03:32:23 +08:00
Feat/support tool credentials bool schema (#2875)
This commit is contained in:
parent
cb79a90031
commit
95b74c211d
|
@ -171,6 +171,7 @@ class ToolProviderCredentials(BaseModel):
|
|||
SECRET_INPUT = "secret-input"
|
||||
TEXT_INPUT = "text-input"
|
||||
SELECT = "select"
|
||||
BOOLEAN = "boolean"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "ToolProviderCredentials.CredentialsType":
|
||||
|
@ -192,7 +193,7 @@ class ToolProviderCredentials(BaseModel):
|
|||
name: str = Field(..., description="The name of the credentials")
|
||||
type: CredentialsType = Field(..., description="The type of the credentials")
|
||||
required: bool = False
|
||||
default: Optional[str] = None
|
||||
default: Optional[Union[int, str]] = None
|
||||
options: Optional[list[ToolCredentialsOption]] = None
|
||||
label: Optional[I18nObject] = None
|
||||
help: Optional[I18nObject] = None
|
||||
|
|
|
@ -12,12 +12,11 @@ class BingProvider(BuiltinToolProviderController):
|
|||
meta={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id='',
|
||||
).validate_credentials(
|
||||
credentials=credentials,
|
||||
tool_parameters={
|
||||
"query": "test",
|
||||
"result_type": "link",
|
||||
"enable_webpages": True,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
|
|
|
@ -43,3 +43,63 @@ credentials_for_provider:
|
|||
zh_Hans: 例如 "https://api.bing.microsoft.com/v7.0/search"
|
||||
pt_BR: An endpoint is like "https://api.bing.microsoft.com/v7.0/search"
|
||||
default: https://api.bing.microsoft.com/v7.0/search
|
||||
allow_entities:
|
||||
type: boolean
|
||||
required: false
|
||||
label:
|
||||
en_US: Allow Entities Search
|
||||
zh_Hans: 支持实体搜索
|
||||
pt_BR: Allow Entities Search
|
||||
help:
|
||||
en_US: Does your subscription plan allow entity search
|
||||
zh_Hans: 您的订阅计划是否支持实体搜索
|
||||
pt_BR: Does your subscription plan allow entity search
|
||||
default: true
|
||||
allow_web_pages:
|
||||
type: boolean
|
||||
required: false
|
||||
label:
|
||||
en_US: Allow Web Pages Search
|
||||
zh_Hans: 支持网页搜索
|
||||
pt_BR: Allow Web Pages Search
|
||||
help:
|
||||
en_US: Does your subscription plan allow web pages search
|
||||
zh_Hans: 您的订阅计划是否支持网页搜索
|
||||
pt_BR: Does your subscription plan allow web pages search
|
||||
default: true
|
||||
allow_computation:
|
||||
type: boolean
|
||||
required: false
|
||||
label:
|
||||
en_US: Allow Computation Search
|
||||
zh_Hans: 支持计算搜索
|
||||
pt_BR: Allow Computation Search
|
||||
help:
|
||||
en_US: Does your subscription plan allow computation search
|
||||
zh_Hans: 您的订阅计划是否支持计算搜索
|
||||
pt_BR: Does your subscription plan allow computation search
|
||||
default: false
|
||||
allow_news:
|
||||
type: boolean
|
||||
required: false
|
||||
label:
|
||||
en_US: Allow News Search
|
||||
zh_Hans: 支持新闻搜索
|
||||
pt_BR: Allow News Search
|
||||
help:
|
||||
en_US: Does your subscription plan allow news search
|
||||
zh_Hans: 您的订阅计划是否支持新闻搜索
|
||||
pt_BR: Does your subscription plan allow news search
|
||||
default: false
|
||||
allow_related_searches:
|
||||
type: boolean
|
||||
required: false
|
||||
label:
|
||||
en_US: Allow Related Searches
|
||||
zh_Hans: 支持相关搜索
|
||||
pt_BR: Allow Related Searches
|
||||
help:
|
||||
en_US: Does your subscription plan allow related searches
|
||||
zh_Hans: 您的订阅计划是否支持相关搜索
|
||||
pt_BR: Does your subscription plan allow related searches
|
||||
default: false
|
||||
|
|
|
@ -10,53 +10,23 @@ from core.tools.tool.builtin_tool import BuiltinTool
|
|||
class BingSearchTool(BuiltinTool):
|
||||
url = 'https://api.bing.microsoft.com/v7.0/search'
|
||||
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
def _invoke_bing(self,
|
||||
user_id: str,
|
||||
subscription_key: str, query: str, limit: int,
|
||||
result_type: str, market: str, lang: str,
|
||||
filters: list[str]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
invoke bing search
|
||||
"""
|
||||
|
||||
key = self.runtime.credentials.get('subscription_key', None)
|
||||
if not key:
|
||||
raise Exception('subscription_key is required')
|
||||
|
||||
server_url = self.runtime.credentials.get('server_url', None)
|
||||
if not server_url:
|
||||
server_url = self.url
|
||||
|
||||
query = tool_parameters.get('query', None)
|
||||
if not query:
|
||||
raise Exception('query is required')
|
||||
|
||||
limit = min(tool_parameters.get('limit', 5), 10)
|
||||
result_type = tool_parameters.get('result_type', 'text') or 'text'
|
||||
|
||||
market = tool_parameters.get('market', 'US')
|
||||
lang = tool_parameters.get('language', 'en')
|
||||
filter = []
|
||||
|
||||
if tool_parameters.get('enable_computation', False):
|
||||
filter.append('Computation')
|
||||
if tool_parameters.get('enable_entities', False):
|
||||
filter.append('Entities')
|
||||
if tool_parameters.get('enable_news', False):
|
||||
filter.append('News')
|
||||
if tool_parameters.get('enable_related_search', False):
|
||||
filter.append('RelatedSearches')
|
||||
if tool_parameters.get('enable_webpages', False):
|
||||
filter.append('WebPages')
|
||||
|
||||
market_code = f'{lang}-{market}'
|
||||
accept_language = f'{lang},{market_code};q=0.9'
|
||||
headers = {
|
||||
'Ocp-Apim-Subscription-Key': key,
|
||||
'Ocp-Apim-Subscription-Key': subscription_key,
|
||||
'Accept-Language': accept_language
|
||||
}
|
||||
|
||||
query = quote(query)
|
||||
server_url = f'{server_url}?q={query}&mkt={market_code}&count={limit}&responseFilter={",".join(filter)}'
|
||||
server_url = f'{self.url}?q={query}&mkt={market_code}&count={limit}&responseFilter={",".join(filters)}'
|
||||
response = get(server_url, headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
|
@ -124,3 +94,105 @@ class BingSearchTool(BuiltinTool):
|
|||
text += f'{related["displayText"]} - {related["webSearchUrl"]}\n'
|
||||
|
||||
return self.create_text_message(text=self.summary(user_id=user_id, content=text))
|
||||
|
||||
|
||||
def validate_credentials(self, credentials: dict[str, Any], tool_parameters: dict[str, Any]) -> None:
|
||||
key = credentials.get('subscription_key', None)
|
||||
if not key:
|
||||
raise Exception('subscription_key is required')
|
||||
|
||||
server_url = credentials.get('server_url', None)
|
||||
if not server_url:
|
||||
server_url = self.url
|
||||
|
||||
query = tool_parameters.get('query', None)
|
||||
if not query:
|
||||
raise Exception('query is required')
|
||||
|
||||
limit = min(tool_parameters.get('limit', 5), 10)
|
||||
result_type = tool_parameters.get('result_type', 'text') or 'text'
|
||||
|
||||
market = tool_parameters.get('market', 'US')
|
||||
lang = tool_parameters.get('language', 'en')
|
||||
filter = []
|
||||
|
||||
if credentials.get('allow_entities', False):
|
||||
filter.append('Entities')
|
||||
|
||||
if credentials.get('allow_computation', False):
|
||||
filter.append('Computation')
|
||||
|
||||
if credentials.get('allow_news', False):
|
||||
filter.append('News')
|
||||
|
||||
if credentials.get('allow_related_searches', False):
|
||||
filter.append('RelatedSearches')
|
||||
|
||||
if credentials.get('allow_web_pages', False):
|
||||
filter.append('WebPages')
|
||||
|
||||
if not filter:
|
||||
raise Exception('At least one filter is required')
|
||||
|
||||
self._invoke_bing(
|
||||
user_id='test',
|
||||
subscription_key=key,
|
||||
query=query,
|
||||
limit=limit,
|
||||
result_type=result_type,
|
||||
market=market,
|
||||
lang=lang,
|
||||
filters=filter
|
||||
)
|
||||
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
|
||||
key = self.runtime.credentials.get('subscription_key', None)
|
||||
if not key:
|
||||
raise Exception('subscription_key is required')
|
||||
|
||||
server_url = self.runtime.credentials.get('server_url', None)
|
||||
if not server_url:
|
||||
server_url = self.url
|
||||
|
||||
query = tool_parameters.get('query', None)
|
||||
if not query:
|
||||
raise Exception('query is required')
|
||||
|
||||
limit = min(tool_parameters.get('limit', 5), 10)
|
||||
result_type = tool_parameters.get('result_type', 'text') or 'text'
|
||||
|
||||
market = tool_parameters.get('market', 'US')
|
||||
lang = tool_parameters.get('language', 'en')
|
||||
filter = []
|
||||
|
||||
if tool_parameters.get('enable_computation', False):
|
||||
filter.append('Computation')
|
||||
if tool_parameters.get('enable_entities', False):
|
||||
filter.append('Entities')
|
||||
if tool_parameters.get('enable_news', False):
|
||||
filter.append('News')
|
||||
if tool_parameters.get('enable_related_search', False):
|
||||
filter.append('RelatedSearches')
|
||||
if tool_parameters.get('enable_webpages', False):
|
||||
filter.append('WebPages')
|
||||
|
||||
if not filter:
|
||||
raise Exception('At least one filter is required')
|
||||
|
||||
return self._invoke_bing(
|
||||
user_id=user_id,
|
||||
subscription_key=key,
|
||||
query=query,
|
||||
limit=limit,
|
||||
result_type=result_type,
|
||||
market=market,
|
||||
lang=lang,
|
||||
filters=filter
|
||||
)
|
|
@ -246,8 +246,27 @@ class BuiltinToolProviderController(ToolProviderController):
|
|||
|
||||
if credentials[credential_name] not in [x.value for x in options]:
|
||||
raise ToolProviderCredentialValidationError(f'credential {credential_schema.label.en_US} should be one of {options}')
|
||||
|
||||
if credentials[credential_name]:
|
||||
elif credential_schema.type == ToolProviderCredentials.CredentialsType.BOOLEAN:
|
||||
if isinstance(credentials[credential_name], bool):
|
||||
pass
|
||||
elif isinstance(credentials[credential_name], str):
|
||||
if credentials[credential_name].lower() == 'true':
|
||||
credentials[credential_name] = True
|
||||
elif credentials[credential_name].lower() == 'false':
|
||||
credentials[credential_name] = False
|
||||
else:
|
||||
raise ToolProviderCredentialValidationError(f'credential {credential_schema.label.en_US} should be boolean')
|
||||
elif isinstance(credentials[credential_name], int):
|
||||
if credentials[credential_name] == 1:
|
||||
credentials[credential_name] = True
|
||||
elif credentials[credential_name] == 0:
|
||||
credentials[credential_name] = False
|
||||
else:
|
||||
raise ToolProviderCredentialValidationError(f'credential {credential_schema.label.en_US} should be boolean')
|
||||
else:
|
||||
raise ToolProviderCredentialValidationError(f'credential {credential_schema.label.en_US} should be boolean')
|
||||
|
||||
if credentials[credential_name] or credentials[credential_name] == False:
|
||||
credentials_need_to_validate.pop(credential_name)
|
||||
|
||||
for credential_name in credentials_need_to_validate:
|
||||
|
|
|
@ -138,9 +138,9 @@ class ToolManageService:
|
|||
:return: the list of tool providers
|
||||
"""
|
||||
provider = ToolManager.get_builtin_provider(provider_name)
|
||||
return [
|
||||
v.to_dict() for _, v in (provider.credentials_schema or {}).items()
|
||||
]
|
||||
return json.loads(serialize_base_model_array([
|
||||
v for _, v in (provider.credentials_schema or {}).items()
|
||||
]))
|
||||
|
||||
@staticmethod
|
||||
def parser_api_schema(schema: str) -> list[ApiBasedToolBundle]:
|
||||
|
|
|
@ -3,7 +3,7 @@ import type { FC } from 'react'
|
|||
import React, { useEffect, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import cn from 'classnames'
|
||||
import { toolCredentialToFormSchemas } from '../../utils/to-form-schema'
|
||||
import { addDefaultValue, toolCredentialToFormSchemas } from '../../utils/to-form-schema'
|
||||
import type { Collection } from '../../types'
|
||||
import Drawer from '@/app/components/base/drawer-plus'
|
||||
import Button from '@/app/components/base/button'
|
||||
|
@ -28,12 +28,15 @@ const ConfigCredential: FC<Props> = ({
|
|||
const { t } = useTranslation()
|
||||
const [credentialSchema, setCredentialSchema] = useState<any>(null)
|
||||
const { team_credentials: credentialValue, name: collectionName } = collection
|
||||
const [tempCredential, setTempCredential] = React.useState<any>(credentialValue)
|
||||
useEffect(() => {
|
||||
fetchBuiltInToolCredentialSchema(collectionName).then((res) => {
|
||||
setCredentialSchema(toolCredentialToFormSchemas(res))
|
||||
const toolCredentialSchemas = toolCredentialToFormSchemas(res)
|
||||
const defaultCredentials = addDefaultValue(credentialValue, toolCredentialSchemas)
|
||||
setCredentialSchema(toolCredentialSchemas)
|
||||
setTempCredential(defaultCredentials)
|
||||
})
|
||||
}, [])
|
||||
const [tempCredential, setTempCredential] = React.useState<any>(credentialValue)
|
||||
|
||||
return (
|
||||
<Drawer
|
||||
|
|
Loading…
Reference in New Issue
Block a user