diff --git a/api/core/tools/provider/builtin/xinference/tools/stable_diffusion.py b/api/core/tools/provider/builtin/xinference/tools/stable_diffusion.py index 847f2730f2..a44d3b730a 100644 --- a/api/core/tools/provider/builtin/xinference/tools/stable_diffusion.py +++ b/api/core/tools/provider/builtin/xinference/tools/stable_diffusion.py @@ -104,14 +104,15 @@ class StableDiffusionTool(BuiltinTool): model = self.runtime.credentials.get("model", None) if not model: return self.create_text_message("Please input model") - + api_key = self.runtime.credentials.get("api_key") or "abc" + headers = {"Authorization": f"Bearer {api_key}"} # set model try: url = str(URL(base_url) / "sdapi" / "v1" / "options") response = post( url, json={"sd_model_checkpoint": model}, - headers={"Authorization": f"Bearer {self.runtime.credentials['api_key']}"}, + headers=headers, ) if response.status_code != 200: raise ToolProviderCredentialValidationError("Failed to set model, please tell user to set model") @@ -257,14 +258,15 @@ class StableDiffusionTool(BuiltinTool): draw_options["prompt"] = f"{lora},{prompt}" else: draw_options["prompt"] = prompt - + api_key = self.runtime.credentials.get("api_key") or "abc" + headers = {"Authorization": f"Bearer {api_key}"} try: url = str(URL(base_url) / "sdapi" / "v1" / "img2img") response = post( url, json=draw_options, timeout=120, - headers={"Authorization": f"Bearer {self.runtime.credentials['api_key']}"}, + headers=headers, ) if response.status_code != 200: return self.create_text_message("Failed to generate image") @@ -298,14 +300,15 @@ class StableDiffusionTool(BuiltinTool): else: draw_options["prompt"] = prompt draw_options["override_settings"]["sd_model_checkpoint"] = model - + api_key = self.runtime.credentials.get("api_key") or "abc" + headers = {"Authorization": f"Bearer {api_key}"} try: url = str(URL(base_url) / "sdapi" / "v1" / "txt2img") response = post( url, json=draw_options, timeout=120, - headers={"Authorization": f"Bearer {self.runtime.credentials['api_key']}"}, + headers=headers, ) if response.status_code != 200: return self.create_text_message("Failed to generate image") diff --git a/api/core/tools/provider/builtin/xinference/xinference.py b/api/core/tools/provider/builtin/xinference/xinference.py index 7c2428cc00..9692e4060e 100644 --- a/api/core/tools/provider/builtin/xinference/xinference.py +++ b/api/core/tools/provider/builtin/xinference/xinference.py @@ -6,12 +6,18 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl class XinferenceProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: - base_url = credentials.get("base_url") - api_key = credentials.get("api_key") - model = credentials.get("model") + base_url = credentials.get("base_url", "").removesuffix("/") + api_key = credentials.get("api_key", "") + if not api_key: + api_key = "abc" + credentials["api_key"] = api_key + model = credentials.get("model", "") + if not base_url or not model: + raise ToolProviderCredentialValidationError("Xinference base_url and model is required") + headers = {"Authorization": f"Bearer {api_key}"} res = requests.post( f"{base_url}/sdapi/v1/options", - headers={"Authorization": f"Bearer {api_key}"}, + headers=headers, json={"sd_model_checkpoint": model}, ) if res.status_code != 200: diff --git a/api/core/tools/provider/builtin/xinference/xinference.yaml b/api/core/tools/provider/builtin/xinference/xinference.yaml index 19aaf5cbd1..b0c02b9cbc 100644 --- a/api/core/tools/provider/builtin/xinference/xinference.yaml +++ b/api/core/tools/provider/builtin/xinference/xinference.yaml @@ -31,7 +31,7 @@ credentials_for_provider: zh_Hans: 请输入你的模型名称 api_key: type: secret-input - required: true + required: false label: en_US: API Key zh_Hans: Xinference 服务器的 API Key