From d3c06a3f76e73ed3afe4cbf0504467809a60b5df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=9E=E6=B3=95=E6=93=8D=E4=BD=9C?= Date: Thu, 17 Oct 2024 16:48:42 +0800 Subject: [PATCH 01/12] feat: add the workflow tool of comfyUI (#9447) --- .../tools/provider/builtin/comfyui/comfyui.py | 18 +-- .../provider/builtin/comfyui/comfyui.yaml | 25 +---- .../builtin/comfyui/tools/comfyui_client.py | 105 ++++++++++++++++++ .../tools/comfyui_stable_diffusion.yaml | 8 +- .../builtin/comfyui/tools/comfyui_workflow.py | 32 ++++++ .../comfyui/tools/comfyui_workflow.yaml | 35 ++++++ 6 files changed, 190 insertions(+), 33 deletions(-) create mode 100644 api/core/tools/provider/builtin/comfyui/tools/comfyui_client.py create mode 100644 api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.py create mode 100644 api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.yaml diff --git a/api/core/tools/provider/builtin/comfyui/comfyui.py b/api/core/tools/provider/builtin/comfyui/comfyui.py index 7013a0b93c..bab690af82 100644 --- a/api/core/tools/provider/builtin/comfyui/comfyui.py +++ b/api/core/tools/provider/builtin/comfyui/comfyui.py @@ -1,17 +1,21 @@ from typing import Any +import websocket +from yarl import URL + from core.tools.errors import ToolProviderCredentialValidationError -from core.tools.provider.builtin.comfyui.tools.comfyui_stable_diffusion import ComfyuiStableDiffusionTool from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController class ComfyUIProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: + ws = websocket.WebSocket() + base_url = URL(credentials.get("base_url")) + ws_address = f"ws://{base_url.authority}/ws?clientId=test123" + try: - ComfyuiStableDiffusionTool().fork_tool_runtime( - runtime={ - "credentials": credentials, - } - ).validate_models() + ws.connect(ws_address) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) + raise ToolProviderCredentialValidationError(f"can not connect to {ws_address}") + finally: + ws.close() diff --git a/api/core/tools/provider/builtin/comfyui/comfyui.yaml b/api/core/tools/provider/builtin/comfyui/comfyui.yaml index 3891eebf3a..24ae43cd44 100644 --- a/api/core/tools/provider/builtin/comfyui/comfyui.yaml +++ b/api/core/tools/provider/builtin/comfyui/comfyui.yaml @@ -4,11 +4,9 @@ identity: label: en_US: ComfyUI zh_Hans: ComfyUI - pt_BR: ComfyUI description: en_US: ComfyUI is a tool for generating images which can be deployed locally. zh_Hans: ComfyUI 是一个可以在本地部署的图片生成的工具。 - pt_BR: ComfyUI is a tool for generating images which can be deployed locally. icon: icon.png tags: - image @@ -17,26 +15,9 @@ credentials_for_provider: type: text-input required: true label: - en_US: Base URL - zh_Hans: ComfyUI服务器的Base URL - pt_BR: Base URL + en_US: The URL of ComfyUI Server + zh_Hans: ComfyUI服务器的URL placeholder: en_US: Please input your ComfyUI server's Base URL zh_Hans: 请输入你的 ComfyUI 服务器的 Base URL - pt_BR: Please input your ComfyUI server's Base URL - model: - type: text-input - required: true - label: - en_US: Model with suffix - zh_Hans: 模型, 需要带后缀 - pt_BR: Model with suffix - placeholder: - en_US: Please input your model - zh_Hans: 请输入你的模型名称 - pt_BR: Please input your model - help: - en_US: The checkpoint name of the ComfyUI server, e.g. xxx.safetensors - zh_Hans: ComfyUI服务器的模型名称, 比如 xxx.safetensors - pt_BR: The checkpoint name of the ComfyUI server, e.g. xxx.safetensors - url: https://github.com/comfyanonymous/ComfyUI#installing + url: https://docs.dify.ai/guides/tools/tool-configuration/comfyui diff --git a/api/core/tools/provider/builtin/comfyui/tools/comfyui_client.py b/api/core/tools/provider/builtin/comfyui/tools/comfyui_client.py new file mode 100644 index 0000000000..a41d34d40f --- /dev/null +++ b/api/core/tools/provider/builtin/comfyui/tools/comfyui_client.py @@ -0,0 +1,105 @@ +import json +import random +import uuid + +import httpx +from websocket import WebSocket +from yarl import URL + + +class ComfyUiClient: + def __init__(self, base_url: str): + self.base_url = URL(base_url) + + def get_history(self, prompt_id: str): + res = httpx.get(str(self.base_url / "history"), params={"prompt_id": prompt_id}) + history = res.json()[prompt_id] + return history + + def get_image(self, filename: str, subfolder: str, folder_type: str): + response = httpx.get( + str(self.base_url / "view"), + params={"filename": filename, "subfolder": subfolder, "type": folder_type}, + ) + return response.content + + def upload_image(self, input_path: str, name: str, image_type: str = "input", overwrite: bool = False): + # plan to support img2img in dify 0.10.0 + with open(input_path, "rb") as file: + files = {"image": (name, file, "image/png")} + data = {"type": image_type, "overwrite": str(overwrite).lower()} + + res = httpx.post(str(self.base_url / "upload/image"), data=data, files=files) + return res + + def queue_prompt(self, client_id: str, prompt: dict): + res = httpx.post(str(self.base_url / "prompt"), json={"client_id": client_id, "prompt": prompt}) + prompt_id = res.json()["prompt_id"] + return prompt_id + + def open_websocket_connection(self): + client_id = str(uuid.uuid4()) + ws = WebSocket() + ws_address = f"ws://{self.base_url.authority}/ws?clientId={client_id}" + ws.connect(ws_address) + return ws, client_id + + def set_prompt(self, origin_prompt: dict, positive_prompt: str, negative_prompt: str = ""): + """ + find the first KSampler, then can find the prompt node through it. + """ + prompt = origin_prompt.copy() + id_to_class_type = {id: details["class_type"] for id, details in prompt.items()} + k_sampler = [key for key, value in id_to_class_type.items() if value == "KSampler"][0] + prompt.get(k_sampler)["inputs"]["seed"] = random.randint(10**14, 10**15 - 1) + positive_input_id = prompt.get(k_sampler)["inputs"]["positive"][0] + prompt.get(positive_input_id)["inputs"]["text"] = positive_prompt + + if negative_prompt != "": + negative_input_id = prompt.get(k_sampler)["inputs"]["negative"][0] + prompt.get(negative_input_id)["inputs"]["text"] = negative_prompt + return prompt + + def track_progress(self, prompt: dict, ws: WebSocket, prompt_id: str): + node_ids = list(prompt.keys()) + finished_nodes = [] + + while True: + out = ws.recv() + if isinstance(out, str): + message = json.loads(out) + if message["type"] == "progress": + data = message["data"] + current_step = data["value"] + print("In K-Sampler -> Step: ", current_step, " of: ", data["max"]) + if message["type"] == "execution_cached": + data = message["data"] + for itm in data["nodes"]: + if itm not in finished_nodes: + finished_nodes.append(itm) + print("Progress: ", len(finished_nodes), "/", len(node_ids), " Tasks done") + if message["type"] == "executing": + data = message["data"] + if data["node"] not in finished_nodes: + finished_nodes.append(data["node"]) + print("Progress: ", len(finished_nodes), "/", len(node_ids), " Tasks done") + + if data["node"] is None and data["prompt_id"] == prompt_id: + break # Execution is done + else: + continue + + def generate_image_by_prompt(self, prompt: dict): + try: + ws, client_id = self.open_websocket_connection() + prompt_id = self.queue_prompt(client_id, prompt) + self.track_progress(prompt, ws, prompt_id) + history = self.get_history(prompt_id) + images = [] + for output in history["outputs"].values(): + for img in output.get("images", []): + image_data = self.get_image(img["filename"], img["subfolder"], img["type"]) + images.append(image_data) + return images + finally: + ws.close() diff --git a/api/core/tools/provider/builtin/comfyui/tools/comfyui_stable_diffusion.yaml b/api/core/tools/provider/builtin/comfyui/tools/comfyui_stable_diffusion.yaml index 4f4a6942b3..75fe746965 100644 --- a/api/core/tools/provider/builtin/comfyui/tools/comfyui_stable_diffusion.yaml +++ b/api/core/tools/provider/builtin/comfyui/tools/comfyui_stable_diffusion.yaml @@ -1,10 +1,10 @@ identity: - name: txt2img workflow + name: txt2img author: Qun label: - en_US: Txt2Img Workflow - zh_Hans: Txt2Img Workflow - pt_BR: Txt2Img Workflow + en_US: Txt2Img + zh_Hans: Txt2Img + pt_BR: Txt2Img description: human: en_US: a pre-defined comfyui workflow that can use one model and up to 3 loras to generate images. Support SD1.5, SDXL, SD3 and FLUX which contain text encoders/clip, but does not support models that requires a triple clip loader. diff --git a/api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.py b/api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.py new file mode 100644 index 0000000000..e4df9f8c3b --- /dev/null +++ b/api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.py @@ -0,0 +1,32 @@ +import json +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + +from .comfyui_client import ComfyUiClient + + +class ComfyUIWorkflowTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: + comfyui = ComfyUiClient(self.runtime.credentials["base_url"]) + + positive_prompt = tool_parameters.get("positive_prompt") + negative_prompt = tool_parameters.get("negative_prompt") + workflow = tool_parameters.get("workflow_json") + + try: + origin_prompt = json.loads(workflow) + except: + return self.create_text_message("the Workflow JSON is not correct") + + prompt = comfyui.set_prompt(origin_prompt, positive_prompt, negative_prompt) + images = comfyui.generate_image_by_prompt(prompt) + result = [] + for img in images: + result.append( + self.create_blob_message( + blob=img, meta={"mime_type": "image/png"}, save_as=self.VariableKey.IMAGE.value + ) + ) + return result diff --git a/api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.yaml b/api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.yaml new file mode 100644 index 0000000000..6342d6d468 --- /dev/null +++ b/api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.yaml @@ -0,0 +1,35 @@ +identity: + name: workflow + author: hjlarry + label: + en_US: workflow + zh_Hans: 工作流 +description: + human: + en_US: Run ComfyUI workflow. + zh_Hans: 运行ComfyUI工作流。 + llm: Run ComfyUI workflow. +parameters: + - name: positive_prompt + type: string + label: + en_US: Prompt + zh_Hans: 提示词 + llm_description: Image prompt, you should describe the image you want to generate as a list of words as possible as detailed, the prompt must be written in English. + form: llm + - name: negative_prompt + type: string + label: + en_US: Negative Prompt + zh_Hans: 负面提示词 + llm_description: Negative prompt, you should describe the image you don't want to generate as a list of words as possible as detailed, the prompt must be written in English. + form: llm + - name: workflow_json + type: string + required: true + label: + en_US: Workflow JSON + human_description: + en_US: exported from ComfyUI workflow + zh_Hans: 从ComfyUI的工作流中导出 + form: form From a45f8969a0742607cfc047fe8aa7985dbd7d8050 Mon Sep 17 00:00:00 2001 From: zhuhao <37029601+hwzhuhao@users.noreply.github.com> Date: Thu, 17 Oct 2024 17:25:14 +0800 Subject: [PATCH 02/12] fix: remove the undefined variable line (#9446) --- .../model_providers/sagemaker/speech2text/speech2text.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/api/core/model_runtime/model_providers/sagemaker/speech2text/speech2text.py b/api/core/model_runtime/model_providers/sagemaker/speech2text/speech2text.py index 6aa8c9995f..94bae71e53 100644 --- a/api/core/model_runtime/model_providers/sagemaker/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/sagemaker/speech2text/speech2text.py @@ -14,6 +14,7 @@ from core.model_runtime.errors.invoke import ( InvokeRateLimitError, InvokeServerUnavailableError, ) +from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel from core.model_runtime.model_providers.sagemaker.sagemaker import generate_presigned_url @@ -77,7 +78,8 @@ class SageMakerSpeech2TextModel(Speech2TextModel): json_obj = json.loads(json_str) asr_text = json_obj["text"] except Exception as e: - logger.exception(f"Exception {e}, line : {line}") + logger.exception(f"failed to invoke speech2text model, {e}") + raise CredentialsValidateFailedError(str(e)) return asr_text From e7aecb89dd792072fa8cc4c91bfd3c6157f4801e Mon Sep 17 00:00:00 2001 From: Kevin9703 <51311316+Kevin9703@users.noreply.github.com> Date: Thu, 17 Oct 2024 19:01:50 +0800 Subject: [PATCH 03/12] fix(workflow): Implement automatic variable addition from opening statement to start node (#9450) --- .../base/features/feature-panel/index.tsx | 4 +++ .../feature-panel/opening-statement/index.tsx | 11 ++++++-- web/app/components/workflow/features.tsx | 26 ++++++++++++++++++- 3 files changed, 38 insertions(+), 3 deletions(-) diff --git a/web/app/components/base/features/feature-panel/index.tsx b/web/app/components/base/features/feature-panel/index.tsx index e979391c92..72799ef2fc 100644 --- a/web/app/components/base/features/feature-panel/index.tsx +++ b/web/app/components/base/features/feature-panel/index.tsx @@ -13,16 +13,19 @@ import TextToSpeech from './text-to-speech' import SpeechToText from './speech-to-text' import Citation from './citation' import Moderation from './moderation' +import type { InputVar } from '@/app/components/workflow/types' export type FeaturePanelProps = { onChange?: OnFeaturesChange openingStatementProps: OpeningStatementProps disabled?: boolean + workflowVariables: InputVar[] } const FeaturePanel = ({ onChange, openingStatementProps, disabled, + workflowVariables, }: FeaturePanelProps) => { const { t } = useTranslation() const features = useFeatures(s => s.features) @@ -60,6 +63,7 @@ const FeaturePanel = ({ {...openingStatementProps} onChange={onChange} readonly={disabled} + workflowVariables={workflowVariables} /> ) } diff --git a/web/app/components/base/features/feature-panel/opening-statement/index.tsx b/web/app/components/base/features/feature-panel/opening-statement/index.tsx index b039165c9e..1f102700ad 100644 --- a/web/app/components/base/features/feature-panel/opening-statement/index.tsx +++ b/web/app/components/base/features/feature-panel/opening-statement/index.tsx @@ -24,6 +24,7 @@ import ConfirmAddVar from '@/app/components/app/configuration/config-prompt/conf import { getNewVar } from '@/utils/var' import { varHighlightHTML } from '@/app/components/app/configuration/base/var-highlight' import type { PromptVariable } from '@/models/debug' +import type { InputVar } from '@/app/components/workflow/types' const MAX_QUESTION_NUM = 5 @@ -32,6 +33,7 @@ export type OpeningStatementProps = { readonly?: boolean promptVariables?: PromptVariable[] onAutoAddPromptVariable: (variable: PromptVariable[]) => void + workflowVariables?: InputVar[] } // regex to match the {{}} and replace it with a span @@ -42,6 +44,7 @@ const OpeningStatement: FC = ({ readonly, promptVariables = [], onAutoAddPromptVariable, + workflowVariables = [], }) => { const { t } = useTranslation() const featureStore = useFeaturesStore() @@ -96,14 +99,18 @@ const OpeningStatement: FC = ({ const handleConfirm = () => { const keys = getInputKeys(tempValue) const promptKeys = promptVariables.map(item => item.key) + const workflowVariableKeys = workflowVariables.map(item => item.variable) let notIncludeKeys: string[] = [] - if (promptKeys.length === 0) { + if (promptKeys.length === 0 && workflowVariables.length === 0) { if (keys.length > 0) notIncludeKeys = keys } else { - notIncludeKeys = keys.filter(key => !promptKeys.includes(key)) + if (workflowVariables.length > 0) + notIncludeKeys = keys.filter(key => !workflowVariableKeys.includes(key)) + + else notIncludeKeys = keys.filter(key => !promptKeys.includes(key)) } if (notIncludeKeys.length > 0) { diff --git a/web/app/components/workflow/features.tsx b/web/app/components/workflow/features.tsx index 60a47bf177..16b638108c 100644 --- a/web/app/components/workflow/features.tsx +++ b/web/app/components/workflow/features.tsx @@ -4,16 +4,21 @@ import { } from 'react' import { useTranslation } from 'react-i18next' import { RiCloseLine } from '@remixicon/react' +import { useNodes } from 'reactflow' import { useStore } from './store' import { useIsChatMode, useNodesReadOnly, useNodesSyncDraft, } from './hooks' +import { type CommonNodeType, type InputVar, InputVarType, type Node } from './types' +import useConfig from './nodes/start/use-config' +import type { StartNodeType } from './nodes/start/types' import { FeaturesChoose, FeaturesPanel, } from '@/app/components/base/features' +import type { PromptVariable } from '@/models/debug' const Features = () => { const { t } = useTranslation() @@ -21,6 +26,24 @@ const Features = () => { const setShowFeaturesPanel = useStore(s => s.setShowFeaturesPanel) const { nodesReadOnly } = useNodesReadOnly() const { handleSyncWorkflowDraft } = useNodesSyncDraft() + const nodes = useNodes() + + const startNode = nodes.find(node => node.data.type === 'start') + const { id, data } = startNode as Node + const { handleAddVariable } = useConfig(id, data) + + const handleAddOpeningStatementVariable = (variables: PromptVariable[]) => { + const newVariable = variables[0] + const startNodeVariable: InputVar = { + variable: newVariable.key, + label: newVariable.name, + type: InputVarType.textInput, + max_length: newVariable.max_length, + required: newVariable.required || false, + options: [], + } + handleAddVariable(startNodeVariable) + } const handleFeaturesChange = useCallback(() => { handleSyncWorkflowDraft() @@ -55,8 +78,9 @@ const Features = () => { disabled={nodesReadOnly} onChange={handleFeaturesChange} openingStatementProps={{ - onAutoAddPromptVariable: () => {}, + onAutoAddPromptVariable: handleAddOpeningStatementVariable, }} + workflowVariables={data.variables} /> From b90ad587c2f16588f18f2968726d6b803ab35ef1 Mon Sep 17 00:00:00 2001 From: zhuhao <37029601+hwzhuhao@users.noreply.github.com> Date: Thu, 17 Oct 2024 19:12:42 +0800 Subject: [PATCH 04/12] refactor: move the embedding to the rag module and abstract the rerank runner for extension (#9423) --- .../embedding_type.py} | 0 api/core/model_manager.py | 2 +- .../__base/text_embedding_model.py | 2 +- .../text_embedding/text_embedding.py | 2 +- .../baichuan/text_embedding/text_embedding.py | 2 +- .../bedrock/text_embedding/text_embedding.py | 2 +- .../cohere/text_embedding/text_embedding.py | 2 +- .../text_embedding/text_embedding.py | 2 +- .../text_embedding/text_embedding.py | 2 +- .../text_embedding/text_embedding.py | 2 +- .../hunyuan/text_embedding/text_embedding.py | 2 +- .../jina/text_embedding/text_embedding.py | 2 +- .../localai/text_embedding/text_embedding.py | 2 +- .../minimax/text_embedding/text_embedding.py | 2 +- .../text_embedding/text_embedding.py | 2 +- .../nomic/text_embedding/text_embedding.py | 2 +- .../nvidia/text_embedding/text_embedding.py | 2 +- .../oci/text_embedding/text_embedding.py | 2 +- .../ollama/text_embedding/text_embedding.py | 2 +- .../openai/text_embedding/text_embedding.py | 2 +- .../text_embedding/text_embedding.py | 2 +- .../openllm/text_embedding/text_embedding.py | 2 +- .../text_embedding/text_embedding.py | 2 +- .../text_embedding/text_embedding.py | 2 +- .../text_embedding/text_embedding.py | 2 +- .../text_embedding/text_embedding.py | 2 +- .../tongyi/text_embedding/text_embedding.py | 2 +- .../upstage/text_embedding/text_embedding.py | 2 +- .../text_embedding/text_embedding.py | 2 +- .../text_embedding/text_embedding.py | 2 +- .../voyage/text_embedding/text_embedding.py | 2 +- .../wenxin/text_embedding/text_embedding.py | 2 +- .../text_embedding/text_embedding.py | 2 +- .../zhipuai/text_embedding/text_embedding.py | 2 +- .../data_post_processor.py | 53 +++++++++++-------- api/core/rag/datasource/retrieval_service.py | 2 +- .../vdb/analyticdb/analyticdb_vector.py | 2 +- .../rag/datasource/vdb/baidu/baidu_vector.py | 2 +- .../datasource/vdb/chroma/chroma_vector.py | 2 +- .../vdb/elasticsearch/elasticsearch_vector.py | 2 +- .../datasource/vdb/milvus/milvus_vector.py | 2 +- .../datasource/vdb/myscale/myscale_vector.py | 2 +- .../vdb/opensearch/opensearch_vector.py | 2 +- .../rag/datasource/vdb/oracle/oraclevector.py | 2 +- .../datasource/vdb/pgvecto_rs/pgvecto_rs.py | 2 +- .../rag/datasource/vdb/pgvector/pgvector.py | 2 +- .../datasource/vdb/qdrant/qdrant_vector.py | 2 +- .../rag/datasource/vdb/relyt/relyt_vector.py | 2 +- .../datasource/vdb/tencent/tencent_vector.py | 2 +- .../datasource/vdb/tidb_vector/tidb_vector.py | 2 +- api/core/rag/datasource/vdb/vector_factory.py | 4 +- .../vdb/vikingdb/vikingdb_vector.py | 2 +- .../vdb/weaviate/weaviate_vector.py | 2 +- api/core/rag/embedding/__init__.py | 0 .../{ => rag}/embedding/cached_embedding.py | 4 +- .../embedding_base.py} | 2 + api/core/rag/rerank/rerank_base.py | 26 +++++++++ api/core/rag/rerank/rerank_factory.py | 16 ++++++ api/core/rag/rerank/rerank_model.py | 3 +- .../rerank_mode.py => rerank_type.py} | 0 api/core/rag/rerank/weight_rerank.py | 5 +- 61 files changed, 135 insertions(+), 78 deletions(-) rename api/core/{embedding/embedding_constant.py => entities/embedding_type.py} (100%) create mode 100644 api/core/rag/embedding/__init__.py rename api/core/{ => rag}/embedding/cached_embedding.py (97%) rename api/core/rag/{datasource/entity/embedding.py => embedding/embedding_base.py} (90%) create mode 100644 api/core/rag/rerank/rerank_base.py create mode 100644 api/core/rag/rerank/rerank_factory.py rename api/core/rag/rerank/{constants/rerank_mode.py => rerank_type.py} (100%) diff --git a/api/core/embedding/embedding_constant.py b/api/core/entities/embedding_type.py similarity index 100% rename from api/core/embedding/embedding_constant.py rename to api/core/entities/embedding_type.py diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 74b4452362..e394233d2c 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -3,7 +3,7 @@ import os from collections.abc import Callable, Generator, Sequence from typing import IO, Optional, Union, cast -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_entities import ModelLoadBalancingConfiguration from core.errors.error import ProviderTokenNotInitError diff --git a/api/core/model_runtime/model_providers/__base/text_embedding_model.py b/api/core/model_runtime/model_providers/__base/text_embedding_model.py index a948dca20d..2d38fba955 100644 --- a/api/core/model_runtime/model_providers/__base/text_embedding_model.py +++ b/api/core/model_runtime/model_providers/__base/text_embedding_model.py @@ -4,7 +4,7 @@ from typing import Optional from pydantic import ConfigDict -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult from core.model_runtime.model_providers.__base.ai_model import AIModel diff --git a/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py index 8701a38050..c45ce87ea7 100644 --- a/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py @@ -7,7 +7,7 @@ import numpy as np import tiktoken from openai import AzureOpenAI -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import AIModelEntity, PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError diff --git a/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py index 56b9be1c36..1ace68d2b9 100644 --- a/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py @@ -4,7 +4,7 @@ from typing import Optional from requests import post -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.invoke import ( diff --git a/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py index d9c5726592..2f998d8bda 100644 --- a/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py @@ -13,7 +13,7 @@ from botocore.exceptions import ( UnknownServiceError, ) -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.invoke import ( diff --git a/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py index 4da2080690..5fd4d637be 100644 --- a/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py @@ -5,7 +5,7 @@ import cohere import numpy as np from cohere.core import RequestOptions -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.invoke import ( diff --git a/api/core/model_runtime/model_providers/fireworks/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/fireworks/text_embedding/text_embedding.py index cdce69ff38..c745a7e978 100644 --- a/api/core/model_runtime/model_providers/fireworks/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/fireworks/text_embedding/text_embedding.py @@ -5,7 +5,7 @@ from typing import Optional, Union import numpy as np from openai import OpenAI -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError diff --git a/api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py index b2e6d1b652..8278d1e64d 100644 --- a/api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py @@ -6,7 +6,7 @@ import numpy as np import requests from huggingface_hub import HfApi, InferenceClient -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult diff --git a/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py index b8ff3ca549..6b43934538 100644 --- a/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py @@ -1,7 +1,7 @@ import time from typing import Optional -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult diff --git a/api/core/model_runtime/model_providers/hunyuan/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/hunyuan/text_embedding/text_embedding.py index 75701ebc54..b6d857cb37 100644 --- a/api/core/model_runtime/model_providers/hunyuan/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/hunyuan/text_embedding/text_embedding.py @@ -9,7 +9,7 @@ from tencentcloud.common.profile.client_profile import ClientProfile from tencentcloud.common.profile.http_profile import HttpProfile from tencentcloud.hunyuan.v20230901 import hunyuan_client, models -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.invoke import ( diff --git a/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py index b397129512..49c558f4a4 100644 --- a/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py @@ -4,7 +4,7 @@ from typing import Optional from requests import post -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult diff --git a/api/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py index ab8ca76c2f..b4dfc1a4de 100644 --- a/api/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py @@ -5,7 +5,7 @@ from typing import Optional from requests import post from yarl import URL -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult diff --git a/api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py index d031bfa04d..29be5888af 100644 --- a/api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py @@ -4,7 +4,7 @@ from typing import Optional from requests import post -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.invoke import ( diff --git a/api/core/model_runtime/model_providers/mixedbread/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/mixedbread/text_embedding/text_embedding.py index 68b7b448bf..ca949cb953 100644 --- a/api/core/model_runtime/model_providers/mixedbread/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/mixedbread/text_embedding/text_embedding.py @@ -4,7 +4,7 @@ from typing import Optional import requests -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult diff --git a/api/core/model_runtime/model_providers/nomic/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/nomic/text_embedding/text_embedding.py index 857dfb5f41..56a707333c 100644 --- a/api/core/model_runtime/model_providers/nomic/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/nomic/text_embedding/text_embedding.py @@ -5,7 +5,7 @@ from typing import Optional from nomic import embed from nomic import login as nomic_login -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import ( EmbeddingUsage, diff --git a/api/core/model_runtime/model_providers/nvidia/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/nvidia/text_embedding/text_embedding.py index 936ceb8dd2..04363e11be 100644 --- a/api/core/model_runtime/model_providers/nvidia/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/nvidia/text_embedding/text_embedding.py @@ -4,7 +4,7 @@ from typing import Optional from requests import post -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.invoke import ( diff --git a/api/core/model_runtime/model_providers/oci/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/oci/text_embedding/text_embedding.py index 4de9296cca..50fa63768c 100644 --- a/api/core/model_runtime/model_providers/oci/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/oci/text_embedding/text_embedding.py @@ -6,7 +6,7 @@ from typing import Optional import numpy as np import oci -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.invoke import ( diff --git a/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py index 5cf3f1c6fa..a16c91cd7e 100644 --- a/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py @@ -8,7 +8,7 @@ from urllib.parse import urljoin import numpy as np import requests -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import ( AIModelEntity, diff --git a/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py index 16f1a0cfa1..bec01fe679 100644 --- a/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py @@ -6,7 +6,7 @@ import numpy as np import tiktoken from openai import OpenAI -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py index 64fa6aaa3c..c2b7297aac 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py @@ -7,7 +7,7 @@ from urllib.parse import urljoin import numpy as np import requests -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import ( AIModelEntity, diff --git a/api/core/model_runtime/model_providers/openllm/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/openllm/text_embedding/text_embedding.py index c5d4330912..43a2e948e2 100644 --- a/api/core/model_runtime/model_providers/openllm/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/openllm/text_embedding/text_embedding.py @@ -5,7 +5,7 @@ from typing import Optional from requests import post from requests.exceptions import ConnectionError, InvalidSchema, MissingSchema -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.invoke import ( diff --git a/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py index 1e86f351c8..d78bdaa75e 100644 --- a/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py @@ -7,7 +7,7 @@ from urllib.parse import urljoin import numpy as np import requests -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import ( AIModelEntity, diff --git a/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py index 9f724a77ac..c4e9d0b9c6 100644 --- a/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py @@ -4,7 +4,7 @@ from typing import Optional from replicate import Client as ReplicateClient -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult diff --git a/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py index 8f993ce672..ae7d805b4e 100644 --- a/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py @@ -6,7 +6,7 @@ from typing import Any, Optional import boto3 -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult diff --git a/api/core/model_runtime/model_providers/siliconflow/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/siliconflow/text_embedding/text_embedding.py index c5dcc12610..5e29a4827a 100644 --- a/api/core/model_runtime/model_providers/siliconflow/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/siliconflow/text_embedding/text_embedding.py @@ -1,6 +1,6 @@ from typing import Optional -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult from core.model_runtime.model_providers.openai_api_compatible.text_embedding.text_embedding import ( OAICompatEmbeddingModel, diff --git a/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py index 736cd44df8..2ef7f3f577 100644 --- a/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py @@ -4,7 +4,7 @@ from typing import Optional import dashscope import numpy as np -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import ( EmbeddingUsage, diff --git a/api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py index b6509cd26c..7dd495b55e 100644 --- a/api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py @@ -7,7 +7,7 @@ import numpy as np from openai import OpenAI from tokenizers import Tokenizer -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError diff --git a/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py index fce9544df0..43233e6126 100644 --- a/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py @@ -9,7 +9,7 @@ from google.cloud import aiplatform from google.oauth2 import service_account from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import ( AIModelEntity, diff --git a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py index 0dd4037c95..4d13e4708b 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py @@ -2,7 +2,7 @@ import time from decimal import Decimal from typing import Optional -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import ( AIModelEntity, diff --git a/api/core/model_runtime/model_providers/voyage/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/voyage/text_embedding/text_embedding.py index a8a4d3c15b..e69c9fccba 100644 --- a/api/core/model_runtime/model_providers/voyage/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/voyage/text_embedding/text_embedding.py @@ -4,7 +4,7 @@ from typing import Optional import requests -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult diff --git a/api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py index c21d0c0552..19135deb27 100644 --- a/api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py @@ -7,7 +7,7 @@ from typing import Any, Optional import numpy as np from requests import Response, post -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.invoke import InvokeError diff --git a/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py index ddc21b365c..f64b9c50af 100644 --- a/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py @@ -3,7 +3,7 @@ from typing import Optional from xinference_client.client.restful.restful_client import Client, RESTfulEmbeddingModelHandle -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult diff --git a/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py index 5a34a3d593..f629b62fd5 100644 --- a/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py @@ -3,7 +3,7 @@ from typing import Optional from zhipuai import ZhipuAI -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError diff --git a/api/core/rag/data_post_processor/data_post_processor.py b/api/core/rag/data_post_processor/data_post_processor.py index b1d6f93cff..992415657e 100644 --- a/api/core/rag/data_post_processor/data_post_processor.py +++ b/api/core/rag/data_post_processor/data_post_processor.py @@ -1,14 +1,14 @@ from typing import Optional -from core.model_manager import ModelManager +from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.rag.data_post_processor.reorder import ReorderRunner from core.rag.models.document import Document -from core.rag.rerank.constants.rerank_mode import RerankMode from core.rag.rerank.entity.weight import KeywordSetting, VectorSetting, Weights -from core.rag.rerank.rerank_model import RerankModelRunner -from core.rag.rerank.weight_rerank import WeightRerankRunner +from core.rag.rerank.rerank_base import BaseRerankRunner +from core.rag.rerank.rerank_factory import RerankRunnerFactory +from core.rag.rerank.rerank_type import RerankMode class DataPostProcessor: @@ -47,11 +47,12 @@ class DataPostProcessor: tenant_id: str, reranking_model: Optional[dict] = None, weights: Optional[dict] = None, - ) -> Optional[RerankModelRunner | WeightRerankRunner]: + ) -> Optional[BaseRerankRunner]: if reranking_mode == RerankMode.WEIGHTED_SCORE.value and weights: - return WeightRerankRunner( - tenant_id, - Weights( + runner = RerankRunnerFactory.create_rerank_runner( + runner_type=reranking_mode, + tenant_id=tenant_id, + weights=Weights( vector_setting=VectorSetting( vector_weight=weights["vector_setting"]["vector_weight"], embedding_provider_name=weights["vector_setting"]["embedding_provider_name"], @@ -62,23 +63,33 @@ class DataPostProcessor: ), ), ) + return runner elif reranking_mode == RerankMode.RERANKING_MODEL.value: - if reranking_model: - try: - model_manager = ModelManager() - rerank_model_instance = model_manager.get_model_instance( - tenant_id=tenant_id, - provider=reranking_model["reranking_provider_name"], - model_type=ModelType.RERANK, - model=reranking_model["reranking_model_name"], - ) - except InvokeAuthorizationError: - return None - return RerankModelRunner(rerank_model_instance) - return None + rerank_model_instance = self._get_rerank_model_instance(tenant_id, reranking_model) + if rerank_model_instance is None: + return None + runner = RerankRunnerFactory.create_rerank_runner( + runner_type=reranking_mode, rerank_model_instance=rerank_model_instance + ) + return runner return None def _get_reorder_runner(self, reorder_enabled) -> Optional[ReorderRunner]: if reorder_enabled: return ReorderRunner() return None + + def _get_rerank_model_instance(self, tenant_id: str, reranking_model: Optional[dict]) -> ModelInstance | None: + if reranking_model: + try: + model_manager = ModelManager() + rerank_model_instance = model_manager.get_model_instance( + tenant_id=tenant_id, + provider=reranking_model["reranking_provider_name"], + model_type=ModelType.RERANK, + model=reranking_model["reranking_model_name"], + ) + return rerank_model_instance + except InvokeAuthorizationError: + return None + return None diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index d3fd0c672a..3affbd2d0a 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -6,7 +6,7 @@ from flask import Flask, current_app from core.rag.data_post_processor.data_post_processor import DataPostProcessor from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.vdb.vector_factory import Vector -from core.rag.rerank.constants.rerank_mode import RerankMode +from core.rag.rerank.rerank_type import RerankMode from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_database import db from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py index 6dcd98dcfd..c77cb87376 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py +++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py @@ -9,10 +9,10 @@ _import_err_msg = ( ) from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/baidu/baidu_vector.py b/api/core/rag/datasource/vdb/baidu/baidu_vector.py index 543cfa67b3..1d4bfef76d 100644 --- a/api/core/rag/datasource/vdb/baidu/baidu_vector.py +++ b/api/core/rag/datasource/vdb/baidu/baidu_vector.py @@ -12,10 +12,10 @@ from pymochow.model.schema import Field, HNSWParams, Schema, VectorIndex from pymochow.model.table import AnnSearch, HNSWSearchParams, Partition, Row from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/chroma/chroma_vector.py b/api/core/rag/datasource/vdb/chroma/chroma_vector.py index 610aa498ab..a9e1486edd 100644 --- a/api/core/rag/datasource/vdb/chroma/chroma_vector.py +++ b/api/core/rag/datasource/vdb/chroma/chroma_vector.py @@ -6,10 +6,10 @@ from chromadb import QueryResult, Settings from pydantic import BaseModel from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py index f420373d5b..052a187225 100644 --- a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py +++ b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py @@ -9,11 +9,11 @@ from elasticsearch import Elasticsearch from flask import current_app from pydantic import BaseModel, model_validator -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/milvus/milvus_vector.py b/api/core/rag/datasource/vdb/milvus/milvus_vector.py index bdca59f869..080a1ef567 100644 --- a/api/core/rag/datasource/vdb/milvus/milvus_vector.py +++ b/api/core/rag/datasource/vdb/milvus/milvus_vector.py @@ -7,11 +7,11 @@ from pymilvus import MilvusClient, MilvusException from pymilvus.milvus_client import IndexParams from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/myscale/myscale_vector.py b/api/core/rag/datasource/vdb/myscale/myscale_vector.py index b30aa7ca22..1fca926a2d 100644 --- a/api/core/rag/datasource/vdb/myscale/myscale_vector.py +++ b/api/core/rag/datasource/vdb/myscale/myscale_vector.py @@ -8,10 +8,10 @@ from clickhouse_connect import get_client from pydantic import BaseModel from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py index 8d2e0a86ab..0e0f107268 100644 --- a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py +++ b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py @@ -9,11 +9,11 @@ from opensearchpy.helpers import BulkIndexError from pydantic import BaseModel, model_validator from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/oracle/oraclevector.py b/api/core/rag/datasource/vdb/oracle/oraclevector.py index 84a4381cd1..4ced5d61e5 100644 --- a/api/core/rag/datasource/vdb/oracle/oraclevector.py +++ b/api/core/rag/datasource/vdb/oracle/oraclevector.py @@ -13,10 +13,10 @@ from nltk.corpus import stopwords from pydantic import BaseModel, model_validator from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py index a82a9b96dd..9233cd63dc 100644 --- a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py +++ b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py @@ -12,11 +12,11 @@ from sqlalchemy.dialects import postgresql from sqlalchemy.orm import Mapped, Session, mapped_column from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.pgvecto_rs.collection import CollectionORM from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/pgvector/pgvector.py b/api/core/rag/datasource/vdb/pgvector/pgvector.py index 6f336d27e7..40a9cdd136 100644 --- a/api/core/rag/datasource/vdb/pgvector/pgvector.py +++ b/api/core/rag/datasource/vdb/pgvector/pgvector.py @@ -8,10 +8,10 @@ import psycopg2.pool from pydantic import BaseModel, model_validator from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py index f418e3ca05..69d2aa4f76 100644 --- a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py +++ b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py @@ -20,11 +20,11 @@ from qdrant_client.http.models import ( from qdrant_client.local.qdrant_local import QdrantLocal from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_database import db from extensions.ext_redis import redis_client diff --git a/api/core/rag/datasource/vdb/relyt/relyt_vector.py b/api/core/rag/datasource/vdb/relyt/relyt_vector.py index 13a63784be..f373dcfeab 100644 --- a/api/core/rag/datasource/vdb/relyt/relyt_vector.py +++ b/api/core/rag/datasource/vdb/relyt/relyt_vector.py @@ -8,9 +8,9 @@ from sqlalchemy import text as sql_text from sqlalchemy.dialects.postgresql import JSON, TEXT from sqlalchemy.orm import Session -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from models.dataset import Dataset try: diff --git a/api/core/rag/datasource/vdb/tencent/tencent_vector.py b/api/core/rag/datasource/vdb/tencent/tencent_vector.py index 39e3a7f6cf..f971a9c5eb 100644 --- a/api/core/rag/datasource/vdb/tencent/tencent_vector.py +++ b/api/core/rag/datasource/vdb/tencent/tencent_vector.py @@ -8,10 +8,10 @@ from tcvectordb.model import index as vdb_index from tcvectordb.model.document import Filter from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py index 7837c5a4aa..1147e35ce8 100644 --- a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py +++ b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py @@ -9,10 +9,10 @@ from sqlalchemy import text as sql_text from sqlalchemy.orm import Session, declarative_base from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index 873b289027..fb956a16ed 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -2,12 +2,12 @@ from abc import ABC, abstractmethod from typing import Any, Optional from configs import dify_config -from core.embedding.cached_embedding import CacheEmbedding from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.cached_embedding import CacheEmbedding +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py b/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py index 5f60f10acb..4f927f2899 100644 --- a/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py +++ b/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py @@ -14,11 +14,11 @@ from volcengine.viking_db import ( ) from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.field import Field as vdb_Field from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py index 4009efe7a7..649cfbfea8 100644 --- a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py +++ b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py @@ -7,11 +7,11 @@ import weaviate from pydantic import BaseModel, model_validator from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset diff --git a/api/core/rag/embedding/__init__.py b/api/core/rag/embedding/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/embedding/cached_embedding.py b/api/core/rag/embedding/cached_embedding.py similarity index 97% rename from api/core/embedding/cached_embedding.py rename to api/core/rag/embedding/cached_embedding.py index 31d2171e72..b3e93ce760 100644 --- a/api/core/embedding/cached_embedding.py +++ b/api/core/rag/embedding/cached_embedding.py @@ -6,11 +6,11 @@ import numpy as np from sqlalchemy.exc import IntegrityError from configs import dify_config -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_manager import ModelInstance from core.model_runtime.entities.model_entities import ModelPropertyKey from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from core.rag.datasource.entity.embedding import Embeddings +from core.rag.embedding.embedding_base import Embeddings from extensions.ext_database import db from extensions.ext_redis import redis_client from libs import helper diff --git a/api/core/rag/datasource/entity/embedding.py b/api/core/rag/embedding/embedding_base.py similarity index 90% rename from api/core/rag/datasource/entity/embedding.py rename to api/core/rag/embedding/embedding_base.py index 126c1a3723..9f232ab910 100644 --- a/api/core/rag/datasource/entity/embedding.py +++ b/api/core/rag/embedding/embedding_base.py @@ -7,10 +7,12 @@ class Embeddings(ABC): @abstractmethod def embed_documents(self, texts: list[str]) -> list[list[float]]: """Embed search docs.""" + raise NotImplementedError @abstractmethod def embed_query(self, text: str) -> list[float]: """Embed query text.""" + raise NotImplementedError async def aembed_documents(self, texts: list[str]) -> list[list[float]]: """Asynchronous Embed search docs.""" diff --git a/api/core/rag/rerank/rerank_base.py b/api/core/rag/rerank/rerank_base.py new file mode 100644 index 0000000000..818b04b2ff --- /dev/null +++ b/api/core/rag/rerank/rerank_base.py @@ -0,0 +1,26 @@ +from abc import ABC, abstractmethod +from typing import Optional + +from core.rag.models.document import Document + + +class BaseRerankRunner(ABC): + @abstractmethod + def run( + self, + query: str, + documents: list[Document], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> list[Document]: + """ + Run rerank model + :param query: search query + :param documents: documents for reranking + :param score_threshold: score threshold + :param top_n: top n + :param user: unique user id if needed + :return: + """ + raise NotImplementedError diff --git a/api/core/rag/rerank/rerank_factory.py b/api/core/rag/rerank/rerank_factory.py new file mode 100644 index 0000000000..1a3cf85736 --- /dev/null +++ b/api/core/rag/rerank/rerank_factory.py @@ -0,0 +1,16 @@ +from core.rag.rerank.rerank_base import BaseRerankRunner +from core.rag.rerank.rerank_model import RerankModelRunner +from core.rag.rerank.rerank_type import RerankMode +from core.rag.rerank.weight_rerank import WeightRerankRunner + + +class RerankRunnerFactory: + @staticmethod + def create_rerank_runner(runner_type: str, *args, **kwargs) -> BaseRerankRunner: + match runner_type: + case RerankMode.RERANKING_MODEL.value: + return RerankModelRunner(*args, **kwargs) + case RerankMode.WEIGHTED_SCORE.value: + return WeightRerankRunner(*args, **kwargs) + case _: + raise ValueError(f"Unknown runner type: {runner_type}") diff --git a/api/core/rag/rerank/rerank_model.py b/api/core/rag/rerank/rerank_model.py index 27f86aed34..40ebf0befd 100644 --- a/api/core/rag/rerank/rerank_model.py +++ b/api/core/rag/rerank/rerank_model.py @@ -2,9 +2,10 @@ from typing import Optional from core.model_manager import ModelInstance from core.rag.models.document import Document +from core.rag.rerank.rerank_base import BaseRerankRunner -class RerankModelRunner: +class RerankModelRunner(BaseRerankRunner): def __init__(self, rerank_model_instance: ModelInstance) -> None: self.rerank_model_instance = rerank_model_instance diff --git a/api/core/rag/rerank/constants/rerank_mode.py b/api/core/rag/rerank/rerank_type.py similarity index 100% rename from api/core/rag/rerank/constants/rerank_mode.py rename to api/core/rag/rerank/rerank_type.py diff --git a/api/core/rag/rerank/weight_rerank.py b/api/core/rag/rerank/weight_rerank.py index 16d6b879a4..2e3fbe04e2 100644 --- a/api/core/rag/rerank/weight_rerank.py +++ b/api/core/rag/rerank/weight_rerank.py @@ -4,15 +4,16 @@ from typing import Optional import numpy as np -from core.embedding.cached_embedding import CacheEmbedding from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler +from core.rag.embedding.cached_embedding import CacheEmbedding from core.rag.models.document import Document from core.rag.rerank.entity.weight import VectorSetting, Weights +from core.rag.rerank.rerank_base import BaseRerankRunner -class WeightRerankRunner: +class WeightRerankRunner(BaseRerankRunner): def __init__(self, tenant_id: str, weights: Weights) -> None: self.tenant_id = tenant_id self.weights = weights From 211f4168063b509e9e0c34ed4881741f4d7af859 Mon Sep 17 00:00:00 2001 From: chzphoenix Date: Thu, 17 Oct 2024 19:18:32 +0800 Subject: [PATCH 05/12] feat:add wenxin rerank (#9431) Co-authored-by: cuihz Co-authored-by: crazywoola <427733928@qq.com> --- .../model_providers/wenxin/_common.py | 1 + .../model_providers/wenxin/rerank/__init__.py | 0 .../wenxin/rerank/bce-reranker-base_v1.yaml | 8 + .../model_providers/wenxin/rerank/rerank.py | 147 ++++++++++++++++++ .../model_providers/wenxin/wenxin.yaml | 1 + .../model_runtime/wenxin/test_rerank.py | 21 +++ 6 files changed, 178 insertions(+) create mode 100644 api/core/model_runtime/model_providers/wenxin/rerank/__init__.py create mode 100644 api/core/model_runtime/model_providers/wenxin/rerank/bce-reranker-base_v1.yaml create mode 100644 api/core/model_runtime/model_providers/wenxin/rerank/rerank.py create mode 100644 api/tests/integration_tests/model_runtime/wenxin/test_rerank.py diff --git a/api/core/model_runtime/model_providers/wenxin/_common.py b/api/core/model_runtime/model_providers/wenxin/_common.py index d72d1bd83a..1a4cc15371 100644 --- a/api/core/model_runtime/model_providers/wenxin/_common.py +++ b/api/core/model_runtime/model_providers/wenxin/_common.py @@ -120,6 +120,7 @@ class _CommonWenxin: "bge-large-en": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/bge_large_en", "bge-large-zh": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/bge_large_zh", "tao-8k": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/tao_8k", + "bce-reranker-base_v1": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/reranker/bce_reranker_base", } function_calling_supports = [ diff --git a/api/core/model_runtime/model_providers/wenxin/rerank/__init__.py b/api/core/model_runtime/model_providers/wenxin/rerank/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/wenxin/rerank/bce-reranker-base_v1.yaml b/api/core/model_runtime/model_providers/wenxin/rerank/bce-reranker-base_v1.yaml new file mode 100644 index 0000000000..ef4b07d767 --- /dev/null +++ b/api/core/model_runtime/model_providers/wenxin/rerank/bce-reranker-base_v1.yaml @@ -0,0 +1,8 @@ +model: bce-reranker-base_v1 +model_type: rerank +model_properties: + context_size: 4096 +pricing: + input: '0.0005' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/wenxin/rerank/rerank.py b/api/core/model_runtime/model_providers/wenxin/rerank/rerank.py new file mode 100644 index 0000000000..b22aead22b --- /dev/null +++ b/api/core/model_runtime/model_providers/wenxin/rerank/rerank.py @@ -0,0 +1,147 @@ +from typing import Optional + +import httpx + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType +from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.rerank_model import RerankModel +from core.model_runtime.model_providers.wenxin._common import _CommonWenxin + + +class WenxinRerank(_CommonWenxin): + def rerank(self, model: str, query: str, docs: list[str], top_n: Optional[int] = None): + access_token = self._get_access_token() + url = f"{self.api_bases[model]}?access_token={access_token}" + + try: + response = httpx.post( + url, + json={"model": model, "query": query, "documents": docs, "top_n": top_n}, + headers={"Content-Type": "application/json"}, + ) + response.raise_for_status() + return response.json() + except httpx.HTTPStatusError as e: + raise InvokeServerUnavailableError(str(e)) + + +class WenxinRerankModel(RerankModel): + """ + Model class for wenxin rerank model. + """ + + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: + """ + Invoke rerank model + + :param model: model name + :param credentials: model credentials + :param query: search query + :param docs: docs for reranking + :param score_threshold: score threshold + :param top_n: top n documents to return + :param user: unique user id + :return: rerank result + """ + if len(docs) == 0: + return RerankResult(model=model, docs=[]) + + api_key = credentials["api_key"] + secret_key = credentials["secret_key"] + + wenxin_rerank: WenxinRerank = WenxinRerank(api_key, secret_key) + + try: + results = wenxin_rerank.rerank(model, query, docs, top_n) + + rerank_documents = [] + for result in results["results"]: + index = result["index"] + if "document" in result: + text = result["document"] + else: + # llama.cpp rerank maynot return original documents + text = docs[index] + + rerank_document = RerankDocument( + index=index, + text=text, + score=result["relevance_score"], + ) + + if score_threshold is None or result["relevance_score"] >= score_threshold: + rerank_documents.append(rerank_document) + + return RerankResult(model=model, docs=rerank_documents) + except httpx.HTTPStatusError as e: + raise InvokeServerUnavailableError(str(e)) + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + self._invoke( + model=model, + credentials=credentials, + query="What is the capital of the United States?", + docs=[ + "Carson City is the capital city of the American state of Nevada. At the 2010 United States " + "Census, Carson City had a population of 55,274.", + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " + "are a political division controlled by the United States. Its capital is Saipan.", + ], + score_threshold=0.8, + ) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + """ + return { + InvokeConnectionError: [httpx.ConnectError], + InvokeServerUnavailableError: [httpx.RemoteProtocolError], + InvokeRateLimitError: [], + InvokeAuthorizationError: [httpx.HTTPStatusError], + InvokeBadRequestError: [httpx.RequestError], + } + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: + """ + generate custom model entities from credentials + """ + entity = AIModelEntity( + model=model, + label=I18nObject(en_US=model), + model_type=ModelType.RERANK, + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))}, + ) + + return entity diff --git a/api/core/model_runtime/model_providers/wenxin/wenxin.yaml b/api/core/model_runtime/model_providers/wenxin/wenxin.yaml index 6a6b38e6a1..d8acfd8120 100644 --- a/api/core/model_runtime/model_providers/wenxin/wenxin.yaml +++ b/api/core/model_runtime/model_providers/wenxin/wenxin.yaml @@ -18,6 +18,7 @@ help: supported_model_types: - llm - text-embedding + - rerank configurate_methods: - predefined-model provider_credential_schema: diff --git a/api/tests/integration_tests/model_runtime/wenxin/test_rerank.py b/api/tests/integration_tests/model_runtime/wenxin/test_rerank.py new file mode 100644 index 0000000000..33c803e8e1 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/wenxin/test_rerank.py @@ -0,0 +1,21 @@ +import os +from time import sleep + +from core.model_runtime.entities.rerank_entities import RerankResult +from core.model_runtime.model_providers.wenxin.rerank.rerank import WenxinRerankModel + + +def test_invoke_bce_reranker_base_v1(): + sleep(3) + model = WenxinRerankModel() + + response = model.invoke( + model="bce-reranker-base_v1", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, + query="What is Deep Learning?", + docs=["Deep Learning is ...", "My Book is ..."], + user="abc-123", + ) + + assert isinstance(response, RerankResult) + assert len(response.docs) == 2 From 3fc0ebdd51251faf3c96cdd4163293a0d97c315f Mon Sep 17 00:00:00 2001 From: zhuhao <37029601+hwzhuhao@users.noreply.github.com> Date: Fri, 18 Oct 2024 08:19:58 +0800 Subject: [PATCH 06/12] feat: add yi-lightning llm model for yi (#9458) --- .../model_providers/yi/llm/_position.yaml | 1 + .../model_providers/yi/llm/yi-lightning.yaml | 43 +++++++++++++++++++ 2 files changed, 44 insertions(+) create mode 100644 api/core/model_runtime/model_providers/yi/llm/yi-lightning.yaml diff --git a/api/core/model_runtime/model_providers/yi/llm/_position.yaml b/api/core/model_runtime/model_providers/yi/llm/_position.yaml index e876893b41..5fa098beda 100644 --- a/api/core/model_runtime/model_providers/yi/llm/_position.yaml +++ b/api/core/model_runtime/model_providers/yi/llm/_position.yaml @@ -7,3 +7,4 @@ - yi-medium-200k - yi-spark - yi-large-turbo +- yi-lightning diff --git a/api/core/model_runtime/model_providers/yi/llm/yi-lightning.yaml b/api/core/model_runtime/model_providers/yi/llm/yi-lightning.yaml new file mode 100644 index 0000000000..fccf1b3a26 --- /dev/null +++ b/api/core/model_runtime/model_providers/yi/llm/yi-lightning.yaml @@ -0,0 +1,43 @@ +model: yi-lightning +label: + zh_Hans: yi-lightning + en_US: yi-lightning +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 16384 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.3 + min: 0.0 + max: 2.0 + help: + zh_Hans: 控制生成结果的多样性和随机性。数值越小,越严谨;数值越大,越发散。 + en_US: Control the diversity and randomness of generated results. The smaller the value, the more rigorous it is; the larger the value, the more divergent it is. + - name: max_tokens + use_template: max_tokens + type: int + default: 1024 + min: 1 + max: 4000 + help: + zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。 + en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. + - name: top_p + use_template: top_p + type: float + default: 0.9 + min: 0.01 + max: 1.00 + help: + zh_Hans: 控制生成结果的随机性。数值越小,随机性越弱;数值越大,随机性越强。一般而言,top_p 和 temperature 两个参数选择一个进行调整即可。 + en_US: Control the randomness of generated results. The smaller the value, the weaker the randomness; the larger the value, the stronger the randomness. Generally speaking, you can adjust one of the two parameters top_p and temperature. +pricing: + input: '0.99' + output: '0.99' + unit: '0.000001' + currency: RMB From a53fdc712679fe36eb703a59d998d730e9c0053a Mon Sep 17 00:00:00 2001 From: ice yao Date: Fri, 18 Oct 2024 08:20:22 +0800 Subject: [PATCH 07/12] fix: add missing vector type to migrate command (#9470) --- api/commands.py | 81 +++++++++++++++---------------------------------- 1 file changed, 25 insertions(+), 56 deletions(-) diff --git a/api/commands.py b/api/commands.py index dbcd8a744d..5b7f79c8f0 100644 --- a/api/commands.py +++ b/api/commands.py @@ -259,6 +259,25 @@ def migrate_knowledge_vector_database(): skipped_count = 0 total_count = 0 vector_type = dify_config.VECTOR_STORE + upper_colletion_vector_types = { + VectorType.MILVUS, + VectorType.PGVECTOR, + VectorType.RELYT, + VectorType.WEAVIATE, + VectorType.ORACLE, + VectorType.ELASTICSEARCH, + } + lower_colletion_vector_types = { + VectorType.ANALYTICDB, + VectorType.CHROMA, + VectorType.MYSCALE, + VectorType.PGVECTO_RS, + VectorType.TIDB_VECTOR, + VectorType.OPENSEARCH, + VectorType.TENCENT, + VectorType.BAIDU, + VectorType.VIKINGDB, + } page = 1 while True: try: @@ -284,11 +303,9 @@ def migrate_knowledge_vector_database(): skipped_count = skipped_count + 1 continue collection_name = "" - if vector_type == VectorType.WEAVIATE: - dataset_id = dataset.id + dataset_id = dataset.id + if vector_type in upper_colletion_vector_types: collection_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = {"type": VectorType.WEAVIATE, "vector_store": {"class_prefix": collection_name}} - dataset.index_struct = json.dumps(index_struct_dict) elif vector_type == VectorType.QDRANT: if dataset.collection_binding_id: dataset_collection_binding = ( @@ -301,63 +318,15 @@ def migrate_knowledge_vector_database(): else: raise ValueError("Dataset Collection Binding not found") else: - dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = {"type": VectorType.QDRANT, "vector_store": {"class_prefix": collection_name}} - dataset.index_struct = json.dumps(index_struct_dict) - elif vector_type == VectorType.MILVUS: - dataset_id = dataset.id - collection_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = {"type": VectorType.MILVUS, "vector_store": {"class_prefix": collection_name}} - dataset.index_struct = json.dumps(index_struct_dict) - elif vector_type == VectorType.RELYT: - dataset_id = dataset.id - collection_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = {"type": "relyt", "vector_store": {"class_prefix": collection_name}} - dataset.index_struct = json.dumps(index_struct_dict) - elif vector_type == VectorType.TENCENT: - dataset_id = dataset.id - collection_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = {"type": VectorType.TENCENT, "vector_store": {"class_prefix": collection_name}} - dataset.index_struct = json.dumps(index_struct_dict) - elif vector_type == VectorType.PGVECTOR: - dataset_id = dataset.id - collection_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = {"type": VectorType.PGVECTOR, "vector_store": {"class_prefix": collection_name}} - dataset.index_struct = json.dumps(index_struct_dict) - elif vector_type == VectorType.OPENSEARCH: - dataset_id = dataset.id - collection_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = { - "type": VectorType.OPENSEARCH, - "vector_store": {"class_prefix": collection_name}, - } - dataset.index_struct = json.dumps(index_struct_dict) - elif vector_type == VectorType.ANALYTICDB: - dataset_id = dataset.id - collection_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = { - "type": VectorType.ANALYTICDB, - "vector_store": {"class_prefix": collection_name}, - } - dataset.index_struct = json.dumps(index_struct_dict) - elif vector_type == VectorType.ELASTICSEARCH: - dataset_id = dataset.id - index_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = {"type": "elasticsearch", "vector_store": {"class_prefix": index_name}} - dataset.index_struct = json.dumps(index_struct_dict) - elif vector_type == VectorType.BAIDU: - dataset_id = dataset.id - collection_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = { - "type": VectorType.BAIDU, - "vector_store": {"class_prefix": collection_name}, - } - dataset.index_struct = json.dumps(index_struct_dict) + elif vector_type in lower_colletion_vector_types: + collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() else: raise ValueError(f"Vector store {vector_type} is not supported.") + index_struct_dict = {"type": vector_type, "vector_store": {"class_prefix": collection_name}} + dataset.index_struct = json.dumps(index_struct_dict) vector = Vector(dataset) click.echo(f"Migrating dataset {dataset.id}.") From 2155bba5b02975cbc13bcfa58136227c38d17bb1 Mon Sep 17 00:00:00 2001 From: ice yao Date: Fri, 18 Oct 2024 08:21:41 +0800 Subject: [PATCH 08/12] fix: update mismatch vector type (#9462) --- api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py index 9233cd63dc..7cbbdcc81f 100644 --- a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py +++ b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py @@ -216,7 +216,7 @@ class PGVectoRSFactory(AbstractVectorFactory): else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() - dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name)) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.PGVECTO_RS, collection_name)) dim = len(embeddings.embed_query("pgvecto_rs")) return PGVectoRS( From b3cde9900c55a73f8a28044d2a4da91a70bbf3a0 Mon Sep 17 00:00:00 2001 From: zhuhao <37029601+hwzhuhao@users.noreply.github.com> Date: Fri, 18 Oct 2024 08:21:54 +0800 Subject: [PATCH 09/12] feat: add parameter top-k for the llm model provided by openrouter and siliconflow (#9455) --- .../model_providers/openrouter/llm/deepseek-chat.yaml | 9 +++++++++ .../model_providers/openrouter/llm/deepseek-coder.yaml | 9 +++++++++ .../model_providers/openrouter/llm/gpt-3.5-turbo.yaml | 9 +++++++++ .../model_providers/openrouter/llm/gpt-4-32k.yaml | 9 +++++++++ .../model_providers/openrouter/llm/gpt-4.yaml | 9 +++++++++ .../openrouter/llm/gpt-4o-2024-08-06.yaml | 9 +++++++++ .../model_providers/openrouter/llm/gpt-4o-mini.yaml | 9 +++++++++ .../model_providers/openrouter/llm/gpt-4o.yaml | 9 +++++++++ .../openrouter/llm/llama-3-70b-instruct.yaml | 9 +++++++++ .../openrouter/llm/llama-3-8b-instruct.yaml | 9 +++++++++ .../openrouter/llm/llama-3.1-405b-instruct.yaml | 9 +++++++++ .../openrouter/llm/llama-3.1-70b-instruct.yaml | 9 +++++++++ .../openrouter/llm/llama-3.1-8b-instruct.yaml | 9 +++++++++ .../openrouter/llm/mistral-7b-instruct.yaml | 9 +++++++++ .../openrouter/llm/mixtral-8x22b-instruct.yaml | 9 +++++++++ .../openrouter/llm/mixtral-8x7b-instruct.yaml | 9 +++++++++ .../model_providers/openrouter/llm/o1-mini.yaml | 9 +++++++++ .../model_providers/openrouter/llm/o1-preview.yaml | 9 +++++++++ .../openrouter/llm/qwen2-72b-instruct.yaml | 9 +++++++++ .../openrouter/llm/qwen2.5-72b-instruct.yaml | 9 +++++++++ .../siliconflow/llm/deepdeek-coder-v2-instruct.yaml | 9 +++++++++ .../siliconflow/llm/deepseek-v2-chat.yaml | 9 +++++++++ .../model_providers/siliconflow/llm/deepseek-v2.5.yaml | 9 +++++++++ .../model_providers/siliconflow/llm/gemma-2-27b-it.yaml | 9 +++++++++ .../model_providers/siliconflow/llm/gemma-2-9b-it.yaml | 9 +++++++++ .../model_providers/siliconflow/llm/glm4-9b-chat.yaml | 9 +++++++++ .../siliconflow/llm/internlm2_5-20b-chat.yaml | 9 +++++++++ .../siliconflow/llm/internlm2_5-7b-chat.yaml | 9 +++++++++ .../siliconflow/llm/meta-mlama-3-70b-instruct.yaml | 9 +++++++++ .../siliconflow/llm/meta-mlama-3-8b-instruct.yaml | 9 +++++++++ .../siliconflow/llm/meta-mlama-3.1-405b-instruct.yaml | 9 +++++++++ .../siliconflow/llm/meta-mlama-3.1-70b-instruct.yaml | 9 +++++++++ .../siliconflow/llm/meta-mlama-3.1-8b-instruct.yaml | 9 +++++++++ .../siliconflow/llm/mistral-7b-instruct-v0.2.yaml | 9 +++++++++ .../siliconflow/llm/mistral-8x7b-instruct-v0.1.yaml | 9 +++++++++ .../siliconflow/llm/qwen2-1.5b-instruct.yaml | 9 +++++++++ .../siliconflow/llm/qwen2-57b-a14b-instruct.yaml | 9 +++++++++ .../siliconflow/llm/qwen2-72b-instruct.yaml | 9 +++++++++ .../siliconflow/llm/qwen2-7b-instruct.yaml | 9 +++++++++ .../siliconflow/llm/qwen2.5-14b-instruct.yaml | 9 +++++++++ .../siliconflow/llm/qwen2.5-32b-instruct.yaml | 9 +++++++++ .../siliconflow/llm/qwen2.5-72b-instruct.yaml | 9 +++++++++ .../siliconflow/llm/qwen2.5-7b-instruct.yaml | 9 +++++++++ .../model_providers/siliconflow/llm/yi-1.5-34b-chat.yaml | 9 +++++++++ .../model_providers/siliconflow/llm/yi-1.5-6b-chat.yaml | 9 +++++++++ .../model_providers/siliconflow/llm/yi-1.5-9b-chat.yaml | 9 +++++++++ 46 files changed, 414 insertions(+) diff --git a/api/core/model_runtime/model_providers/openrouter/llm/deepseek-chat.yaml b/api/core/model_runtime/model_providers/openrouter/llm/deepseek-chat.yaml index 7a1dea6950..6743bfcad6 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/deepseek-chat.yaml +++ b/api/core/model_runtime/model_providers/openrouter/llm/deepseek-chat.yaml @@ -35,6 +35,15 @@ parameter_rules: help: zh_Hans: 控制生成结果的随机性。数值越小,随机性越弱;数值越大,随机性越强。一般而言,top_p 和 temperature 两个参数选择一个进行调整即可。 en_US: Control the randomness of generated results. The smaller the value, the weaker the randomness; the larger the value, the stronger the randomness. Generally speaking, you can adjust one of the two parameters top_p and temperature. + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty default: 0 diff --git a/api/core/model_runtime/model_providers/openrouter/llm/deepseek-coder.yaml b/api/core/model_runtime/model_providers/openrouter/llm/deepseek-coder.yaml index c05f4769b8..375a4d2d52 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/deepseek-coder.yaml +++ b/api/core/model_runtime/model_providers/openrouter/llm/deepseek-coder.yaml @@ -18,6 +18,15 @@ parameter_rules: min: 0 max: 1 default: 1 + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: max_tokens use_template: max_tokens min: 1 diff --git a/api/core/model_runtime/model_providers/openrouter/llm/gpt-3.5-turbo.yaml b/api/core/model_runtime/model_providers/openrouter/llm/gpt-3.5-turbo.yaml index 186c1cc663..621ecf065e 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/gpt-3.5-turbo.yaml +++ b/api/core/model_runtime/model_providers/openrouter/llm/gpt-3.5-turbo.yaml @@ -14,6 +14,15 @@ parameter_rules: use_template: temperature - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: presence_penalty use_template: presence_penalty - name: frequency_penalty diff --git a/api/core/model_runtime/model_providers/openrouter/llm/gpt-4-32k.yaml b/api/core/model_runtime/model_providers/openrouter/llm/gpt-4-32k.yaml index 8c2989b300..887e6d60f9 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/gpt-4-32k.yaml +++ b/api/core/model_runtime/model_providers/openrouter/llm/gpt-4-32k.yaml @@ -14,6 +14,15 @@ parameter_rules: use_template: temperature - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: presence_penalty use_template: presence_penalty - name: frequency_penalty diff --git a/api/core/model_runtime/model_providers/openrouter/llm/gpt-4.yaml b/api/core/model_runtime/model_providers/openrouter/llm/gpt-4.yaml index ef19d4f6f0..66d1f9ae67 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/gpt-4.yaml +++ b/api/core/model_runtime/model_providers/openrouter/llm/gpt-4.yaml @@ -14,6 +14,15 @@ parameter_rules: use_template: temperature - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: presence_penalty use_template: presence_penalty - name: frequency_penalty diff --git a/api/core/model_runtime/model_providers/openrouter/llm/gpt-4o-2024-08-06.yaml b/api/core/model_runtime/model_providers/openrouter/llm/gpt-4o-2024-08-06.yaml index 0be325f55b..695cc3eedf 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/gpt-4o-2024-08-06.yaml +++ b/api/core/model_runtime/model_providers/openrouter/llm/gpt-4o-2024-08-06.yaml @@ -16,6 +16,15 @@ parameter_rules: use_template: temperature - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: presence_penalty use_template: presence_penalty - name: frequency_penalty diff --git a/api/core/model_runtime/model_providers/openrouter/llm/gpt-4o-mini.yaml b/api/core/model_runtime/model_providers/openrouter/llm/gpt-4o-mini.yaml index 3b1d95643d..e1e5889085 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/gpt-4o-mini.yaml +++ b/api/core/model_runtime/model_providers/openrouter/llm/gpt-4o-mini.yaml @@ -15,6 +15,15 @@ parameter_rules: use_template: temperature - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: presence_penalty use_template: presence_penalty - name: frequency_penalty diff --git a/api/core/model_runtime/model_providers/openrouter/llm/gpt-4o.yaml b/api/core/model_runtime/model_providers/openrouter/llm/gpt-4o.yaml index a8c97efdd6..560bf9d7d0 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/gpt-4o.yaml +++ b/api/core/model_runtime/model_providers/openrouter/llm/gpt-4o.yaml @@ -15,6 +15,15 @@ parameter_rules: use_template: temperature - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: presence_penalty use_template: presence_penalty - name: frequency_penalty diff --git a/api/core/model_runtime/model_providers/openrouter/llm/llama-3-70b-instruct.yaml b/api/core/model_runtime/model_providers/openrouter/llm/llama-3-70b-instruct.yaml index b91c39e729..04a4a90c6d 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/llama-3-70b-instruct.yaml +++ b/api/core/model_runtime/model_providers/openrouter/llm/llama-3-70b-instruct.yaml @@ -10,6 +10,15 @@ parameter_rules: use_template: temperature - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: max_tokens use_template: max_tokens required: true diff --git a/api/core/model_runtime/model_providers/openrouter/llm/llama-3-8b-instruct.yaml b/api/core/model_runtime/model_providers/openrouter/llm/llama-3-8b-instruct.yaml index 84b2c7fac2..066949d431 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/llama-3-8b-instruct.yaml +++ b/api/core/model_runtime/model_providers/openrouter/llm/llama-3-8b-instruct.yaml @@ -10,6 +10,15 @@ parameter_rules: use_template: temperature - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: max_tokens use_template: max_tokens required: true diff --git a/api/core/model_runtime/model_providers/openrouter/llm/llama-3.1-405b-instruct.yaml b/api/core/model_runtime/model_providers/openrouter/llm/llama-3.1-405b-instruct.yaml index a489ce1b5a..0cd89dea71 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/llama-3.1-405b-instruct.yaml +++ b/api/core/model_runtime/model_providers/openrouter/llm/llama-3.1-405b-instruct.yaml @@ -10,6 +10,15 @@ parameter_rules: use_template: temperature - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: max_tokens use_template: max_tokens required: true diff --git a/api/core/model_runtime/model_providers/openrouter/llm/llama-3.1-70b-instruct.yaml b/api/core/model_runtime/model_providers/openrouter/llm/llama-3.1-70b-instruct.yaml index 12037411b1..768ab5ecbb 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/llama-3.1-70b-instruct.yaml +++ b/api/core/model_runtime/model_providers/openrouter/llm/llama-3.1-70b-instruct.yaml @@ -10,6 +10,15 @@ parameter_rules: use_template: temperature - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: max_tokens use_template: max_tokens required: true diff --git a/api/core/model_runtime/model_providers/openrouter/llm/llama-3.1-8b-instruct.yaml b/api/core/model_runtime/model_providers/openrouter/llm/llama-3.1-8b-instruct.yaml index 6f06493f29..67b6b82b5d 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/llama-3.1-8b-instruct.yaml +++ b/api/core/model_runtime/model_providers/openrouter/llm/llama-3.1-8b-instruct.yaml @@ -10,6 +10,15 @@ parameter_rules: use_template: temperature - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: max_tokens use_template: max_tokens required: true diff --git a/api/core/model_runtime/model_providers/openrouter/llm/mistral-7b-instruct.yaml b/api/core/model_runtime/model_providers/openrouter/llm/mistral-7b-instruct.yaml index 012dfc55ce..d08c016e95 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/mistral-7b-instruct.yaml +++ b/api/core/model_runtime/model_providers/openrouter/llm/mistral-7b-instruct.yaml @@ -18,6 +18,15 @@ parameter_rules: default: 1 min: 0 max: 1 + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: max_tokens use_template: max_tokens default: 1024 diff --git a/api/core/model_runtime/model_providers/openrouter/llm/mixtral-8x22b-instruct.yaml b/api/core/model_runtime/model_providers/openrouter/llm/mixtral-8x22b-instruct.yaml index f4eb4e45d9..e3af0e64d8 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/mixtral-8x22b-instruct.yaml +++ b/api/core/model_runtime/model_providers/openrouter/llm/mixtral-8x22b-instruct.yaml @@ -18,6 +18,15 @@ parameter_rules: default: 1 min: 0 max: 1 + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: max_tokens use_template: max_tokens default: 1024 diff --git a/api/core/model_runtime/model_providers/openrouter/llm/mixtral-8x7b-instruct.yaml b/api/core/model_runtime/model_providers/openrouter/llm/mixtral-8x7b-instruct.yaml index 7871e1f7a0..095ea5a858 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/mixtral-8x7b-instruct.yaml +++ b/api/core/model_runtime/model_providers/openrouter/llm/mixtral-8x7b-instruct.yaml @@ -19,6 +19,15 @@ parameter_rules: default: 1 min: 0 max: 1 + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: max_tokens use_template: max_tokens default: 1024 diff --git a/api/core/model_runtime/model_providers/openrouter/llm/o1-mini.yaml b/api/core/model_runtime/model_providers/openrouter/llm/o1-mini.yaml index 85a918ff5e..f4202ee814 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/o1-mini.yaml +++ b/api/core/model_runtime/model_providers/openrouter/llm/o1-mini.yaml @@ -12,6 +12,15 @@ parameter_rules: use_template: temperature - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: presence_penalty use_template: presence_penalty - name: frequency_penalty diff --git a/api/core/model_runtime/model_providers/openrouter/llm/o1-preview.yaml b/api/core/model_runtime/model_providers/openrouter/llm/o1-preview.yaml index 74b0a511be..1281b84286 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/o1-preview.yaml +++ b/api/core/model_runtime/model_providers/openrouter/llm/o1-preview.yaml @@ -12,6 +12,15 @@ parameter_rules: use_template: temperature - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: presence_penalty use_template: presence_penalty - name: frequency_penalty diff --git a/api/core/model_runtime/model_providers/openrouter/llm/qwen2-72b-instruct.yaml b/api/core/model_runtime/model_providers/openrouter/llm/qwen2-72b-instruct.yaml index 7b75fcb0c9..b6058138d3 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/qwen2-72b-instruct.yaml +++ b/api/core/model_runtime/model_providers/openrouter/llm/qwen2-72b-instruct.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/openrouter/llm/qwen2.5-72b-instruct.yaml b/api/core/model_runtime/model_providers/openrouter/llm/qwen2.5-72b-instruct.yaml index f141a40a00..5392b11168 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/qwen2.5-72b-instruct.yaml +++ b/api/core/model_runtime/model_providers/openrouter/llm/qwen2.5-72b-instruct.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/deepdeek-coder-v2-instruct.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/deepdeek-coder-v2-instruct.yaml index d4431179e5..d5f23776ea 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/deepdeek-coder-v2-instruct.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/deepdeek-coder-v2-instruct.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/deepseek-v2-chat.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/deepseek-v2-chat.yaml index caa6508b5e..7aa684ef38 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/deepseek-v2-chat.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/deepseek-v2-chat.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/deepseek-v2.5.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/deepseek-v2.5.yaml index 1c8e15ae52..b30fa3e2d1 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/deepseek-v2.5.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/deepseek-v2.5.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/gemma-2-27b-it.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/gemma-2-27b-it.yaml index 2840e3dcf4..f2a1f64bfb 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/gemma-2-27b-it.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/gemma-2-27b-it.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/gemma-2-9b-it.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/gemma-2-9b-it.yaml index d7e19b46f6..b096b9b647 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/gemma-2-9b-it.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/gemma-2-9b-it.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/glm4-9b-chat.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/glm4-9b-chat.yaml index 9b32a02477..87acc557b7 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/glm4-9b-chat.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/glm4-9b-chat.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/internlm2_5-20b-chat.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/internlm2_5-20b-chat.yaml index d9663582e5..60157c2b46 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/internlm2_5-20b-chat.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/internlm2_5-20b-chat.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/internlm2_5-7b-chat.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/internlm2_5-7b-chat.yaml index 73ad4480aa..faf4af7ea3 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/internlm2_5-7b-chat.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/internlm2_5-7b-chat.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3-70b-instruct.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3-70b-instruct.yaml index 9993d781ac..d01770cb01 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3-70b-instruct.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3-70b-instruct.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3-8b-instruct.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3-8b-instruct.yaml index 60e3764789..3cd75d89e8 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3-8b-instruct.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3-8b-instruct.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3.1-405b-instruct.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3.1-405b-instruct.yaml index f992660aa2..3506a70bcc 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3.1-405b-instruct.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3.1-405b-instruct.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3.1-70b-instruct.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3.1-70b-instruct.yaml index 1c69d63a40..994a754a82 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3.1-70b-instruct.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3.1-70b-instruct.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3.1-8b-instruct.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3.1-8b-instruct.yaml index a97002a5ca..ebfa9aac9d 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3.1-8b-instruct.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3.1-8b-instruct.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/mistral-7b-instruct-v0.2.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/mistral-7b-instruct-v0.2.yaml index 89fb153ba0..a71d8688a8 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/mistral-7b-instruct-v0.2.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/mistral-7b-instruct-v0.2.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/mistral-8x7b-instruct-v0.1.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/mistral-8x7b-instruct-v0.1.yaml index 2785e7496f..db45a75c6d 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/mistral-8x7b-instruct-v0.1.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/mistral-8x7b-instruct-v0.1.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-1.5b-instruct.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-1.5b-instruct.yaml index f6c976af8e..bec5d37c57 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-1.5b-instruct.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-1.5b-instruct.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-57b-a14b-instruct.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-57b-a14b-instruct.yaml index a996e919ea..b2461335f8 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-57b-a14b-instruct.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-57b-a14b-instruct.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-72b-instruct.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-72b-instruct.yaml index a6e2c22dac..e0f23bd89e 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-72b-instruct.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-72b-instruct.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-7b-instruct.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-7b-instruct.yaml index d8bea5e129..47a9da8119 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-7b-instruct.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-7b-instruct.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2.5-14b-instruct.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2.5-14b-instruct.yaml index 02a401464b..9cc5ac4c91 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2.5-14b-instruct.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2.5-14b-instruct.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2.5-32b-instruct.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2.5-32b-instruct.yaml index d084617e7d..c7fb21e9e1 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2.5-32b-instruct.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2.5-32b-instruct.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2.5-72b-instruct.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2.5-72b-instruct.yaml index dfbad2494c..03136c88a1 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2.5-72b-instruct.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2.5-72b-instruct.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2.5-7b-instruct.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2.5-7b-instruct.yaml index cdc8ffc4d2..99412adde7 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2.5-7b-instruct.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2.5-7b-instruct.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/yi-1.5-34b-chat.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/yi-1.5-34b-chat.yaml index 864ba46f1a..3e25f82369 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/yi-1.5-34b-chat.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/yi-1.5-34b-chat.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/yi-1.5-6b-chat.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/yi-1.5-6b-chat.yaml index fe4c8b4b3e..827b2ce1e5 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/yi-1.5-6b-chat.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/yi-1.5-6b-chat.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/yi-1.5-9b-chat.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/yi-1.5-9b-chat.yaml index c61f0dc53f..112fcbfe97 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/yi-1.5-9b-chat.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/yi-1.5-9b-chat.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: From 28de676956615c912c591ecb6358644122e1f7fc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20Sacrist=C3=A1n?= Date: Fri, 18 Oct 2024 02:23:36 +0200 Subject: [PATCH 10/12] controller test (#9469) --- api/app.py | 205 +---------------- api/app_factory.py | 213 ++++++++++++++++++ .../controllers/app_fixture.py | 24 ++ .../controllers/test_controllers.py | 10 + 4 files changed, 249 insertions(+), 203 deletions(-) create mode 100644 api/app_factory.py create mode 100644 api/tests/integration_tests/controllers/app_fixture.py create mode 100644 api/tests/integration_tests/controllers/test_controllers.py diff --git a/api/app.py b/api/app.py index 52dd492225..7fef62cd38 100644 --- a/api/app.py +++ b/api/app.py @@ -10,44 +10,19 @@ if os.environ.get("DEBUG", "false").lower() != "true": grpc.experimental.gevent.init_gevent() import json -import logging -import sys import threading import time import warnings -from logging.handlers import RotatingFileHandler -from flask import Flask, Response, request -from flask_cors import CORS -from werkzeug.exceptions import Unauthorized +from flask import Response -import contexts -from commands import register_commands -from configs import dify_config +from app_factory import create_app # DO NOT REMOVE BELOW from events import event_handlers # noqa: F401 -from extensions import ( - ext_celery, - ext_code_based_extension, - ext_compress, - ext_database, - ext_hosting_provider, - ext_login, - ext_mail, - ext_migrate, - ext_proxy_fix, - ext_redis, - ext_sentry, - ext_storage, -) -from extensions.ext_database import db -from extensions.ext_login import login_manager -from libs.passport import PassportService # TODO: Find a way to avoid importing models here from models import account, dataset, model, source, task, tool, tools, web # noqa: F401 -from services.account_service import AccountService # DO NOT REMOVE ABOVE @@ -60,188 +35,12 @@ if hasattr(time, "tzset"): time.tzset() -class DifyApp(Flask): - pass - - # ------------- # Configuration # ------------- - - config_type = os.getenv("EDITION", default="SELF_HOSTED") # ce edition first -# ---------------------------- -# Application Factory Function -# ---------------------------- - - -def create_flask_app_with_configs() -> Flask: - """ - create a raw flask app - with configs loaded from .env file - """ - dify_app = DifyApp(__name__) - dify_app.config.from_mapping(dify_config.model_dump()) - - # populate configs into system environment variables - for key, value in dify_app.config.items(): - if isinstance(value, str): - os.environ[key] = value - elif isinstance(value, int | float | bool): - os.environ[key] = str(value) - elif value is None: - os.environ[key] = "" - - return dify_app - - -def create_app() -> Flask: - app = create_flask_app_with_configs() - - app.secret_key = app.config["SECRET_KEY"] - - log_handlers = None - log_file = app.config.get("LOG_FILE") - if log_file: - log_dir = os.path.dirname(log_file) - os.makedirs(log_dir, exist_ok=True) - log_handlers = [ - RotatingFileHandler( - filename=log_file, - maxBytes=1024 * 1024 * 1024, - backupCount=5, - ), - logging.StreamHandler(sys.stdout), - ] - - logging.basicConfig( - level=app.config.get("LOG_LEVEL"), - format=app.config.get("LOG_FORMAT"), - datefmt=app.config.get("LOG_DATEFORMAT"), - handlers=log_handlers, - force=True, - ) - log_tz = app.config.get("LOG_TZ") - if log_tz: - from datetime import datetime - - import pytz - - timezone = pytz.timezone(log_tz) - - def time_converter(seconds): - return datetime.utcfromtimestamp(seconds).astimezone(timezone).timetuple() - - for handler in logging.root.handlers: - handler.formatter.converter = time_converter - initialize_extensions(app) - register_blueprints(app) - register_commands(app) - - return app - - -def initialize_extensions(app): - # Since the application instance is now created, pass it to each Flask - # extension instance to bind it to the Flask application instance (app) - ext_compress.init_app(app) - ext_code_based_extension.init() - ext_database.init_app(app) - ext_migrate.init(app, db) - ext_redis.init_app(app) - ext_storage.init_app(app) - ext_celery.init_app(app) - ext_login.init_app(app) - ext_mail.init_app(app) - ext_hosting_provider.init_app(app) - ext_sentry.init_app(app) - ext_proxy_fix.init_app(app) - - -# Flask-Login configuration -@login_manager.request_loader -def load_user_from_request(request_from_flask_login): - """Load user based on the request.""" - if request.blueprint not in {"console", "inner_api"}: - return None - # Check if the user_id contains a dot, indicating the old format - auth_header = request.headers.get("Authorization", "") - if not auth_header: - auth_token = request.args.get("_token") - if not auth_token: - raise Unauthorized("Invalid Authorization token.") - else: - if " " not in auth_header: - raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") - auth_scheme, auth_token = auth_header.split(None, 1) - auth_scheme = auth_scheme.lower() - if auth_scheme != "bearer": - raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") - - decoded = PassportService().verify(auth_token) - user_id = decoded.get("user_id") - - logged_in_account = AccountService.load_logged_in_account(account_id=user_id) - if logged_in_account: - contexts.tenant_id.set(logged_in_account.current_tenant_id) - return logged_in_account - - -@login_manager.unauthorized_handler -def unauthorized_handler(): - """Handle unauthorized requests.""" - return Response( - json.dumps({"code": "unauthorized", "message": "Unauthorized."}), - status=401, - content_type="application/json", - ) - - -# register blueprint routers -def register_blueprints(app): - from controllers.console import bp as console_app_bp - from controllers.files import bp as files_bp - from controllers.inner_api import bp as inner_api_bp - from controllers.service_api import bp as service_api_bp - from controllers.web import bp as web_bp - - CORS( - service_api_bp, - allow_headers=["Content-Type", "Authorization", "X-App-Code"], - methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], - ) - app.register_blueprint(service_api_bp) - - CORS( - web_bp, - resources={r"/*": {"origins": app.config["WEB_API_CORS_ALLOW_ORIGINS"]}}, - supports_credentials=True, - allow_headers=["Content-Type", "Authorization", "X-App-Code"], - methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], - expose_headers=["X-Version", "X-Env"], - ) - - app.register_blueprint(web_bp) - - CORS( - console_app_bp, - resources={r"/*": {"origins": app.config["CONSOLE_CORS_ALLOW_ORIGINS"]}}, - supports_credentials=True, - allow_headers=["Content-Type", "Authorization"], - methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], - expose_headers=["X-Version", "X-Env"], - ) - - app.register_blueprint(console_app_bp) - - CORS(files_bp, allow_headers=["Content-Type"], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"]) - app.register_blueprint(files_bp) - - app.register_blueprint(inner_api_bp) - - # create app app = create_app() celery = app.extensions["celery"] diff --git a/api/app_factory.py b/api/app_factory.py new file mode 100644 index 0000000000..04654c2699 --- /dev/null +++ b/api/app_factory.py @@ -0,0 +1,213 @@ +import os + +if os.environ.get("DEBUG", "false").lower() != "true": + from gevent import monkey + + monkey.patch_all() + + import grpc.experimental.gevent + + grpc.experimental.gevent.init_gevent() + +import json +import logging +import sys +from logging.handlers import RotatingFileHandler + +from flask import Flask, Response, request +from flask_cors import CORS +from werkzeug.exceptions import Unauthorized + +import contexts +from commands import register_commands +from configs import dify_config +from extensions import ( + ext_celery, + ext_code_based_extension, + ext_compress, + ext_database, + ext_hosting_provider, + ext_login, + ext_mail, + ext_migrate, + ext_proxy_fix, + ext_redis, + ext_sentry, + ext_storage, +) +from extensions.ext_database import db +from extensions.ext_login import login_manager +from libs.passport import PassportService +from services.account_service import AccountService + + +class DifyApp(Flask): + pass + + +# ---------------------------- +# Application Factory Function +# ---------------------------- +def create_flask_app_with_configs() -> Flask: + """ + create a raw flask app + with configs loaded from .env file + """ + dify_app = DifyApp(__name__) + dify_app.config.from_mapping(dify_config.model_dump()) + + # populate configs into system environment variables + for key, value in dify_app.config.items(): + if isinstance(value, str): + os.environ[key] = value + elif isinstance(value, int | float | bool): + os.environ[key] = str(value) + elif value is None: + os.environ[key] = "" + + return dify_app + + +def create_app() -> Flask: + app = create_flask_app_with_configs() + + app.secret_key = app.config["SECRET_KEY"] + + log_handlers = None + log_file = app.config.get("LOG_FILE") + if log_file: + log_dir = os.path.dirname(log_file) + os.makedirs(log_dir, exist_ok=True) + log_handlers = [ + RotatingFileHandler( + filename=log_file, + maxBytes=1024 * 1024 * 1024, + backupCount=5, + ), + logging.StreamHandler(sys.stdout), + ] + + logging.basicConfig( + level=app.config.get("LOG_LEVEL"), + format=app.config.get("LOG_FORMAT"), + datefmt=app.config.get("LOG_DATEFORMAT"), + handlers=log_handlers, + force=True, + ) + log_tz = app.config.get("LOG_TZ") + if log_tz: + from datetime import datetime + + import pytz + + timezone = pytz.timezone(log_tz) + + def time_converter(seconds): + return datetime.utcfromtimestamp(seconds).astimezone(timezone).timetuple() + + for handler in logging.root.handlers: + handler.formatter.converter = time_converter + initialize_extensions(app) + register_blueprints(app) + register_commands(app) + + return app + + +def initialize_extensions(app): + # Since the application instance is now created, pass it to each Flask + # extension instance to bind it to the Flask application instance (app) + ext_compress.init_app(app) + ext_code_based_extension.init() + ext_database.init_app(app) + ext_migrate.init(app, db) + ext_redis.init_app(app) + ext_storage.init_app(app) + ext_celery.init_app(app) + ext_login.init_app(app) + ext_mail.init_app(app) + ext_hosting_provider.init_app(app) + ext_sentry.init_app(app) + ext_proxy_fix.init_app(app) + + +# Flask-Login configuration +@login_manager.request_loader +def load_user_from_request(request_from_flask_login): + """Load user based on the request.""" + if request.blueprint not in {"console", "inner_api"}: + return None + # Check if the user_id contains a dot, indicating the old format + auth_header = request.headers.get("Authorization", "") + if not auth_header: + auth_token = request.args.get("_token") + if not auth_token: + raise Unauthorized("Invalid Authorization token.") + else: + if " " not in auth_header: + raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") + auth_scheme, auth_token = auth_header.split(None, 1) + auth_scheme = auth_scheme.lower() + if auth_scheme != "bearer": + raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") + + decoded = PassportService().verify(auth_token) + user_id = decoded.get("user_id") + + logged_in_account = AccountService.load_logged_in_account(account_id=user_id) + if logged_in_account: + contexts.tenant_id.set(logged_in_account.current_tenant_id) + return logged_in_account + + +@login_manager.unauthorized_handler +def unauthorized_handler(): + """Handle unauthorized requests.""" + return Response( + json.dumps({"code": "unauthorized", "message": "Unauthorized."}), + status=401, + content_type="application/json", + ) + + +# register blueprint routers +def register_blueprints(app): + from controllers.console import bp as console_app_bp + from controllers.files import bp as files_bp + from controllers.inner_api import bp as inner_api_bp + from controllers.service_api import bp as service_api_bp + from controllers.web import bp as web_bp + + CORS( + service_api_bp, + allow_headers=["Content-Type", "Authorization", "X-App-Code"], + methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], + ) + app.register_blueprint(service_api_bp) + + CORS( + web_bp, + resources={r"/*": {"origins": app.config["WEB_API_CORS_ALLOW_ORIGINS"]}}, + supports_credentials=True, + allow_headers=["Content-Type", "Authorization", "X-App-Code"], + methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], + expose_headers=["X-Version", "X-Env"], + ) + + app.register_blueprint(web_bp) + + CORS( + console_app_bp, + resources={r"/*": {"origins": app.config["CONSOLE_CORS_ALLOW_ORIGINS"]}}, + supports_credentials=True, + allow_headers=["Content-Type", "Authorization"], + methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], + expose_headers=["X-Version", "X-Env"], + ) + + app.register_blueprint(console_app_bp) + + CORS(files_bp, allow_headers=["Content-Type"], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"]) + app.register_blueprint(files_bp) + + app.register_blueprint(inner_api_bp) diff --git a/api/tests/integration_tests/controllers/app_fixture.py b/api/tests/integration_tests/controllers/app_fixture.py new file mode 100644 index 0000000000..93065ee95c --- /dev/null +++ b/api/tests/integration_tests/controllers/app_fixture.py @@ -0,0 +1,24 @@ +import pytest + +from app_factory import create_app + +mock_user = type( + "MockUser", + (object,), + { + "is_authenticated": True, + "id": "123", + "is_editor": True, + "is_dataset_editor": True, + "status": "active", + "get_id": "123", + "current_tenant_id": "9d2074fc-6f86-45a9-b09d-6ecc63b9056b", + }, +) + + +@pytest.fixture +def app(): + app = create_app() + app.config["LOGIN_DISABLED"] = True + return app diff --git a/api/tests/integration_tests/controllers/test_controllers.py b/api/tests/integration_tests/controllers/test_controllers.py new file mode 100644 index 0000000000..6371694694 --- /dev/null +++ b/api/tests/integration_tests/controllers/test_controllers.py @@ -0,0 +1,10 @@ +from unittest.mock import patch + +from app_fixture import app, mock_user + + +def test_post_requires_login(app): + with app.test_client() as client: + with patch("flask_login.utils._get_user", mock_user): + response = client.get("/console/api/data-source/integrates") + assert response.status_code == 200 From bd27b4c1620ecc0694be89aa466aceeddc0e5fed Mon Sep 17 00:00:00 2001 From: horochx <32632779+horochx@users.noreply.github.com> Date: Fri, 18 Oct 2024 08:24:07 +0800 Subject: [PATCH 11/12] fix fetch apps (#9453) --- web/app/(commonLayout)/apps/Apps.tsx | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/web/app/(commonLayout)/apps/Apps.tsx b/web/app/(commonLayout)/apps/Apps.tsx index 132096c6b4..accf6c67f2 100644 --- a/web/app/(commonLayout)/apps/Apps.tsx +++ b/web/app/(commonLayout)/apps/Apps.tsx @@ -87,15 +87,15 @@ const Apps = () => { localStorage.removeItem(NEED_REFRESH_APP_LIST_KEY) mutate() } - }, []) + }, [mutate, t]) useEffect(() => { if (isCurrentWorkspaceDatasetOperator) return router.replace('/datasets') - }, [isCurrentWorkspaceDatasetOperator]) + }, [router, isCurrentWorkspaceDatasetOperator]) - const hasMore = data?.at(-1)?.has_more ?? true useEffect(() => { + const hasMore = data?.at(-1)?.has_more ?? true let observer: IntersectionObserver | undefined if (anchorRef.current) { observer = new IntersectionObserver((entries) => { @@ -105,7 +105,7 @@ const Apps = () => { observer.observe(anchorRef.current) } return () => observer?.disconnect() - }, [isLoading, setSize, anchorRef, mutate, hasMore]) + }, [isLoading, setSize, anchorRef, mutate, data]) const { run: handleSearch } = useDebounceFn(() => { setSearchKeywords(keywords) From b9bf60ea23d4a09c700fa1e21fcd4b3b6e4c390b Mon Sep 17 00:00:00 2001 From: zhuhao <37029601+hwzhuhao@users.noreply.github.com> Date: Fri, 18 Oct 2024 12:30:25 +0800 Subject: [PATCH 12/12] fix: resolve the error with the db-pool-stat endpoint (#9478) --- api/app.py | 1 + 1 file changed, 1 insertion(+) diff --git a/api/app.py b/api/app.py index 7fef62cd38..a3efabf06c 100644 --- a/api/app.py +++ b/api/app.py @@ -20,6 +20,7 @@ from app_factory import create_app # DO NOT REMOVE BELOW from events import event_handlers # noqa: F401 +from extensions.ext_database import db # TODO: Find a way to avoid importing models here from models import account, dataset, model, source, task, tool, tools, web # noqa: F401