feat(vannaai): add base_url configuration (#10294)
Some checks are pending
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions

This commit is contained in:
Benjamin 2024-11-05 20:58:49 +08:00 committed by GitHub
parent 1279e27825
commit d7b4d0756e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 31 additions and 2 deletions

View File

@ -35,7 +35,8 @@ class VannaTool(BuiltinTool):
password = tool_parameters.get("password", "")
port = tool_parameters.get("port", 0)
vn = VannaDefault(model=model, api_key=api_key)
base_url = self.runtime.credentials.get("base_url", None)
vn = VannaDefault(model=model, api_key=api_key, config={"endpoint": base_url})
db_type = tool_parameters.get("db_type", "")
if db_type in {"Postgres", "MySQL", "Hive", "ClickHouse"}:

View File

@ -1,4 +1,6 @@
import re
from typing import Any
from urllib.parse import urlparse
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.vanna.tools.vanna import VannaTool
@ -6,7 +8,26 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl
class VannaProvider(BuiltinToolProviderController):
def _get_protocol_and_main_domain(self, url):
parsed_url = urlparse(url)
protocol = parsed_url.scheme
hostname = parsed_url.hostname
port = f":{parsed_url.port}" if parsed_url.port else ""
# Check if the hostname is an IP address
is_ip = re.match(r"^\d{1,3}(\.\d{1,3}){3}$", hostname) is not None
# Return the full hostname (with port if present) for IP addresses, otherwise return the main domain
main_domain = f"{hostname}{port}" if is_ip else ".".join(hostname.split(".")[-2:]) + port
return f"{protocol}://{main_domain}"
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
base_url = credentials.get("base_url")
if not base_url:
base_url = "https://ask.vanna.ai/rpc"
else:
base_url = base_url.removesuffix("/")
credentials["base_url"] = base_url
try:
VannaTool().fork_tool_runtime(
runtime={
@ -17,7 +38,7 @@ class VannaProvider(BuiltinToolProviderController):
tool_parameters={
"model": "chinook",
"db_type": "SQLite",
"url": "https://vanna.ai/Chinook.sqlite",
"url": f'{self._get_protocol_and_main_domain(credentials["base_url"])}/Chinook.sqlite',
"query": "What are the top 10 customers by sales?",
},
)

View File

@ -26,3 +26,10 @@ credentials_for_provider:
en_US: Get your API key from Vanna.AI
zh_Hans: 从 Vanna.AI 获取你的 API key
url: https://vanna.ai/account/profile
base_url:
type: text-input
required: false
label:
en_US: Vanna.AI Endpoint Base URL
placeholder:
en_US: https://ask.vanna.ai/rpc