mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 03:32:23 +08:00
feat: Add tools for open weather search and image generation using the Spark API. (#2845)
This commit is contained in:
parent
4502436c47
commit
cb79a90031
|
@ -124,7 +124,7 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
|
|||
elif err == 'insufficient_quota':
|
||||
raise InsufficientAccountBalance(msg)
|
||||
elif err == 'invalid_authentication':
|
||||
raise InvalidAuthenticationError(msg)
|
||||
raise InvalidAuthenticationError(msg)
|
||||
elif err and 'rate' in err:
|
||||
raise RateLimitReachedError(msg)
|
||||
elif err and 'internal' in err:
|
||||
|
|
12
api/core/tools/provider/builtin/openweather/_assets/icon.svg
Normal file
12
api/core/tools/provider/builtin/openweather/_assets/icon.svg
Normal file
|
@ -0,0 +1,12 @@
|
|||
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<g clip-path="url(#clip0_16624_62807)">
|
||||
<path d="M7.11111 0.888889C7.11111 0.888889 7.11111 0 8 0C8.88889 0 8.88889 0.888889 8.88889 0.888889V1.77778C8.88889 1.77778 8.88889 2.66667 8 2.66667C7.11111 2.66667 7.11111 1.77778 7.11111 1.77778V0.888889ZM15.1111 7.11111C15.1111 7.11111 16 7.11111 16 8C16 8.88889 15.1111 8.88889 15.1111 8.88889H14.2222C14.2222 8.88889 13.3333 8.88889 13.3333 8C13.3333 7.11111 14.2222 7.11111 14.2222 7.11111H15.1111ZM1.77778 7.11111C1.77778 7.11111 2.66667 7.11111 2.66667 8C2.66667 8.88889 1.77778 8.88889 1.77778 8.88889H0.888889C0.888889 8.88889 0 8.88889 0 8C0 7.11111 0.888889 7.11111 0.888889 7.11111H1.77778ZM4.05378 3.24133C4.05378 3.24133 4.68222 3.86978 4.05378 4.49822C3.42533 5.12667 2.79689 4.49822 2.79689 4.49822L2.168 3.87022C2.168 3.87022 1.53956 3.24178 2.168 2.61289C2.79689 1.98444 3.42533 2.61289 3.42533 2.61289L4.05378 3.24133ZM13.2036 4.49822C13.2036 4.49822 12.5751 5.12667 11.9467 4.49822C11.3182 3.86978 11.9467 3.24133 11.9467 3.24133L12.5751 2.61289C12.5751 2.61289 13.2036 1.98444 13.832 2.61289C14.4604 3.24133 13.832 3.86978 13.832 3.86978L13.2036 4.49822ZM3.87022 13.8316C3.87022 13.8316 3.24178 14.46 2.61333 13.8316C1.98489 13.2031 2.61333 12.5747 2.61333 12.5747L3.24178 11.9462C3.24178 11.9462 3.87022 11.3178 4.49867 11.9462C5.12711 12.5747 4.49867 13.2031 4.49867 13.2031L3.87022 13.8316Z" fill="#FFCF27"/>
|
||||
<path d="M8.00011 12.4446C10.4547 12.4446 12.4446 10.4547 12.4446 8.00011C12.4446 5.54551 10.4547 3.55566 8.00011 3.55566C5.54551 3.55566 3.55566 5.54551 3.55566 8.00011C3.55566 10.4547 5.54551 12.4446 8.00011 12.4446Z" fill="#FFCB13"/>
|
||||
<path d="M13.2343 10.3111C12.949 10.3111 12.6743 10.3556 12.4152 10.4378C12.1094 9.53647 11.2774 8.88892 10.2966 8.88892C9.24411 8.88892 8.36322 9.63469 8.11922 10.6387C7.85878 10.436 7.53744 10.3116 7.18544 10.3116C6.32633 10.3116 5.62989 11.0276 5.62989 11.9116C5.62989 12.1262 5.67255 12.3298 5.74722 12.5174C5.59878 12.4742 5.44544 12.4445 5.28411 12.4445C4.32944 12.4445 3.55566 13.2405 3.55566 14.2222C3.55566 15.204 4.32944 16 5.28411 16H13.2348C14.7619 16 16.0001 14.7271 16.0001 13.1556C16.0001 11.5845 14.7619 10.3111 13.2343 10.3111Z" fill="#E9F6FF"/>
|
||||
</g>
|
||||
<defs>
|
||||
<clipPath id="clip0_16624_62807">
|
||||
<rect width="16" height="16" fill="white"/>
|
||||
</clipPath>
|
||||
</defs>
|
||||
</svg>
|
After Width: | Height: | Size: 2.4 KiB |
36
api/core/tools/provider/builtin/openweather/openweather.py
Normal file
36
api/core/tools/provider/builtin/openweather/openweather.py
Normal file
|
@ -0,0 +1,36 @@
|
|||
import requests
|
||||
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
def query_weather(city="Beijing", units="metric", language="zh_cn", api_key=None):
|
||||
|
||||
url = "https://api.openweathermap.org/data/2.5/weather"
|
||||
params = {"q": city, "appid": api_key, "units": units, "lang": language}
|
||||
|
||||
return requests.get(url, params=params)
|
||||
|
||||
|
||||
class OpenweatherProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
try:
|
||||
if "api_key" not in credentials or not credentials.get("api_key"):
|
||||
raise ToolProviderCredentialValidationError(
|
||||
"Open weather API key is required."
|
||||
)
|
||||
apikey = credentials.get("api_key")
|
||||
try:
|
||||
response = query_weather(api_key=apikey)
|
||||
if response.status_code == 200:
|
||||
pass
|
||||
else:
|
||||
raise ToolProviderCredentialValidationError(
|
||||
(response.json()).get("info")
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(
|
||||
"Open weather API Key is invalid. {}".format(e)
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
29
api/core/tools/provider/builtin/openweather/openweather.yaml
Normal file
29
api/core/tools/provider/builtin/openweather/openweather.yaml
Normal file
|
@ -0,0 +1,29 @@
|
|||
identity:
|
||||
author: Onelevenvy
|
||||
name: openweather
|
||||
label:
|
||||
en_US: Open weather query
|
||||
zh_Hans: Open Weather
|
||||
pt_BR: Consulta de clima open weather
|
||||
description:
|
||||
en_US: Weather query toolkit based on Open Weather
|
||||
zh_Hans: 基于open weather的天气查询工具包
|
||||
pt_BR: Kit de consulta de clima baseado no Open Weather
|
||||
icon: icon.svg
|
||||
credentials_for_provider:
|
||||
api_key:
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: API Key
|
||||
zh_Hans: API Key
|
||||
pt_BR: Fogo a chave
|
||||
placeholder:
|
||||
en_US: Please enter your open weather API Key
|
||||
zh_Hans: 请输入你的open weather API Key
|
||||
pt_BR: Insira sua chave de API open weather
|
||||
help:
|
||||
en_US: Get your API Key from open weather
|
||||
zh_Hans: 从open weather获取您的 API Key
|
||||
pt_BR: Obtenha sua chave de API do open weather
|
||||
url: https://openweathermap.org
|
60
api/core/tools/provider/builtin/openweather/tools/weather.py
Normal file
60
api/core/tools/provider/builtin/openweather/tools/weather.py
Normal file
|
@ -0,0 +1,60 @@
|
|||
import json
|
||||
from typing import Any, Union
|
||||
|
||||
import requests
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class OpenweatherTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
city = tool_parameters.get("city", "")
|
||||
if not city:
|
||||
return self.create_text_message("Please tell me your city")
|
||||
if (
|
||||
"api_key" not in self.runtime.credentials
|
||||
or not self.runtime.credentials.get("api_key")
|
||||
):
|
||||
return self.create_text_message("OpenWeather API key is required.")
|
||||
|
||||
units = tool_parameters.get("units", "metric")
|
||||
lang = tool_parameters.get("lang", "zh_cn")
|
||||
try:
|
||||
# request URL
|
||||
url = "https://api.openweathermap.org/data/2.5/weather"
|
||||
|
||||
# request parmas
|
||||
params = {
|
||||
"q": city,
|
||||
"appid": self.runtime.credentials.get("api_key"),
|
||||
"units": units,
|
||||
"lang": lang,
|
||||
}
|
||||
response = requests.get(url, params=params)
|
||||
|
||||
if response.status_code == 200:
|
||||
|
||||
data = response.json()
|
||||
return self.create_text_message(
|
||||
self.summary(
|
||||
user_id=user_id, content=json.dumps(data, ensure_ascii=False)
|
||||
)
|
||||
)
|
||||
else:
|
||||
error_message = {
|
||||
"error": f"failed:{response.status_code}",
|
||||
"data": response.text,
|
||||
}
|
||||
# return error
|
||||
return json.dumps(error_message)
|
||||
|
||||
except Exception as e:
|
||||
return self.create_text_message(
|
||||
"Openweather API Key is invalid. {}".format(e)
|
||||
)
|
|
@ -0,0 +1,80 @@
|
|||
identity:
|
||||
name: weather
|
||||
author: Onelevenvy
|
||||
label:
|
||||
en_US: Open Weather Query
|
||||
zh_Hans: 天气查询
|
||||
pt_BR: Previsão do tempo
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: Weather forecast inquiry
|
||||
zh_Hans: 天气查询
|
||||
pt_BR: Inquérito sobre previsão meteorológica
|
||||
llm: A tool when you want to ask about the weather or weather-related question
|
||||
parameters:
|
||||
- name: city
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: city
|
||||
zh_Hans: 城市
|
||||
pt_BR: cidade
|
||||
human_description:
|
||||
en_US: Target city for weather forecast query
|
||||
zh_Hans: 天气预报查询的目标城市
|
||||
pt_BR: Cidade de destino para consulta de previsão do tempo
|
||||
llm_description: If you don't know you can extract the city name from the
|
||||
question or you can reply:Please tell me your city. You have to extract
|
||||
the Chinese city name from the question.If the input region is in Chinese
|
||||
characters for China, it should be replaced with the corresponding English
|
||||
name, such as '北京' for correct input is 'Beijing'
|
||||
form: llm
|
||||
- name: lang
|
||||
type: select
|
||||
required: true
|
||||
human_description:
|
||||
en_US: language
|
||||
zh_Hans: 语言
|
||||
pt_BR: language
|
||||
label:
|
||||
en_US: language
|
||||
zh_Hans: 语言
|
||||
pt_BR: language
|
||||
form: form
|
||||
options:
|
||||
- value: zh_cn
|
||||
label:
|
||||
en_US: cn
|
||||
zh_Hans: 中国
|
||||
pt_BR: cn
|
||||
- value: en_us
|
||||
label:
|
||||
en_US: usa
|
||||
zh_Hans: 美国
|
||||
pt_BR: usa
|
||||
default: zh_cn
|
||||
- name: units
|
||||
type: select
|
||||
required: true
|
||||
human_description:
|
||||
en_US: units for temperature
|
||||
zh_Hans: 温度单位
|
||||
pt_BR: units for temperature
|
||||
label:
|
||||
en_US: units
|
||||
zh_Hans: 单位
|
||||
pt_BR: units
|
||||
form: form
|
||||
options:
|
||||
- value: metric
|
||||
label:
|
||||
en_US: metric
|
||||
zh_Hans: ℃
|
||||
pt_BR: metric
|
||||
- value: imperial
|
||||
label:
|
||||
en_US: imperial
|
||||
zh_Hans: ℉
|
||||
pt_BR: imperial
|
||||
default: metric
|
0
api/core/tools/provider/builtin/spark/__init__.py
Normal file
0
api/core/tools/provider/builtin/spark/__init__.py
Normal file
5
api/core/tools/provider/builtin/spark/_assets/icon.svg
Normal file
5
api/core/tools/provider/builtin/spark/_assets/icon.svg
Normal file
|
@ -0,0 +1,5 @@
|
|||
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M21.6547 16.7993C21.3111 18.0034 20.7384 19.0938 20.0054 20.048C18.9058 21.4111 15.1261 21.4111 12.8583 20.8204C10.4072 20.1616 8.6433 18.6395 8.50586 18.5259C9.46797 19.2756 10.6821 19.7072 12.0107 19.7072C15.1948 19.7072 17.7605 17.1174 17.7605 13.9368C17.7605 12.9826 17.5314 12.0966 17.119 11.3015C17.0961 11.2561 17.1419 11.2106 17.1649 11.2333C18.9745 11.5287 22.571 13.2098 21.6547 16.7993Z" fill="#2751D0"/>
|
||||
<path d="M21.9994 12.7773C21.9994 12.8454 21.9306 12.8682 21.8848 12.8C21.0372 11.0053 19.5483 10.46 17.7615 10.0511C16.4099 9.75577 15.5166 9.3014 15.1271 9.09694C15.0355 9.0515 14.9668 8.98335 14.8751 8.93791C12.0575 7.23404 12.0117 4.30339 12.0117 4.30339V0.0550813C12.0117 0.00964486 12.0804 -0.0130733 12.1034 0.0096449L18.7694 6.50706L19.2734 6.98414C20.7394 8.52898 21.7474 10.5509 21.9994 12.7773Z" fill="#D82F20"/>
|
||||
<path d="M20.0052 20.0462C18.1726 22.4316 15.2863 23.9992 12.0334 23.9992C6.48985 23.9992 2 19.501 2 13.9577C2 11.2543 3.05374 8.8234 4.7947 7.00594L5.29866 6.50614L9.65107 2.25783C9.69688 2.2124 9.7656 2.25783 9.7427 2.30327C9.67397 2.59861 9.55944 3.28015 9.62816 4.18888C9.71979 5.25664 10.0634 6.68789 11.0713 8.27817C11.6898 9.27777 12.5832 10.3228 13.8202 11.4133C13.9577 11.5496 14.118 11.6632 14.2784 11.7995C14.8281 12.3674 15.1488 13.1171 15.1488 13.9577C15.1488 15.6616 13.7515 17.0474 12.0563 17.0474C11.3233 17.0474 10.659 16.7975 10.1321 16.3659C10.0863 16.3204 10.1321 16.2523 10.1779 16.275C10.2925 16.2977 10.407 16.3204 10.5215 16.3204C11.1171 16.3204 11.6211 15.8433 11.6211 15.2299C11.6211 14.8665 11.4378 14.5257 11.163 14.3439C10.4299 13.7533 9.81142 13.1853 9.28455 12.6173C8.55151 11.8222 8.00174 11.0498 7.61231 10.3001C6.81055 11.2997 6.30659 12.5492 6.30659 13.935C6.30659 15.7979 7.17707 17.4563 8.55152 18.5014C8.68896 18.615 10.4528 20.1371 12.9039 20.7959C15.1259 21.432 18.9057 21.4093 20.0052 20.0462Z" fill="#69C5F4"/>
|
||||
</svg>
|
After Width: | Height: | Size: 2.0 KiB |
40
api/core/tools/provider/builtin/spark/spark.py
Normal file
40
api/core/tools/provider/builtin/spark/spark.py
Normal file
|
@ -0,0 +1,40 @@
|
|||
import json
|
||||
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.spark.tools.spark_img_generation import spark_response
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class SparkProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
try:
|
||||
if "APPID" not in credentials or not credentials.get("APPID"):
|
||||
raise ToolProviderCredentialValidationError("APPID is required.")
|
||||
if "APISecret" not in credentials or not credentials.get("APISecret"):
|
||||
raise ToolProviderCredentialValidationError("APISecret is required.")
|
||||
if "APIKey" not in credentials or not credentials.get("APIKey"):
|
||||
raise ToolProviderCredentialValidationError("APIKey is required.")
|
||||
|
||||
appid = credentials.get("APPID")
|
||||
apisecret = credentials.get("APISecret")
|
||||
apikey = credentials.get("APIKey")
|
||||
prompt = "a cute black dog"
|
||||
|
||||
try:
|
||||
response = spark_response(prompt, appid, apikey, apisecret)
|
||||
data = json.loads(response)
|
||||
code = data["header"]["code"]
|
||||
|
||||
if code == 0:
|
||||
# 0 success,
|
||||
pass
|
||||
else:
|
||||
raise ToolProviderCredentialValidationError(
|
||||
"image generate error, code:{}".format(code)
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(
|
||||
"APPID APISecret APIKey is invalid. {}".format(e)
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
59
api/core/tools/provider/builtin/spark/spark.yaml
Normal file
59
api/core/tools/provider/builtin/spark/spark.yaml
Normal file
|
@ -0,0 +1,59 @@
|
|||
identity:
|
||||
author: Onelevenvy
|
||||
name: spark
|
||||
label:
|
||||
en_US: Spark
|
||||
zh_Hans: 讯飞星火
|
||||
pt_BR: Spark
|
||||
description:
|
||||
en_US: Spark Platform Toolkit
|
||||
zh_Hans: 讯飞星火平台工具
|
||||
pt_BR: Pacote de Ferramentas da Plataforma Spark
|
||||
icon: icon.svg
|
||||
credentials_for_provider:
|
||||
APPID:
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: Spark APPID
|
||||
zh_Hans: APPID
|
||||
pt_BR: Spark APPID
|
||||
help:
|
||||
en_US: Please input your APPID
|
||||
zh_Hans: 请输入你的 APPID
|
||||
pt_BR: Please input your APPID
|
||||
placeholder:
|
||||
en_US: Please input your APPID
|
||||
zh_Hans: 请输入你的 APPID
|
||||
pt_BR: Please input your APPID
|
||||
APISecret:
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: Spark APISecret
|
||||
zh_Hans: APISecret
|
||||
pt_BR: Spark APISecret
|
||||
help:
|
||||
en_US: Please input your Spark APISecret
|
||||
zh_Hans: 请输入你的 APISecret
|
||||
pt_BR: Please input your Spark APISecret
|
||||
placeholder:
|
||||
en_US: Please input your Spark APISecret
|
||||
zh_Hans: 请输入你的 APISecret
|
||||
pt_BR: Please input your Spark APISecret
|
||||
APIKey:
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: Spark APIKey
|
||||
zh_Hans: APIKey
|
||||
pt_BR: Spark APIKey
|
||||
help:
|
||||
en_US: Please input your Spark APIKey
|
||||
zh_Hans: 请输入你的 APIKey
|
||||
pt_BR: Please input your Spark APIKey
|
||||
placeholder:
|
||||
en_US: Please input your Spark APIKey
|
||||
zh_Hans: 请输入你的 APIKey
|
||||
pt_BR: Please input Spark APIKey
|
||||
url: https://console.xfyun.cn/services
|
|
@ -0,0 +1,154 @@
|
|||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
from base64 import b64decode
|
||||
from datetime import datetime
|
||||
from time import mktime
|
||||
from typing import Any, Union
|
||||
from urllib.parse import urlencode
|
||||
from wsgiref.handlers import format_date_time
|
||||
|
||||
import requests
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class AssembleHeaderException(Exception):
|
||||
def __init__(self, msg):
|
||||
self.message = msg
|
||||
|
||||
|
||||
class Url:
|
||||
def __init__(this, host, path, schema):
|
||||
this.host = host
|
||||
this.path = path
|
||||
this.schema = schema
|
||||
|
||||
|
||||
# calculate sha256 and encode to base64
|
||||
def sha256base64(data):
|
||||
sha256 = hashlib.sha256()
|
||||
sha256.update(data)
|
||||
digest = base64.b64encode(sha256.digest()).decode(encoding="utf-8")
|
||||
return digest
|
||||
|
||||
|
||||
def parse_url(requset_url):
|
||||
stidx = requset_url.index("://")
|
||||
host = requset_url[stidx + 3 :]
|
||||
schema = requset_url[: stidx + 3]
|
||||
edidx = host.index("/")
|
||||
if edidx <= 0:
|
||||
raise AssembleHeaderException("invalid request url:" + requset_url)
|
||||
path = host[edidx:]
|
||||
host = host[:edidx]
|
||||
u = Url(host, path, schema)
|
||||
return u
|
||||
|
||||
def assemble_ws_auth_url(requset_url, method="GET", api_key="", api_secret=""):
|
||||
u = parse_url(requset_url)
|
||||
host = u.host
|
||||
path = u.path
|
||||
now = datetime.now()
|
||||
date = format_date_time(mktime(now.timetuple()))
|
||||
signature_origin = "host: {}\ndate: {}\n{} {} HTTP/1.1".format(
|
||||
host, date, method, path
|
||||
)
|
||||
signature_sha = hmac.new(
|
||||
api_secret.encode("utf-8"),
|
||||
signature_origin.encode("utf-8"),
|
||||
digestmod=hashlib.sha256,
|
||||
).digest()
|
||||
signature_sha = base64.b64encode(signature_sha).decode(encoding="utf-8")
|
||||
authorization_origin = f'api_key="{api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha}"'
|
||||
|
||||
authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(
|
||||
encoding="utf-8"
|
||||
)
|
||||
values = {"host": host, "date": date, "authorization": authorization}
|
||||
|
||||
return requset_url + "?" + urlencode(values)
|
||||
|
||||
|
||||
def get_body(appid, text):
|
||||
body = {
|
||||
"header": {"app_id": appid, "uid": "123456789"},
|
||||
"parameter": {
|
||||
"chat": {"domain": "general", "temperature": 0.5, "max_tokens": 4096}
|
||||
},
|
||||
"payload": {"message": {"text": [{"role": "user", "content": text}]}},
|
||||
}
|
||||
return body
|
||||
|
||||
|
||||
def spark_response(text, appid, apikey, apisecret):
|
||||
host = "http://spark-api.cn-huabei-1.xf-yun.com/v2.1/tti"
|
||||
url = assemble_ws_auth_url(
|
||||
host, method="POST", api_key=apikey, api_secret=apisecret
|
||||
)
|
||||
content = get_body(appid, text)
|
||||
response = requests.post(
|
||||
url, json=content, headers={"content-type": "application/json"}
|
||||
).text
|
||||
return response
|
||||
|
||||
|
||||
class SparkImgGeneratorTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
|
||||
if "APPID" not in self.runtime.credentials or not self.runtime.credentials.get(
|
||||
"APPID"
|
||||
):
|
||||
return self.create_text_message("APPID is required.")
|
||||
if (
|
||||
"APISecret" not in self.runtime.credentials
|
||||
or not self.runtime.credentials.get("APISecret")
|
||||
):
|
||||
return self.create_text_message("APISecret is required.")
|
||||
if (
|
||||
"APIKey" not in self.runtime.credentials
|
||||
or not self.runtime.credentials.get("APIKey")
|
||||
):
|
||||
return self.create_text_message("APIKey is required.")
|
||||
|
||||
prompt = tool_parameters.get("prompt", "")
|
||||
if not prompt:
|
||||
return self.create_text_message("Please input prompt")
|
||||
res = self.img_generation(prompt)
|
||||
result = []
|
||||
for image in res:
|
||||
result.append(
|
||||
self.create_blob_message(
|
||||
blob=b64decode(image["base64_image"]),
|
||||
meta={"mime_type": "image/png"},
|
||||
save_as=self.VARIABLE_KEY.IMAGE.value,
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
||||
def img_generation(self, prompt):
|
||||
response = spark_response(
|
||||
text=prompt,
|
||||
appid=self.runtime.credentials.get("APPID"),
|
||||
apikey=self.runtime.credentials.get("APIKey"),
|
||||
apisecret=self.runtime.credentials.get("APISecret"),
|
||||
)
|
||||
data = json.loads(response)
|
||||
code = data["header"]["code"]
|
||||
if code != 0:
|
||||
return self.create_text_message(f"error: {code}, {data}")
|
||||
else:
|
||||
text = data["payload"]["choices"]["text"]
|
||||
image_content = text[0]
|
||||
image_base = image_content["content"]
|
||||
json_data = {"base64_image": image_base}
|
||||
return [json_data]
|
|
@ -0,0 +1,36 @@
|
|||
identity:
|
||||
name: spark_img_generation
|
||||
author: Onelevenvy
|
||||
label:
|
||||
en_US: Spark Image Generation
|
||||
zh_Hans: 图片生成
|
||||
pt_BR: Geração de imagens Spark
|
||||
icon: icon.svg
|
||||
description:
|
||||
en_US: Spark Image Generation
|
||||
zh_Hans: 图片生成
|
||||
pt_BR: Geração de imagens Spark
|
||||
description:
|
||||
human:
|
||||
en_US: Generate images based on user input, with image generation API
|
||||
provided by Spark
|
||||
zh_Hans: 根据用户的输入生成图片,由讯飞星火提供图片生成api
|
||||
pt_BR: Gerar imagens com base na entrada do usuário, com API de geração
|
||||
de imagem fornecida pela Spark
|
||||
llm: spark_img_generation is a tool used to generate images from text
|
||||
parameters:
|
||||
- name: prompt
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Prompt
|
||||
zh_Hans: 提示词
|
||||
pt_BR: Prompt
|
||||
human_description:
|
||||
en_US: Image prompt
|
||||
zh_Hans: 图像提示词
|
||||
pt_BR: Image prompt
|
||||
llm_description: Image prompt of spark_img_generation tooll, you should
|
||||
describe the image you want to generate as a list of words as possible
|
||||
as detailed
|
||||
form: llm
|
|
@ -1 +1 @@
|
|||
from dify_client.client import ChatClient, CompletionClient, DifyClient
|
||||
from dify_client.client import ChatClient, CompletionClient, DifyClient
|
Loading…
Reference in New Issue
Block a user