diff --git a/api/core/embedding/embedding_constant.py b/api/core/entities/embedding_type.py similarity index 100% rename from api/core/embedding/embedding_constant.py rename to api/core/entities/embedding_type.py diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 74b4452362..e394233d2c 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -3,7 +3,7 @@ import os from collections.abc import Callable, Generator, Sequence from typing import IO, Optional, Union, cast -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_entities import ModelLoadBalancingConfiguration from core.errors.error import ProviderTokenNotInitError diff --git a/api/core/model_runtime/model_providers/__base/text_embedding_model.py b/api/core/model_runtime/model_providers/__base/text_embedding_model.py index a948dca20d..2d38fba955 100644 --- a/api/core/model_runtime/model_providers/__base/text_embedding_model.py +++ b/api/core/model_runtime/model_providers/__base/text_embedding_model.py @@ -4,7 +4,7 @@ from typing import Optional from pydantic import ConfigDict -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult from core.model_runtime.model_providers.__base.ai_model import AIModel diff --git a/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py index 8701a38050..c45ce87ea7 100644 --- a/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py @@ -7,7 +7,7 @@ import numpy as np import tiktoken from openai import AzureOpenAI -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import AIModelEntity, PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError diff --git a/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py index 56b9be1c36..1ace68d2b9 100644 --- a/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py @@ -4,7 +4,7 @@ from typing import Optional from requests import post -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.invoke import ( diff --git a/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py index d9c5726592..2f998d8bda 100644 --- a/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py @@ -13,7 +13,7 @@ from botocore.exceptions import ( UnknownServiceError, ) -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.invoke import ( diff --git a/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py index 4da2080690..5fd4d637be 100644 --- a/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py @@ -5,7 +5,7 @@ import cohere import numpy as np from cohere.core import RequestOptions -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.invoke import ( diff --git a/api/core/model_runtime/model_providers/fireworks/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/fireworks/text_embedding/text_embedding.py index cdce69ff38..c745a7e978 100644 --- a/api/core/model_runtime/model_providers/fireworks/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/fireworks/text_embedding/text_embedding.py @@ -5,7 +5,7 @@ from typing import Optional, Union import numpy as np from openai import OpenAI -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError diff --git a/api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py index b2e6d1b652..8278d1e64d 100644 --- a/api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py @@ -6,7 +6,7 @@ import numpy as np import requests from huggingface_hub import HfApi, InferenceClient -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult diff --git a/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py index b8ff3ca549..6b43934538 100644 --- a/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py @@ -1,7 +1,7 @@ import time from typing import Optional -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult diff --git a/api/core/model_runtime/model_providers/hunyuan/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/hunyuan/text_embedding/text_embedding.py index 75701ebc54..b6d857cb37 100644 --- a/api/core/model_runtime/model_providers/hunyuan/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/hunyuan/text_embedding/text_embedding.py @@ -9,7 +9,7 @@ from tencentcloud.common.profile.client_profile import ClientProfile from tencentcloud.common.profile.http_profile import HttpProfile from tencentcloud.hunyuan.v20230901 import hunyuan_client, models -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.invoke import ( diff --git a/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py index b397129512..49c558f4a4 100644 --- a/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py @@ -4,7 +4,7 @@ from typing import Optional from requests import post -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult diff --git a/api/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py index ab8ca76c2f..b4dfc1a4de 100644 --- a/api/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py @@ -5,7 +5,7 @@ from typing import Optional from requests import post from yarl import URL -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult diff --git a/api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py index d031bfa04d..29be5888af 100644 --- a/api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py @@ -4,7 +4,7 @@ from typing import Optional from requests import post -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.invoke import ( diff --git a/api/core/model_runtime/model_providers/mixedbread/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/mixedbread/text_embedding/text_embedding.py index 68b7b448bf..ca949cb953 100644 --- a/api/core/model_runtime/model_providers/mixedbread/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/mixedbread/text_embedding/text_embedding.py @@ -4,7 +4,7 @@ from typing import Optional import requests -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult diff --git a/api/core/model_runtime/model_providers/nomic/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/nomic/text_embedding/text_embedding.py index 857dfb5f41..56a707333c 100644 --- a/api/core/model_runtime/model_providers/nomic/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/nomic/text_embedding/text_embedding.py @@ -5,7 +5,7 @@ from typing import Optional from nomic import embed from nomic import login as nomic_login -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import ( EmbeddingUsage, diff --git a/api/core/model_runtime/model_providers/nvidia/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/nvidia/text_embedding/text_embedding.py index 936ceb8dd2..04363e11be 100644 --- a/api/core/model_runtime/model_providers/nvidia/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/nvidia/text_embedding/text_embedding.py @@ -4,7 +4,7 @@ from typing import Optional from requests import post -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.invoke import ( diff --git a/api/core/model_runtime/model_providers/oci/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/oci/text_embedding/text_embedding.py index 4de9296cca..50fa63768c 100644 --- a/api/core/model_runtime/model_providers/oci/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/oci/text_embedding/text_embedding.py @@ -6,7 +6,7 @@ from typing import Optional import numpy as np import oci -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.invoke import ( diff --git a/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py index 5cf3f1c6fa..a16c91cd7e 100644 --- a/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py @@ -8,7 +8,7 @@ from urllib.parse import urljoin import numpy as np import requests -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import ( AIModelEntity, diff --git a/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py index 16f1a0cfa1..bec01fe679 100644 --- a/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py @@ -6,7 +6,7 @@ import numpy as np import tiktoken from openai import OpenAI -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py index 64fa6aaa3c..c2b7297aac 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py @@ -7,7 +7,7 @@ from urllib.parse import urljoin import numpy as np import requests -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import ( AIModelEntity, diff --git a/api/core/model_runtime/model_providers/openllm/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/openllm/text_embedding/text_embedding.py index c5d4330912..43a2e948e2 100644 --- a/api/core/model_runtime/model_providers/openllm/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/openllm/text_embedding/text_embedding.py @@ -5,7 +5,7 @@ from typing import Optional from requests import post from requests.exceptions import ConnectionError, InvalidSchema, MissingSchema -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.invoke import ( diff --git a/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py index 1e86f351c8..d78bdaa75e 100644 --- a/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py @@ -7,7 +7,7 @@ from urllib.parse import urljoin import numpy as np import requests -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import ( AIModelEntity, diff --git a/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py index 9f724a77ac..c4e9d0b9c6 100644 --- a/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py @@ -4,7 +4,7 @@ from typing import Optional from replicate import Client as ReplicateClient -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult diff --git a/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py index 8f993ce672..ae7d805b4e 100644 --- a/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py @@ -6,7 +6,7 @@ from typing import Any, Optional import boto3 -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult diff --git a/api/core/model_runtime/model_providers/siliconflow/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/siliconflow/text_embedding/text_embedding.py index c5dcc12610..5e29a4827a 100644 --- a/api/core/model_runtime/model_providers/siliconflow/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/siliconflow/text_embedding/text_embedding.py @@ -1,6 +1,6 @@ from typing import Optional -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult from core.model_runtime.model_providers.openai_api_compatible.text_embedding.text_embedding import ( OAICompatEmbeddingModel, diff --git a/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py index 736cd44df8..2ef7f3f577 100644 --- a/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py @@ -4,7 +4,7 @@ from typing import Optional import dashscope import numpy as np -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import ( EmbeddingUsage, diff --git a/api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py index b6509cd26c..7dd495b55e 100644 --- a/api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py @@ -7,7 +7,7 @@ import numpy as np from openai import OpenAI from tokenizers import Tokenizer -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError diff --git a/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py index fce9544df0..43233e6126 100644 --- a/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py @@ -9,7 +9,7 @@ from google.cloud import aiplatform from google.oauth2 import service_account from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import ( AIModelEntity, diff --git a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py index 0dd4037c95..4d13e4708b 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py @@ -2,7 +2,7 @@ import time from decimal import Decimal from typing import Optional -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import ( AIModelEntity, diff --git a/api/core/model_runtime/model_providers/voyage/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/voyage/text_embedding/text_embedding.py index a8a4d3c15b..e69c9fccba 100644 --- a/api/core/model_runtime/model_providers/voyage/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/voyage/text_embedding/text_embedding.py @@ -4,7 +4,7 @@ from typing import Optional import requests -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult diff --git a/api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py index c21d0c0552..19135deb27 100644 --- a/api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py @@ -7,7 +7,7 @@ from typing import Any, Optional import numpy as np from requests import Response, post -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.invoke import InvokeError diff --git a/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py index ddc21b365c..f64b9c50af 100644 --- a/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py @@ -3,7 +3,7 @@ from typing import Optional from xinference_client.client.restful.restful_client import Client, RESTfulEmbeddingModelHandle -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult diff --git a/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py index 5a34a3d593..f629b62fd5 100644 --- a/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py @@ -3,7 +3,7 @@ from typing import Optional from zhipuai import ZhipuAI -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError diff --git a/api/core/rag/data_post_processor/data_post_processor.py b/api/core/rag/data_post_processor/data_post_processor.py index b1d6f93cff..992415657e 100644 --- a/api/core/rag/data_post_processor/data_post_processor.py +++ b/api/core/rag/data_post_processor/data_post_processor.py @@ -1,14 +1,14 @@ from typing import Optional -from core.model_manager import ModelManager +from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.rag.data_post_processor.reorder import ReorderRunner from core.rag.models.document import Document -from core.rag.rerank.constants.rerank_mode import RerankMode from core.rag.rerank.entity.weight import KeywordSetting, VectorSetting, Weights -from core.rag.rerank.rerank_model import RerankModelRunner -from core.rag.rerank.weight_rerank import WeightRerankRunner +from core.rag.rerank.rerank_base import BaseRerankRunner +from core.rag.rerank.rerank_factory import RerankRunnerFactory +from core.rag.rerank.rerank_type import RerankMode class DataPostProcessor: @@ -47,11 +47,12 @@ class DataPostProcessor: tenant_id: str, reranking_model: Optional[dict] = None, weights: Optional[dict] = None, - ) -> Optional[RerankModelRunner | WeightRerankRunner]: + ) -> Optional[BaseRerankRunner]: if reranking_mode == RerankMode.WEIGHTED_SCORE.value and weights: - return WeightRerankRunner( - tenant_id, - Weights( + runner = RerankRunnerFactory.create_rerank_runner( + runner_type=reranking_mode, + tenant_id=tenant_id, + weights=Weights( vector_setting=VectorSetting( vector_weight=weights["vector_setting"]["vector_weight"], embedding_provider_name=weights["vector_setting"]["embedding_provider_name"], @@ -62,23 +63,33 @@ class DataPostProcessor: ), ), ) + return runner elif reranking_mode == RerankMode.RERANKING_MODEL.value: - if reranking_model: - try: - model_manager = ModelManager() - rerank_model_instance = model_manager.get_model_instance( - tenant_id=tenant_id, - provider=reranking_model["reranking_provider_name"], - model_type=ModelType.RERANK, - model=reranking_model["reranking_model_name"], - ) - except InvokeAuthorizationError: - return None - return RerankModelRunner(rerank_model_instance) - return None + rerank_model_instance = self._get_rerank_model_instance(tenant_id, reranking_model) + if rerank_model_instance is None: + return None + runner = RerankRunnerFactory.create_rerank_runner( + runner_type=reranking_mode, rerank_model_instance=rerank_model_instance + ) + return runner return None def _get_reorder_runner(self, reorder_enabled) -> Optional[ReorderRunner]: if reorder_enabled: return ReorderRunner() return None + + def _get_rerank_model_instance(self, tenant_id: str, reranking_model: Optional[dict]) -> ModelInstance | None: + if reranking_model: + try: + model_manager = ModelManager() + rerank_model_instance = model_manager.get_model_instance( + tenant_id=tenant_id, + provider=reranking_model["reranking_provider_name"], + model_type=ModelType.RERANK, + model=reranking_model["reranking_model_name"], + ) + return rerank_model_instance + except InvokeAuthorizationError: + return None + return None diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index d3fd0c672a..3affbd2d0a 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -6,7 +6,7 @@ from flask import Flask, current_app from core.rag.data_post_processor.data_post_processor import DataPostProcessor from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.vdb.vector_factory import Vector -from core.rag.rerank.constants.rerank_mode import RerankMode +from core.rag.rerank.rerank_type import RerankMode from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_database import db from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py index 6dcd98dcfd..c77cb87376 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py +++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py @@ -9,10 +9,10 @@ _import_err_msg = ( ) from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/baidu/baidu_vector.py b/api/core/rag/datasource/vdb/baidu/baidu_vector.py index 543cfa67b3..1d4bfef76d 100644 --- a/api/core/rag/datasource/vdb/baidu/baidu_vector.py +++ b/api/core/rag/datasource/vdb/baidu/baidu_vector.py @@ -12,10 +12,10 @@ from pymochow.model.schema import Field, HNSWParams, Schema, VectorIndex from pymochow.model.table import AnnSearch, HNSWSearchParams, Partition, Row from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/chroma/chroma_vector.py b/api/core/rag/datasource/vdb/chroma/chroma_vector.py index 610aa498ab..a9e1486edd 100644 --- a/api/core/rag/datasource/vdb/chroma/chroma_vector.py +++ b/api/core/rag/datasource/vdb/chroma/chroma_vector.py @@ -6,10 +6,10 @@ from chromadb import QueryResult, Settings from pydantic import BaseModel from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py index f420373d5b..052a187225 100644 --- a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py +++ b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py @@ -9,11 +9,11 @@ from elasticsearch import Elasticsearch from flask import current_app from pydantic import BaseModel, model_validator -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/milvus/milvus_vector.py b/api/core/rag/datasource/vdb/milvus/milvus_vector.py index bdca59f869..080a1ef567 100644 --- a/api/core/rag/datasource/vdb/milvus/milvus_vector.py +++ b/api/core/rag/datasource/vdb/milvus/milvus_vector.py @@ -7,11 +7,11 @@ from pymilvus import MilvusClient, MilvusException from pymilvus.milvus_client import IndexParams from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/myscale/myscale_vector.py b/api/core/rag/datasource/vdb/myscale/myscale_vector.py index b30aa7ca22..1fca926a2d 100644 --- a/api/core/rag/datasource/vdb/myscale/myscale_vector.py +++ b/api/core/rag/datasource/vdb/myscale/myscale_vector.py @@ -8,10 +8,10 @@ from clickhouse_connect import get_client from pydantic import BaseModel from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py index 8d2e0a86ab..0e0f107268 100644 --- a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py +++ b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py @@ -9,11 +9,11 @@ from opensearchpy.helpers import BulkIndexError from pydantic import BaseModel, model_validator from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/oracle/oraclevector.py b/api/core/rag/datasource/vdb/oracle/oraclevector.py index 84a4381cd1..4ced5d61e5 100644 --- a/api/core/rag/datasource/vdb/oracle/oraclevector.py +++ b/api/core/rag/datasource/vdb/oracle/oraclevector.py @@ -13,10 +13,10 @@ from nltk.corpus import stopwords from pydantic import BaseModel, model_validator from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py index a82a9b96dd..9233cd63dc 100644 --- a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py +++ b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py @@ -12,11 +12,11 @@ from sqlalchemy.dialects import postgresql from sqlalchemy.orm import Mapped, Session, mapped_column from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.pgvecto_rs.collection import CollectionORM from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/pgvector/pgvector.py b/api/core/rag/datasource/vdb/pgvector/pgvector.py index 6f336d27e7..40a9cdd136 100644 --- a/api/core/rag/datasource/vdb/pgvector/pgvector.py +++ b/api/core/rag/datasource/vdb/pgvector/pgvector.py @@ -8,10 +8,10 @@ import psycopg2.pool from pydantic import BaseModel, model_validator from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py index f418e3ca05..69d2aa4f76 100644 --- a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py +++ b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py @@ -20,11 +20,11 @@ from qdrant_client.http.models import ( from qdrant_client.local.qdrant_local import QdrantLocal from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_database import db from extensions.ext_redis import redis_client diff --git a/api/core/rag/datasource/vdb/relyt/relyt_vector.py b/api/core/rag/datasource/vdb/relyt/relyt_vector.py index 13a63784be..f373dcfeab 100644 --- a/api/core/rag/datasource/vdb/relyt/relyt_vector.py +++ b/api/core/rag/datasource/vdb/relyt/relyt_vector.py @@ -8,9 +8,9 @@ from sqlalchemy import text as sql_text from sqlalchemy.dialects.postgresql import JSON, TEXT from sqlalchemy.orm import Session -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from models.dataset import Dataset try: diff --git a/api/core/rag/datasource/vdb/tencent/tencent_vector.py b/api/core/rag/datasource/vdb/tencent/tencent_vector.py index 39e3a7f6cf..f971a9c5eb 100644 --- a/api/core/rag/datasource/vdb/tencent/tencent_vector.py +++ b/api/core/rag/datasource/vdb/tencent/tencent_vector.py @@ -8,10 +8,10 @@ from tcvectordb.model import index as vdb_index from tcvectordb.model.document import Filter from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py index 7837c5a4aa..1147e35ce8 100644 --- a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py +++ b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py @@ -9,10 +9,10 @@ from sqlalchemy import text as sql_text from sqlalchemy.orm import Session, declarative_base from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index 873b289027..fb956a16ed 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -2,12 +2,12 @@ from abc import ABC, abstractmethod from typing import Any, Optional from configs import dify_config -from core.embedding.cached_embedding import CacheEmbedding from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.cached_embedding import CacheEmbedding +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py b/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py index 5f60f10acb..4f927f2899 100644 --- a/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py +++ b/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py @@ -14,11 +14,11 @@ from volcengine.viking_db import ( ) from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.field import Field as vdb_Field from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py index 4009efe7a7..649cfbfea8 100644 --- a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py +++ b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py @@ -7,11 +7,11 @@ import weaviate from pydantic import BaseModel, model_validator from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset diff --git a/api/core/rag/embedding/__init__.py b/api/core/rag/embedding/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/embedding/cached_embedding.py b/api/core/rag/embedding/cached_embedding.py similarity index 97% rename from api/core/embedding/cached_embedding.py rename to api/core/rag/embedding/cached_embedding.py index 31d2171e72..b3e93ce760 100644 --- a/api/core/embedding/cached_embedding.py +++ b/api/core/rag/embedding/cached_embedding.py @@ -6,11 +6,11 @@ import numpy as np from sqlalchemy.exc import IntegrityError from configs import dify_config -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_manager import ModelInstance from core.model_runtime.entities.model_entities import ModelPropertyKey from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from core.rag.datasource.entity.embedding import Embeddings +from core.rag.embedding.embedding_base import Embeddings from extensions.ext_database import db from extensions.ext_redis import redis_client from libs import helper diff --git a/api/core/rag/datasource/entity/embedding.py b/api/core/rag/embedding/embedding_base.py similarity index 90% rename from api/core/rag/datasource/entity/embedding.py rename to api/core/rag/embedding/embedding_base.py index 126c1a3723..9f232ab910 100644 --- a/api/core/rag/datasource/entity/embedding.py +++ b/api/core/rag/embedding/embedding_base.py @@ -7,10 +7,12 @@ class Embeddings(ABC): @abstractmethod def embed_documents(self, texts: list[str]) -> list[list[float]]: """Embed search docs.""" + raise NotImplementedError @abstractmethod def embed_query(self, text: str) -> list[float]: """Embed query text.""" + raise NotImplementedError async def aembed_documents(self, texts: list[str]) -> list[list[float]]: """Asynchronous Embed search docs.""" diff --git a/api/core/rag/rerank/rerank_base.py b/api/core/rag/rerank/rerank_base.py new file mode 100644 index 0000000000..818b04b2ff --- /dev/null +++ b/api/core/rag/rerank/rerank_base.py @@ -0,0 +1,26 @@ +from abc import ABC, abstractmethod +from typing import Optional + +from core.rag.models.document import Document + + +class BaseRerankRunner(ABC): + @abstractmethod + def run( + self, + query: str, + documents: list[Document], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> list[Document]: + """ + Run rerank model + :param query: search query + :param documents: documents for reranking + :param score_threshold: score threshold + :param top_n: top n + :param user: unique user id if needed + :return: + """ + raise NotImplementedError diff --git a/api/core/rag/rerank/rerank_factory.py b/api/core/rag/rerank/rerank_factory.py new file mode 100644 index 0000000000..1a3cf85736 --- /dev/null +++ b/api/core/rag/rerank/rerank_factory.py @@ -0,0 +1,16 @@ +from core.rag.rerank.rerank_base import BaseRerankRunner +from core.rag.rerank.rerank_model import RerankModelRunner +from core.rag.rerank.rerank_type import RerankMode +from core.rag.rerank.weight_rerank import WeightRerankRunner + + +class RerankRunnerFactory: + @staticmethod + def create_rerank_runner(runner_type: str, *args, **kwargs) -> BaseRerankRunner: + match runner_type: + case RerankMode.RERANKING_MODEL.value: + return RerankModelRunner(*args, **kwargs) + case RerankMode.WEIGHTED_SCORE.value: + return WeightRerankRunner(*args, **kwargs) + case _: + raise ValueError(f"Unknown runner type: {runner_type}") diff --git a/api/core/rag/rerank/rerank_model.py b/api/core/rag/rerank/rerank_model.py index 27f86aed34..40ebf0befd 100644 --- a/api/core/rag/rerank/rerank_model.py +++ b/api/core/rag/rerank/rerank_model.py @@ -2,9 +2,10 @@ from typing import Optional from core.model_manager import ModelInstance from core.rag.models.document import Document +from core.rag.rerank.rerank_base import BaseRerankRunner -class RerankModelRunner: +class RerankModelRunner(BaseRerankRunner): def __init__(self, rerank_model_instance: ModelInstance) -> None: self.rerank_model_instance = rerank_model_instance diff --git a/api/core/rag/rerank/constants/rerank_mode.py b/api/core/rag/rerank/rerank_type.py similarity index 100% rename from api/core/rag/rerank/constants/rerank_mode.py rename to api/core/rag/rerank/rerank_type.py diff --git a/api/core/rag/rerank/weight_rerank.py b/api/core/rag/rerank/weight_rerank.py index 16d6b879a4..2e3fbe04e2 100644 --- a/api/core/rag/rerank/weight_rerank.py +++ b/api/core/rag/rerank/weight_rerank.py @@ -4,15 +4,16 @@ from typing import Optional import numpy as np -from core.embedding.cached_embedding import CacheEmbedding from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler +from core.rag.embedding.cached_embedding import CacheEmbedding from core.rag.models.document import Document from core.rag.rerank.entity.weight import VectorSetting, Weights +from core.rag.rerank.rerank_base import BaseRerankRunner -class WeightRerankRunner: +class WeightRerankRunner(BaseRerankRunner): def __init__(self, tenant_id: str, weights: Weights) -> None: self.tenant_id = tenant_id self.weights = weights