mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 03:32:23 +08:00
Feat/support tool credentials bool schema (#2875)
This commit is contained in:
parent
cb79a90031
commit
95b74c211d
|
@ -171,6 +171,7 @@ class ToolProviderCredentials(BaseModel):
|
||||||
SECRET_INPUT = "secret-input"
|
SECRET_INPUT = "secret-input"
|
||||||
TEXT_INPUT = "text-input"
|
TEXT_INPUT = "text-input"
|
||||||
SELECT = "select"
|
SELECT = "select"
|
||||||
|
BOOLEAN = "boolean"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def value_of(cls, value: str) -> "ToolProviderCredentials.CredentialsType":
|
def value_of(cls, value: str) -> "ToolProviderCredentials.CredentialsType":
|
||||||
|
@ -192,7 +193,7 @@ class ToolProviderCredentials(BaseModel):
|
||||||
name: str = Field(..., description="The name of the credentials")
|
name: str = Field(..., description="The name of the credentials")
|
||||||
type: CredentialsType = Field(..., description="The type of the credentials")
|
type: CredentialsType = Field(..., description="The type of the credentials")
|
||||||
required: bool = False
|
required: bool = False
|
||||||
default: Optional[str] = None
|
default: Optional[Union[int, str]] = None
|
||||||
options: Optional[list[ToolCredentialsOption]] = None
|
options: Optional[list[ToolCredentialsOption]] = None
|
||||||
label: Optional[I18nObject] = None
|
label: Optional[I18nObject] = None
|
||||||
help: Optional[I18nObject] = None
|
help: Optional[I18nObject] = None
|
||||||
|
|
|
@ -12,12 +12,11 @@ class BingProvider(BuiltinToolProviderController):
|
||||||
meta={
|
meta={
|
||||||
"credentials": credentials,
|
"credentials": credentials,
|
||||||
}
|
}
|
||||||
).invoke(
|
).validate_credentials(
|
||||||
user_id='',
|
credentials=credentials,
|
||||||
tool_parameters={
|
tool_parameters={
|
||||||
"query": "test",
|
"query": "test",
|
||||||
"result_type": "link",
|
"result_type": "link",
|
||||||
"enable_webpages": True,
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
@ -43,3 +43,63 @@ credentials_for_provider:
|
||||||
zh_Hans: 例如 "https://api.bing.microsoft.com/v7.0/search"
|
zh_Hans: 例如 "https://api.bing.microsoft.com/v7.0/search"
|
||||||
pt_BR: An endpoint is like "https://api.bing.microsoft.com/v7.0/search"
|
pt_BR: An endpoint is like "https://api.bing.microsoft.com/v7.0/search"
|
||||||
default: https://api.bing.microsoft.com/v7.0/search
|
default: https://api.bing.microsoft.com/v7.0/search
|
||||||
|
allow_entities:
|
||||||
|
type: boolean
|
||||||
|
required: false
|
||||||
|
label:
|
||||||
|
en_US: Allow Entities Search
|
||||||
|
zh_Hans: 支持实体搜索
|
||||||
|
pt_BR: Allow Entities Search
|
||||||
|
help:
|
||||||
|
en_US: Does your subscription plan allow entity search
|
||||||
|
zh_Hans: 您的订阅计划是否支持实体搜索
|
||||||
|
pt_BR: Does your subscription plan allow entity search
|
||||||
|
default: true
|
||||||
|
allow_web_pages:
|
||||||
|
type: boolean
|
||||||
|
required: false
|
||||||
|
label:
|
||||||
|
en_US: Allow Web Pages Search
|
||||||
|
zh_Hans: 支持网页搜索
|
||||||
|
pt_BR: Allow Web Pages Search
|
||||||
|
help:
|
||||||
|
en_US: Does your subscription plan allow web pages search
|
||||||
|
zh_Hans: 您的订阅计划是否支持网页搜索
|
||||||
|
pt_BR: Does your subscription plan allow web pages search
|
||||||
|
default: true
|
||||||
|
allow_computation:
|
||||||
|
type: boolean
|
||||||
|
required: false
|
||||||
|
label:
|
||||||
|
en_US: Allow Computation Search
|
||||||
|
zh_Hans: 支持计算搜索
|
||||||
|
pt_BR: Allow Computation Search
|
||||||
|
help:
|
||||||
|
en_US: Does your subscription plan allow computation search
|
||||||
|
zh_Hans: 您的订阅计划是否支持计算搜索
|
||||||
|
pt_BR: Does your subscription plan allow computation search
|
||||||
|
default: false
|
||||||
|
allow_news:
|
||||||
|
type: boolean
|
||||||
|
required: false
|
||||||
|
label:
|
||||||
|
en_US: Allow News Search
|
||||||
|
zh_Hans: 支持新闻搜索
|
||||||
|
pt_BR: Allow News Search
|
||||||
|
help:
|
||||||
|
en_US: Does your subscription plan allow news search
|
||||||
|
zh_Hans: 您的订阅计划是否支持新闻搜索
|
||||||
|
pt_BR: Does your subscription plan allow news search
|
||||||
|
default: false
|
||||||
|
allow_related_searches:
|
||||||
|
type: boolean
|
||||||
|
required: false
|
||||||
|
label:
|
||||||
|
en_US: Allow Related Searches
|
||||||
|
zh_Hans: 支持相关搜索
|
||||||
|
pt_BR: Allow Related Searches
|
||||||
|
help:
|
||||||
|
en_US: Does your subscription plan allow related searches
|
||||||
|
zh_Hans: 您的订阅计划是否支持相关搜索
|
||||||
|
pt_BR: Does your subscription plan allow related searches
|
||||||
|
default: false
|
||||||
|
|
|
@ -10,53 +10,23 @@ from core.tools.tool.builtin_tool import BuiltinTool
|
||||||
class BingSearchTool(BuiltinTool):
|
class BingSearchTool(BuiltinTool):
|
||||||
url = 'https://api.bing.microsoft.com/v7.0/search'
|
url = 'https://api.bing.microsoft.com/v7.0/search'
|
||||||
|
|
||||||
def _invoke(self,
|
def _invoke_bing(self,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
tool_parameters: dict[str, Any],
|
subscription_key: str, query: str, limit: int,
|
||||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
result_type: str, market: str, lang: str,
|
||||||
|
filters: list[str]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||||
"""
|
"""
|
||||||
invoke tools
|
invoke bing search
|
||||||
"""
|
"""
|
||||||
|
|
||||||
key = self.runtime.credentials.get('subscription_key', None)
|
|
||||||
if not key:
|
|
||||||
raise Exception('subscription_key is required')
|
|
||||||
|
|
||||||
server_url = self.runtime.credentials.get('server_url', None)
|
|
||||||
if not server_url:
|
|
||||||
server_url = self.url
|
|
||||||
|
|
||||||
query = tool_parameters.get('query', None)
|
|
||||||
if not query:
|
|
||||||
raise Exception('query is required')
|
|
||||||
|
|
||||||
limit = min(tool_parameters.get('limit', 5), 10)
|
|
||||||
result_type = tool_parameters.get('result_type', 'text') or 'text'
|
|
||||||
|
|
||||||
market = tool_parameters.get('market', 'US')
|
|
||||||
lang = tool_parameters.get('language', 'en')
|
|
||||||
filter = []
|
|
||||||
|
|
||||||
if tool_parameters.get('enable_computation', False):
|
|
||||||
filter.append('Computation')
|
|
||||||
if tool_parameters.get('enable_entities', False):
|
|
||||||
filter.append('Entities')
|
|
||||||
if tool_parameters.get('enable_news', False):
|
|
||||||
filter.append('News')
|
|
||||||
if tool_parameters.get('enable_related_search', False):
|
|
||||||
filter.append('RelatedSearches')
|
|
||||||
if tool_parameters.get('enable_webpages', False):
|
|
||||||
filter.append('WebPages')
|
|
||||||
|
|
||||||
market_code = f'{lang}-{market}'
|
market_code = f'{lang}-{market}'
|
||||||
accept_language = f'{lang},{market_code};q=0.9'
|
accept_language = f'{lang},{market_code};q=0.9'
|
||||||
headers = {
|
headers = {
|
||||||
'Ocp-Apim-Subscription-Key': key,
|
'Ocp-Apim-Subscription-Key': subscription_key,
|
||||||
'Accept-Language': accept_language
|
'Accept-Language': accept_language
|
||||||
}
|
}
|
||||||
|
|
||||||
query = quote(query)
|
query = quote(query)
|
||||||
server_url = f'{server_url}?q={query}&mkt={market_code}&count={limit}&responseFilter={",".join(filter)}'
|
server_url = f'{self.url}?q={query}&mkt={market_code}&count={limit}&responseFilter={",".join(filters)}'
|
||||||
response = get(server_url, headers=headers)
|
response = get(server_url, headers=headers)
|
||||||
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
|
@ -124,3 +94,105 @@ class BingSearchTool(BuiltinTool):
|
||||||
text += f'{related["displayText"]} - {related["webSearchUrl"]}\n'
|
text += f'{related["displayText"]} - {related["webSearchUrl"]}\n'
|
||||||
|
|
||||||
return self.create_text_message(text=self.summary(user_id=user_id, content=text))
|
return self.create_text_message(text=self.summary(user_id=user_id, content=text))
|
||||||
|
|
||||||
|
|
||||||
|
def validate_credentials(self, credentials: dict[str, Any], tool_parameters: dict[str, Any]) -> None:
|
||||||
|
key = credentials.get('subscription_key', None)
|
||||||
|
if not key:
|
||||||
|
raise Exception('subscription_key is required')
|
||||||
|
|
||||||
|
server_url = credentials.get('server_url', None)
|
||||||
|
if not server_url:
|
||||||
|
server_url = self.url
|
||||||
|
|
||||||
|
query = tool_parameters.get('query', None)
|
||||||
|
if not query:
|
||||||
|
raise Exception('query is required')
|
||||||
|
|
||||||
|
limit = min(tool_parameters.get('limit', 5), 10)
|
||||||
|
result_type = tool_parameters.get('result_type', 'text') or 'text'
|
||||||
|
|
||||||
|
market = tool_parameters.get('market', 'US')
|
||||||
|
lang = tool_parameters.get('language', 'en')
|
||||||
|
filter = []
|
||||||
|
|
||||||
|
if credentials.get('allow_entities', False):
|
||||||
|
filter.append('Entities')
|
||||||
|
|
||||||
|
if credentials.get('allow_computation', False):
|
||||||
|
filter.append('Computation')
|
||||||
|
|
||||||
|
if credentials.get('allow_news', False):
|
||||||
|
filter.append('News')
|
||||||
|
|
||||||
|
if credentials.get('allow_related_searches', False):
|
||||||
|
filter.append('RelatedSearches')
|
||||||
|
|
||||||
|
if credentials.get('allow_web_pages', False):
|
||||||
|
filter.append('WebPages')
|
||||||
|
|
||||||
|
if not filter:
|
||||||
|
raise Exception('At least one filter is required')
|
||||||
|
|
||||||
|
self._invoke_bing(
|
||||||
|
user_id='test',
|
||||||
|
subscription_key=key,
|
||||||
|
query=query,
|
||||||
|
limit=limit,
|
||||||
|
result_type=result_type,
|
||||||
|
market=market,
|
||||||
|
lang=lang,
|
||||||
|
filters=filter
|
||||||
|
)
|
||||||
|
|
||||||
|
def _invoke(self,
|
||||||
|
user_id: str,
|
||||||
|
tool_parameters: dict[str, Any],
|
||||||
|
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||||
|
"""
|
||||||
|
invoke tools
|
||||||
|
"""
|
||||||
|
|
||||||
|
key = self.runtime.credentials.get('subscription_key', None)
|
||||||
|
if not key:
|
||||||
|
raise Exception('subscription_key is required')
|
||||||
|
|
||||||
|
server_url = self.runtime.credentials.get('server_url', None)
|
||||||
|
if not server_url:
|
||||||
|
server_url = self.url
|
||||||
|
|
||||||
|
query = tool_parameters.get('query', None)
|
||||||
|
if not query:
|
||||||
|
raise Exception('query is required')
|
||||||
|
|
||||||
|
limit = min(tool_parameters.get('limit', 5), 10)
|
||||||
|
result_type = tool_parameters.get('result_type', 'text') or 'text'
|
||||||
|
|
||||||
|
market = tool_parameters.get('market', 'US')
|
||||||
|
lang = tool_parameters.get('language', 'en')
|
||||||
|
filter = []
|
||||||
|
|
||||||
|
if tool_parameters.get('enable_computation', False):
|
||||||
|
filter.append('Computation')
|
||||||
|
if tool_parameters.get('enable_entities', False):
|
||||||
|
filter.append('Entities')
|
||||||
|
if tool_parameters.get('enable_news', False):
|
||||||
|
filter.append('News')
|
||||||
|
if tool_parameters.get('enable_related_search', False):
|
||||||
|
filter.append('RelatedSearches')
|
||||||
|
if tool_parameters.get('enable_webpages', False):
|
||||||
|
filter.append('WebPages')
|
||||||
|
|
||||||
|
if not filter:
|
||||||
|
raise Exception('At least one filter is required')
|
||||||
|
|
||||||
|
return self._invoke_bing(
|
||||||
|
user_id=user_id,
|
||||||
|
subscription_key=key,
|
||||||
|
query=query,
|
||||||
|
limit=limit,
|
||||||
|
result_type=result_type,
|
||||||
|
market=market,
|
||||||
|
lang=lang,
|
||||||
|
filters=filter
|
||||||
|
)
|
|
@ -246,8 +246,27 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||||
|
|
||||||
if credentials[credential_name] not in [x.value for x in options]:
|
if credentials[credential_name] not in [x.value for x in options]:
|
||||||
raise ToolProviderCredentialValidationError(f'credential {credential_schema.label.en_US} should be one of {options}')
|
raise ToolProviderCredentialValidationError(f'credential {credential_schema.label.en_US} should be one of {options}')
|
||||||
|
elif credential_schema.type == ToolProviderCredentials.CredentialsType.BOOLEAN:
|
||||||
if credentials[credential_name]:
|
if isinstance(credentials[credential_name], bool):
|
||||||
|
pass
|
||||||
|
elif isinstance(credentials[credential_name], str):
|
||||||
|
if credentials[credential_name].lower() == 'true':
|
||||||
|
credentials[credential_name] = True
|
||||||
|
elif credentials[credential_name].lower() == 'false':
|
||||||
|
credentials[credential_name] = False
|
||||||
|
else:
|
||||||
|
raise ToolProviderCredentialValidationError(f'credential {credential_schema.label.en_US} should be boolean')
|
||||||
|
elif isinstance(credentials[credential_name], int):
|
||||||
|
if credentials[credential_name] == 1:
|
||||||
|
credentials[credential_name] = True
|
||||||
|
elif credentials[credential_name] == 0:
|
||||||
|
credentials[credential_name] = False
|
||||||
|
else:
|
||||||
|
raise ToolProviderCredentialValidationError(f'credential {credential_schema.label.en_US} should be boolean')
|
||||||
|
else:
|
||||||
|
raise ToolProviderCredentialValidationError(f'credential {credential_schema.label.en_US} should be boolean')
|
||||||
|
|
||||||
|
if credentials[credential_name] or credentials[credential_name] == False:
|
||||||
credentials_need_to_validate.pop(credential_name)
|
credentials_need_to_validate.pop(credential_name)
|
||||||
|
|
||||||
for credential_name in credentials_need_to_validate:
|
for credential_name in credentials_need_to_validate:
|
||||||
|
|
|
@ -138,9 +138,9 @@ class ToolManageService:
|
||||||
:return: the list of tool providers
|
:return: the list of tool providers
|
||||||
"""
|
"""
|
||||||
provider = ToolManager.get_builtin_provider(provider_name)
|
provider = ToolManager.get_builtin_provider(provider_name)
|
||||||
return [
|
return json.loads(serialize_base_model_array([
|
||||||
v.to_dict() for _, v in (provider.credentials_schema or {}).items()
|
v for _, v in (provider.credentials_schema or {}).items()
|
||||||
]
|
]))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parser_api_schema(schema: str) -> list[ApiBasedToolBundle]:
|
def parser_api_schema(schema: str) -> list[ApiBasedToolBundle]:
|
||||||
|
|
|
@ -3,7 +3,7 @@ import type { FC } from 'react'
|
||||||
import React, { useEffect, useState } from 'react'
|
import React, { useEffect, useState } from 'react'
|
||||||
import { useTranslation } from 'react-i18next'
|
import { useTranslation } from 'react-i18next'
|
||||||
import cn from 'classnames'
|
import cn from 'classnames'
|
||||||
import { toolCredentialToFormSchemas } from '../../utils/to-form-schema'
|
import { addDefaultValue, toolCredentialToFormSchemas } from '../../utils/to-form-schema'
|
||||||
import type { Collection } from '../../types'
|
import type { Collection } from '../../types'
|
||||||
import Drawer from '@/app/components/base/drawer-plus'
|
import Drawer from '@/app/components/base/drawer-plus'
|
||||||
import Button from '@/app/components/base/button'
|
import Button from '@/app/components/base/button'
|
||||||
|
@ -28,12 +28,15 @@ const ConfigCredential: FC<Props> = ({
|
||||||
const { t } = useTranslation()
|
const { t } = useTranslation()
|
||||||
const [credentialSchema, setCredentialSchema] = useState<any>(null)
|
const [credentialSchema, setCredentialSchema] = useState<any>(null)
|
||||||
const { team_credentials: credentialValue, name: collectionName } = collection
|
const { team_credentials: credentialValue, name: collectionName } = collection
|
||||||
|
const [tempCredential, setTempCredential] = React.useState<any>(credentialValue)
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
fetchBuiltInToolCredentialSchema(collectionName).then((res) => {
|
fetchBuiltInToolCredentialSchema(collectionName).then((res) => {
|
||||||
setCredentialSchema(toolCredentialToFormSchemas(res))
|
const toolCredentialSchemas = toolCredentialToFormSchemas(res)
|
||||||
|
const defaultCredentials = addDefaultValue(credentialValue, toolCredentialSchemas)
|
||||||
|
setCredentialSchema(toolCredentialSchemas)
|
||||||
|
setTempCredential(defaultCredentials)
|
||||||
})
|
})
|
||||||
}, [])
|
}, [])
|
||||||
const [tempCredential, setTempCredential] = React.useState<any>(credentialValue)
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Drawer
|
<Drawer
|
||||||
|
|
Loading…
Reference in New Issue
Block a user