refactor: move the embedding to the rag module and abstract the rerank runner for extension (#9423)

This commit is contained in:
zhuhao 2024-10-17 19:12:42 +08:00 committed by GitHub
parent e7aecb89dd
commit b90ad587c2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
61 changed files with 135 additions and 78 deletions

View File

@ -3,7 +3,7 @@ import os
from collections.abc import Callable, Generator, Sequence from collections.abc import Callable, Generator, Sequence
from typing import IO, Optional, Union, cast from typing import IO, Optional, Union, cast
from core.embedding.embedding_constant import EmbeddingInputType from core.entities.embedding_type import EmbeddingInputType
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
from core.entities.provider_entities import ModelLoadBalancingConfiguration from core.entities.provider_entities import ModelLoadBalancingConfiguration
from core.errors.error import ProviderTokenNotInitError from core.errors.error import ProviderTokenNotInitError

View File

@ -4,7 +4,7 @@ from typing import Optional
from pydantic import ConfigDict from pydantic import ConfigDict
from core.embedding.embedding_constant import EmbeddingInputType from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.model_providers.__base.ai_model import AIModel from core.model_runtime.model_providers.__base.ai_model import AIModel

View File

@ -7,7 +7,7 @@ import numpy as np
import tiktoken import tiktoken
from openai import AzureOpenAI from openai import AzureOpenAI
from core.embedding.embedding_constant import EmbeddingInputType from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.model_entities import AIModelEntity, PriceType from core.model_runtime.entities.model_entities import AIModelEntity, PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError

View File

@ -4,7 +4,7 @@ from typing import Optional
from requests import post from requests import post
from core.embedding.embedding_constant import EmbeddingInputType from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.invoke import ( from core.model_runtime.errors.invoke import (

View File

@ -13,7 +13,7 @@ from botocore.exceptions import (
UnknownServiceError, UnknownServiceError,
) )
from core.embedding.embedding_constant import EmbeddingInputType from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.invoke import ( from core.model_runtime.errors.invoke import (

View File

@ -5,7 +5,7 @@ import cohere
import numpy as np import numpy as np
from cohere.core import RequestOptions from cohere.core import RequestOptions
from core.embedding.embedding_constant import EmbeddingInputType from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.invoke import ( from core.model_runtime.errors.invoke import (

View File

@ -5,7 +5,7 @@ from typing import Optional, Union
import numpy as np import numpy as np
from openai import OpenAI from openai import OpenAI
from core.embedding.embedding_constant import EmbeddingInputType from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError

View File

@ -6,7 +6,7 @@ import numpy as np
import requests import requests
from huggingface_hub import HfApi, InferenceClient from huggingface_hub import HfApi, InferenceClient
from core.embedding.embedding_constant import EmbeddingInputType from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult

View File

@ -1,7 +1,7 @@
import time import time
from typing import Optional from typing import Optional
from core.embedding.embedding_constant import EmbeddingInputType from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult

View File

@ -9,7 +9,7 @@ from tencentcloud.common.profile.client_profile import ClientProfile
from tencentcloud.common.profile.http_profile import HttpProfile from tencentcloud.common.profile.http_profile import HttpProfile
from tencentcloud.hunyuan.v20230901 import hunyuan_client, models from tencentcloud.hunyuan.v20230901 import hunyuan_client, models
from core.embedding.embedding_constant import EmbeddingInputType from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.invoke import ( from core.model_runtime.errors.invoke import (

View File

@ -4,7 +4,7 @@ from typing import Optional
from requests import post from requests import post
from core.embedding.embedding_constant import EmbeddingInputType from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult

View File

@ -5,7 +5,7 @@ from typing import Optional
from requests import post from requests import post
from yarl import URL from yarl import URL
from core.embedding.embedding_constant import EmbeddingInputType from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult

View File

@ -4,7 +4,7 @@ from typing import Optional
from requests import post from requests import post
from core.embedding.embedding_constant import EmbeddingInputType from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.invoke import ( from core.model_runtime.errors.invoke import (

View File

@ -4,7 +4,7 @@ from typing import Optional
import requests import requests
from core.embedding.embedding_constant import EmbeddingInputType from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult

View File

@ -5,7 +5,7 @@ from typing import Optional
from nomic import embed from nomic import embed
from nomic import login as nomic_login from nomic import login as nomic_login
from core.embedding.embedding_constant import EmbeddingInputType from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.entities.text_embedding_entities import ( from core.model_runtime.entities.text_embedding_entities import (
EmbeddingUsage, EmbeddingUsage,

View File

@ -4,7 +4,7 @@ from typing import Optional
from requests import post from requests import post
from core.embedding.embedding_constant import EmbeddingInputType from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.invoke import ( from core.model_runtime.errors.invoke import (

View File

@ -6,7 +6,7 @@ from typing import Optional
import numpy as np import numpy as np
import oci import oci
from core.embedding.embedding_constant import EmbeddingInputType from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.invoke import ( from core.model_runtime.errors.invoke import (

View File

@ -8,7 +8,7 @@ from urllib.parse import urljoin
import numpy as np import numpy as np
import requests import requests
from core.embedding.embedding_constant import EmbeddingInputType from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import ( from core.model_runtime.entities.model_entities import (
AIModelEntity, AIModelEntity,

View File

@ -6,7 +6,7 @@ import numpy as np
import tiktoken import tiktoken
from openai import OpenAI from openai import OpenAI
from core.embedding.embedding_constant import EmbeddingInputType from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError

View File

@ -7,7 +7,7 @@ from urllib.parse import urljoin
import numpy as np import numpy as np
import requests import requests
from core.embedding.embedding_constant import EmbeddingInputType from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import ( from core.model_runtime.entities.model_entities import (
AIModelEntity, AIModelEntity,

View File

@ -5,7 +5,7 @@ from typing import Optional
from requests import post from requests import post
from requests.exceptions import ConnectionError, InvalidSchema, MissingSchema from requests.exceptions import ConnectionError, InvalidSchema, MissingSchema
from core.embedding.embedding_constant import EmbeddingInputType from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.invoke import ( from core.model_runtime.errors.invoke import (

View File

@ -7,7 +7,7 @@ from urllib.parse import urljoin
import numpy as np import numpy as np
import requests import requests
from core.embedding.embedding_constant import EmbeddingInputType from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import ( from core.model_runtime.entities.model_entities import (
AIModelEntity, AIModelEntity,

View File

@ -4,7 +4,7 @@ from typing import Optional
from replicate import Client as ReplicateClient from replicate import Client as ReplicateClient
from core.embedding.embedding_constant import EmbeddingInputType from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult

View File

@ -6,7 +6,7 @@ from typing import Any, Optional
import boto3 import boto3
from core.embedding.embedding_constant import EmbeddingInputType from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult

View File

@ -1,6 +1,6 @@
from typing import Optional from typing import Optional
from core.embedding.embedding_constant import EmbeddingInputType from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.model_providers.openai_api_compatible.text_embedding.text_embedding import ( from core.model_runtime.model_providers.openai_api_compatible.text_embedding.text_embedding import (
OAICompatEmbeddingModel, OAICompatEmbeddingModel,

View File

@ -4,7 +4,7 @@ from typing import Optional
import dashscope import dashscope
import numpy as np import numpy as np
from core.embedding.embedding_constant import EmbeddingInputType from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.entities.text_embedding_entities import ( from core.model_runtime.entities.text_embedding_entities import (
EmbeddingUsage, EmbeddingUsage,

View File

@ -7,7 +7,7 @@ import numpy as np
from openai import OpenAI from openai import OpenAI
from tokenizers import Tokenizer from tokenizers import Tokenizer
from core.embedding.embedding_constant import EmbeddingInputType from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError

View File

@ -9,7 +9,7 @@ from google.cloud import aiplatform
from google.oauth2 import service_account from google.oauth2 import service_account
from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel
from core.embedding.embedding_constant import EmbeddingInputType from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import ( from core.model_runtime.entities.model_entities import (
AIModelEntity, AIModelEntity,

View File

@ -2,7 +2,7 @@ import time
from decimal import Decimal from decimal import Decimal
from typing import Optional from typing import Optional
from core.embedding.embedding_constant import EmbeddingInputType from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import ( from core.model_runtime.entities.model_entities import (
AIModelEntity, AIModelEntity,

View File

@ -4,7 +4,7 @@ from typing import Optional
import requests import requests
from core.embedding.embedding_constant import EmbeddingInputType from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult

View File

@ -7,7 +7,7 @@ from typing import Any, Optional
import numpy as np import numpy as np
from requests import Response, post from requests import Response, post
from core.embedding.embedding_constant import EmbeddingInputType from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError

View File

@ -3,7 +3,7 @@ from typing import Optional
from xinference_client.client.restful.restful_client import Client, RESTfulEmbeddingModelHandle from xinference_client.client.restful.restful_client import Client, RESTfulEmbeddingModelHandle
from core.embedding.embedding_constant import EmbeddingInputType from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult

View File

@ -3,7 +3,7 @@ from typing import Optional
from zhipuai import ZhipuAI from zhipuai import ZhipuAI
from core.embedding.embedding_constant import EmbeddingInputType from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError

View File

@ -1,14 +1,14 @@
from typing import Optional from typing import Optional
from core.model_manager import ModelManager from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.rag.data_post_processor.reorder import ReorderRunner from core.rag.data_post_processor.reorder import ReorderRunner
from core.rag.models.document import Document from core.rag.models.document import Document
from core.rag.rerank.constants.rerank_mode import RerankMode
from core.rag.rerank.entity.weight import KeywordSetting, VectorSetting, Weights from core.rag.rerank.entity.weight import KeywordSetting, VectorSetting, Weights
from core.rag.rerank.rerank_model import RerankModelRunner from core.rag.rerank.rerank_base import BaseRerankRunner
from core.rag.rerank.weight_rerank import WeightRerankRunner from core.rag.rerank.rerank_factory import RerankRunnerFactory
from core.rag.rerank.rerank_type import RerankMode
class DataPostProcessor: class DataPostProcessor:
@ -47,11 +47,12 @@ class DataPostProcessor:
tenant_id: str, tenant_id: str,
reranking_model: Optional[dict] = None, reranking_model: Optional[dict] = None,
weights: Optional[dict] = None, weights: Optional[dict] = None,
) -> Optional[RerankModelRunner | WeightRerankRunner]: ) -> Optional[BaseRerankRunner]:
if reranking_mode == RerankMode.WEIGHTED_SCORE.value and weights: if reranking_mode == RerankMode.WEIGHTED_SCORE.value and weights:
return WeightRerankRunner( runner = RerankRunnerFactory.create_rerank_runner(
tenant_id, runner_type=reranking_mode,
Weights( tenant_id=tenant_id,
weights=Weights(
vector_setting=VectorSetting( vector_setting=VectorSetting(
vector_weight=weights["vector_setting"]["vector_weight"], vector_weight=weights["vector_setting"]["vector_weight"],
embedding_provider_name=weights["vector_setting"]["embedding_provider_name"], embedding_provider_name=weights["vector_setting"]["embedding_provider_name"],
@ -62,23 +63,33 @@ class DataPostProcessor:
), ),
), ),
) )
return runner
elif reranking_mode == RerankMode.RERANKING_MODEL.value: elif reranking_mode == RerankMode.RERANKING_MODEL.value:
if reranking_model: rerank_model_instance = self._get_rerank_model_instance(tenant_id, reranking_model)
try: if rerank_model_instance is None:
model_manager = ModelManager() return None
rerank_model_instance = model_manager.get_model_instance( runner = RerankRunnerFactory.create_rerank_runner(
tenant_id=tenant_id, runner_type=reranking_mode, rerank_model_instance=rerank_model_instance
provider=reranking_model["reranking_provider_name"], )
model_type=ModelType.RERANK, return runner
model=reranking_model["reranking_model_name"],
)
except InvokeAuthorizationError:
return None
return RerankModelRunner(rerank_model_instance)
return None
return None return None
def _get_reorder_runner(self, reorder_enabled) -> Optional[ReorderRunner]: def _get_reorder_runner(self, reorder_enabled) -> Optional[ReorderRunner]:
if reorder_enabled: if reorder_enabled:
return ReorderRunner() return ReorderRunner()
return None return None
def _get_rerank_model_instance(self, tenant_id: str, reranking_model: Optional[dict]) -> ModelInstance | None:
if reranking_model:
try:
model_manager = ModelManager()
rerank_model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
provider=reranking_model["reranking_provider_name"],
model_type=ModelType.RERANK,
model=reranking_model["reranking_model_name"],
)
return rerank_model_instance
except InvokeAuthorizationError:
return None
return None

View File

@ -6,7 +6,7 @@ from flask import Flask, current_app
from core.rag.data_post_processor.data_post_processor import DataPostProcessor from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.vdb.vector_factory import Vector from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.rerank.constants.rerank_mode import RerankMode from core.rag.rerank.rerank_type import RerankMode
from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Dataset from models.dataset import Dataset

View File

@ -9,10 +9,10 @@ _import_err_msg = (
) )
from configs import dify_config from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document from core.rag.models.document import Document
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.dataset import Dataset from models.dataset import Dataset

View File

@ -12,10 +12,10 @@ from pymochow.model.schema import Field, HNSWParams, Schema, VectorIndex
from pymochow.model.table import AnnSearch, HNSWSearchParams, Partition, Row from pymochow.model.table import AnnSearch, HNSWSearchParams, Partition, Row
from configs import dify_config from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document from core.rag.models.document import Document
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.dataset import Dataset from models.dataset import Dataset

View File

@ -6,10 +6,10 @@ from chromadb import QueryResult, Settings
from pydantic import BaseModel from pydantic import BaseModel
from configs import dify_config from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document from core.rag.models.document import Document
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.dataset import Dataset from models.dataset import Dataset

View File

@ -9,11 +9,11 @@ from elasticsearch import Elasticsearch
from flask import current_app from flask import current_app
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document from core.rag.models.document import Document
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.dataset import Dataset from models.dataset import Dataset

View File

@ -7,11 +7,11 @@ from pymilvus import MilvusClient, MilvusException
from pymilvus.milvus_client import IndexParams from pymilvus.milvus_client import IndexParams
from configs import dify_config from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document from core.rag.models.document import Document
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.dataset import Dataset from models.dataset import Dataset

View File

@ -8,10 +8,10 @@ from clickhouse_connect import get_client
from pydantic import BaseModel from pydantic import BaseModel
from configs import dify_config from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document from core.rag.models.document import Document
from models.dataset import Dataset from models.dataset import Dataset

View File

@ -9,11 +9,11 @@ from opensearchpy.helpers import BulkIndexError
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
from configs import dify_config from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document from core.rag.models.document import Document
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.dataset import Dataset from models.dataset import Dataset

View File

@ -13,10 +13,10 @@ from nltk.corpus import stopwords
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
from configs import dify_config from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document from core.rag.models.document import Document
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.dataset import Dataset from models.dataset import Dataset

View File

@ -12,11 +12,11 @@ from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import Mapped, Session, mapped_column from sqlalchemy.orm import Mapped, Session, mapped_column
from configs import dify_config from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.pgvecto_rs.collection import CollectionORM from core.rag.datasource.vdb.pgvecto_rs.collection import CollectionORM
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document from core.rag.models.document import Document
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.dataset import Dataset from models.dataset import Dataset

View File

@ -8,10 +8,10 @@ import psycopg2.pool
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
from configs import dify_config from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document from core.rag.models.document import Document
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.dataset import Dataset from models.dataset import Dataset

View File

@ -20,11 +20,11 @@ from qdrant_client.http.models import (
from qdrant_client.local.qdrant_local import QdrantLocal from qdrant_client.local.qdrant_local import QdrantLocal
from configs import dify_config from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document from core.rag.models.document import Document
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client

View File

@ -8,9 +8,9 @@ from sqlalchemy import text as sql_text
from sqlalchemy.dialects.postgresql import JSON, TEXT from sqlalchemy.dialects.postgresql import JSON, TEXT
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from models.dataset import Dataset from models.dataset import Dataset
try: try:

View File

@ -8,10 +8,10 @@ from tcvectordb.model import index as vdb_index
from tcvectordb.model.document import Filter from tcvectordb.model.document import Filter
from configs import dify_config from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document from core.rag.models.document import Document
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.dataset import Dataset from models.dataset import Dataset

View File

@ -9,10 +9,10 @@ from sqlalchemy import text as sql_text
from sqlalchemy.orm import Session, declarative_base from sqlalchemy.orm import Session, declarative_base
from configs import dify_config from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document from core.rag.models.document import Document
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.dataset import Dataset from models.dataset import Dataset

View File

@ -2,12 +2,12 @@ from abc import ABC, abstractmethod
from typing import Any, Optional from typing import Any, Optional
from configs import dify_config from configs import dify_config
from core.embedding.cached_embedding import CacheEmbedding
from core.model_manager import ModelManager from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_type import VectorType from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.cached_embedding import CacheEmbedding
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document from core.rag.models.document import Document
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.dataset import Dataset from models.dataset import Dataset

View File

@ -14,11 +14,11 @@ from volcengine.viking_db import (
) )
from configs import dify_config from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.field import Field as vdb_Field from core.rag.datasource.vdb.field import Field as vdb_Field
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document from core.rag.models.document import Document
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.dataset import Dataset from models.dataset import Dataset

View File

@ -7,11 +7,11 @@ import weaviate
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
from configs import dify_config from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document from core.rag.models.document import Document
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.dataset import Dataset from models.dataset import Dataset

View File

View File

@ -6,11 +6,11 @@ import numpy as np
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from configs import dify_config from configs import dify_config
from core.embedding.embedding_constant import EmbeddingInputType from core.entities.embedding_type import EmbeddingInputType
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.model_runtime.entities.model_entities import ModelPropertyKey from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.rag.datasource.entity.embedding import Embeddings from core.rag.embedding.embedding_base import Embeddings
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from libs import helper from libs import helper

View File

@ -7,10 +7,12 @@ class Embeddings(ABC):
@abstractmethod @abstractmethod
def embed_documents(self, texts: list[str]) -> list[list[float]]: def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Embed search docs.""" """Embed search docs."""
raise NotImplementedError
@abstractmethod @abstractmethod
def embed_query(self, text: str) -> list[float]: def embed_query(self, text: str) -> list[float]:
"""Embed query text.""" """Embed query text."""
raise NotImplementedError
async def aembed_documents(self, texts: list[str]) -> list[list[float]]: async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
"""Asynchronous Embed search docs.""" """Asynchronous Embed search docs."""

View File

@ -0,0 +1,26 @@
from abc import ABC, abstractmethod
from typing import Optional
from core.rag.models.document import Document
class BaseRerankRunner(ABC):
@abstractmethod
def run(
self,
query: str,
documents: list[Document],
score_threshold: Optional[float] = None,
top_n: Optional[int] = None,
user: Optional[str] = None,
) -> list[Document]:
"""
Run rerank model
:param query: search query
:param documents: documents for reranking
:param score_threshold: score threshold
:param top_n: top n
:param user: unique user id if needed
:return:
"""
raise NotImplementedError

View File

@ -0,0 +1,16 @@
from core.rag.rerank.rerank_base import BaseRerankRunner
from core.rag.rerank.rerank_model import RerankModelRunner
from core.rag.rerank.rerank_type import RerankMode
from core.rag.rerank.weight_rerank import WeightRerankRunner
class RerankRunnerFactory:
@staticmethod
def create_rerank_runner(runner_type: str, *args, **kwargs) -> BaseRerankRunner:
match runner_type:
case RerankMode.RERANKING_MODEL.value:
return RerankModelRunner(*args, **kwargs)
case RerankMode.WEIGHTED_SCORE.value:
return WeightRerankRunner(*args, **kwargs)
case _:
raise ValueError(f"Unknown runner type: {runner_type}")

View File

@ -2,9 +2,10 @@ from typing import Optional
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.rag.models.document import Document from core.rag.models.document import Document
from core.rag.rerank.rerank_base import BaseRerankRunner
class RerankModelRunner: class RerankModelRunner(BaseRerankRunner):
def __init__(self, rerank_model_instance: ModelInstance) -> None: def __init__(self, rerank_model_instance: ModelInstance) -> None:
self.rerank_model_instance = rerank_model_instance self.rerank_model_instance = rerank_model_instance

View File

@ -4,15 +4,16 @@ from typing import Optional
import numpy as np import numpy as np
from core.embedding.cached_embedding import CacheEmbedding
from core.model_manager import ModelManager from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
from core.rag.embedding.cached_embedding import CacheEmbedding
from core.rag.models.document import Document from core.rag.models.document import Document
from core.rag.rerank.entity.weight import VectorSetting, Weights from core.rag.rerank.entity.weight import VectorSetting, Weights
from core.rag.rerank.rerank_base import BaseRerankRunner
class WeightRerankRunner: class WeightRerankRunner(BaseRerankRunner):
def __init__(self, tenant_id: str, weights: Weights) -> None: def __init__(self, tenant_id: str, weights: Weights) -> None:
self.tenant_id = tenant_id self.tenant_id = tenant_id
self.weights = weights self.weights = weights