refactor: move the embedding to the rag module and abstract the rerank runner for extension (#9423)

This commit is contained in:
zhuhao 2024-10-17 19:12:42 +08:00 committed by GitHub
parent e7aecb89dd
commit b90ad587c2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
61 changed files with 135 additions and 78 deletions

View File

@ -3,7 +3,7 @@ import os
from collections.abc import Callable, Generator, Sequence
from typing import IO, Optional, Union, cast
from core.embedding.embedding_constant import EmbeddingInputType
from core.entities.embedding_type import EmbeddingInputType
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
from core.entities.provider_entities import ModelLoadBalancingConfiguration
from core.errors.error import ProviderTokenNotInitError

View File

@ -4,7 +4,7 @@ from typing import Optional
from pydantic import ConfigDict
from core.embedding.embedding_constant import EmbeddingInputType
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.model_providers.__base.ai_model import AIModel

View File

@ -7,7 +7,7 @@ import numpy as np
import tiktoken
from openai import AzureOpenAI
from core.embedding.embedding_constant import EmbeddingInputType
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.model_entities import AIModelEntity, PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError

View File

@ -4,7 +4,7 @@ from typing import Optional
from requests import post
from core.embedding.embedding_constant import EmbeddingInputType
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.invoke import (

View File

@ -13,7 +13,7 @@ from botocore.exceptions import (
UnknownServiceError,
)
from core.embedding.embedding_constant import EmbeddingInputType
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.invoke import (

View File

@ -5,7 +5,7 @@ import cohere
import numpy as np
from cohere.core import RequestOptions
from core.embedding.embedding_constant import EmbeddingInputType
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.invoke import (

View File

@ -5,7 +5,7 @@ from typing import Optional, Union
import numpy as np
from openai import OpenAI
from core.embedding.embedding_constant import EmbeddingInputType
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError

View File

@ -6,7 +6,7 @@ import numpy as np
import requests
from huggingface_hub import HfApi, InferenceClient
from core.embedding.embedding_constant import EmbeddingInputType
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult

View File

@ -1,7 +1,7 @@
import time
from typing import Optional
from core.embedding.embedding_constant import EmbeddingInputType
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult

View File

@ -9,7 +9,7 @@ from tencentcloud.common.profile.client_profile import ClientProfile
from tencentcloud.common.profile.http_profile import HttpProfile
from tencentcloud.hunyuan.v20230901 import hunyuan_client, models
from core.embedding.embedding_constant import EmbeddingInputType
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.invoke import (

View File

@ -4,7 +4,7 @@ from typing import Optional
from requests import post
from core.embedding.embedding_constant import EmbeddingInputType
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult

View File

@ -5,7 +5,7 @@ from typing import Optional
from requests import post
from yarl import URL
from core.embedding.embedding_constant import EmbeddingInputType
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult

View File

@ -4,7 +4,7 @@ from typing import Optional
from requests import post
from core.embedding.embedding_constant import EmbeddingInputType
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.invoke import (

View File

@ -4,7 +4,7 @@ from typing import Optional
import requests
from core.embedding.embedding_constant import EmbeddingInputType
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult

View File

@ -5,7 +5,7 @@ from typing import Optional
from nomic import embed
from nomic import login as nomic_login
from core.embedding.embedding_constant import EmbeddingInputType
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.entities.text_embedding_entities import (
EmbeddingUsage,

View File

@ -4,7 +4,7 @@ from typing import Optional
from requests import post
from core.embedding.embedding_constant import EmbeddingInputType
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.invoke import (

View File

@ -6,7 +6,7 @@ from typing import Optional
import numpy as np
import oci
from core.embedding.embedding_constant import EmbeddingInputType
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.invoke import (

View File

@ -8,7 +8,7 @@ from urllib.parse import urljoin
import numpy as np
import requests
from core.embedding.embedding_constant import EmbeddingInputType
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import (
AIModelEntity,

View File

@ -6,7 +6,7 @@ import numpy as np
import tiktoken
from openai import OpenAI
from core.embedding.embedding_constant import EmbeddingInputType
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError

View File

@ -7,7 +7,7 @@ from urllib.parse import urljoin
import numpy as np
import requests
from core.embedding.embedding_constant import EmbeddingInputType
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import (
AIModelEntity,

View File

@ -5,7 +5,7 @@ from typing import Optional
from requests import post
from requests.exceptions import ConnectionError, InvalidSchema, MissingSchema
from core.embedding.embedding_constant import EmbeddingInputType
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.invoke import (

View File

@ -7,7 +7,7 @@ from urllib.parse import urljoin
import numpy as np
import requests
from core.embedding.embedding_constant import EmbeddingInputType
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import (
AIModelEntity,

View File

@ -4,7 +4,7 @@ from typing import Optional
from replicate import Client as ReplicateClient
from core.embedding.embedding_constant import EmbeddingInputType
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult

View File

@ -6,7 +6,7 @@ from typing import Any, Optional
import boto3
from core.embedding.embedding_constant import EmbeddingInputType
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult

View File

@ -1,6 +1,6 @@
from typing import Optional
from core.embedding.embedding_constant import EmbeddingInputType
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.model_providers.openai_api_compatible.text_embedding.text_embedding import (
OAICompatEmbeddingModel,

View File

@ -4,7 +4,7 @@ from typing import Optional
import dashscope
import numpy as np
from core.embedding.embedding_constant import EmbeddingInputType
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.entities.text_embedding_entities import (
EmbeddingUsage,

View File

@ -7,7 +7,7 @@ import numpy as np
from openai import OpenAI
from tokenizers import Tokenizer
from core.embedding.embedding_constant import EmbeddingInputType
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError

View File

@ -9,7 +9,7 @@ from google.cloud import aiplatform
from google.oauth2 import service_account
from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel
from core.embedding.embedding_constant import EmbeddingInputType
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import (
AIModelEntity,

View File

@ -2,7 +2,7 @@ import time
from decimal import Decimal
from typing import Optional
from core.embedding.embedding_constant import EmbeddingInputType
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import (
AIModelEntity,

View File

@ -4,7 +4,7 @@ from typing import Optional
import requests
from core.embedding.embedding_constant import EmbeddingInputType
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult

View File

@ -7,7 +7,7 @@ from typing import Any, Optional
import numpy as np
from requests import Response, post
from core.embedding.embedding_constant import EmbeddingInputType
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.invoke import InvokeError

View File

@ -3,7 +3,7 @@ from typing import Optional
from xinference_client.client.restful.restful_client import Client, RESTfulEmbeddingModelHandle
from core.embedding.embedding_constant import EmbeddingInputType
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult

View File

@ -3,7 +3,7 @@ from typing import Optional
from zhipuai import ZhipuAI
from core.embedding.embedding_constant import EmbeddingInputType
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError

View File

@ -1,14 +1,14 @@
from typing import Optional
from core.model_manager import ModelManager
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.rag.data_post_processor.reorder import ReorderRunner
from core.rag.models.document import Document
from core.rag.rerank.constants.rerank_mode import RerankMode
from core.rag.rerank.entity.weight import KeywordSetting, VectorSetting, Weights
from core.rag.rerank.rerank_model import RerankModelRunner
from core.rag.rerank.weight_rerank import WeightRerankRunner
from core.rag.rerank.rerank_base import BaseRerankRunner
from core.rag.rerank.rerank_factory import RerankRunnerFactory
from core.rag.rerank.rerank_type import RerankMode
class DataPostProcessor:
@ -47,11 +47,12 @@ class DataPostProcessor:
tenant_id: str,
reranking_model: Optional[dict] = None,
weights: Optional[dict] = None,
) -> Optional[RerankModelRunner | WeightRerankRunner]:
) -> Optional[BaseRerankRunner]:
if reranking_mode == RerankMode.WEIGHTED_SCORE.value and weights:
return WeightRerankRunner(
tenant_id,
Weights(
runner = RerankRunnerFactory.create_rerank_runner(
runner_type=reranking_mode,
tenant_id=tenant_id,
weights=Weights(
vector_setting=VectorSetting(
vector_weight=weights["vector_setting"]["vector_weight"],
embedding_provider_name=weights["vector_setting"]["embedding_provider_name"],
@ -62,7 +63,23 @@ class DataPostProcessor:
),
),
)
return runner
elif reranking_mode == RerankMode.RERANKING_MODEL.value:
rerank_model_instance = self._get_rerank_model_instance(tenant_id, reranking_model)
if rerank_model_instance is None:
return None
runner = RerankRunnerFactory.create_rerank_runner(
runner_type=reranking_mode, rerank_model_instance=rerank_model_instance
)
return runner
return None
def _get_reorder_runner(self, reorder_enabled) -> Optional[ReorderRunner]:
if reorder_enabled:
return ReorderRunner()
return None
def _get_rerank_model_instance(self, tenant_id: str, reranking_model: Optional[dict]) -> ModelInstance | None:
if reranking_model:
try:
model_manager = ModelManager()
@ -72,13 +89,7 @@ class DataPostProcessor:
model_type=ModelType.RERANK,
model=reranking_model["reranking_model_name"],
)
return rerank_model_instance
except InvokeAuthorizationError:
return None
return RerankModelRunner(rerank_model_instance)
return None
return None
def _get_reorder_runner(self, reorder_enabled) -> Optional[ReorderRunner]:
if reorder_enabled:
return ReorderRunner()
return None

View File

@ -6,7 +6,7 @@ from flask import Flask, current_app
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.rerank.constants.rerank_mode import RerankMode
from core.rag.rerank.rerank_type import RerankMode
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from models.dataset import Dataset

View File

@ -9,10 +9,10 @@ _import_err_msg = (
)
from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset

View File

@ -12,10 +12,10 @@ from pymochow.model.schema import Field, HNSWParams, Schema, VectorIndex
from pymochow.model.table import AnnSearch, HNSWSearchParams, Partition, Row
from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset

View File

@ -6,10 +6,10 @@ from chromadb import QueryResult, Settings
from pydantic import BaseModel
from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset

View File

@ -9,11 +9,11 @@ from elasticsearch import Elasticsearch
from flask import current_app
from pydantic import BaseModel, model_validator
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset

View File

@ -7,11 +7,11 @@ from pymilvus import MilvusClient, MilvusException
from pymilvus.milvus_client import IndexParams
from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset

View File

@ -8,10 +8,10 @@ from clickhouse_connect import get_client
from pydantic import BaseModel
from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from models.dataset import Dataset

View File

@ -9,11 +9,11 @@ from opensearchpy.helpers import BulkIndexError
from pydantic import BaseModel, model_validator
from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset

View File

@ -13,10 +13,10 @@ from nltk.corpus import stopwords
from pydantic import BaseModel, model_validator
from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset

View File

@ -12,11 +12,11 @@ from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import Mapped, Session, mapped_column
from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.pgvecto_rs.collection import CollectionORM
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset

View File

@ -8,10 +8,10 @@ import psycopg2.pool
from pydantic import BaseModel, model_validator
from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset

View File

@ -20,11 +20,11 @@ from qdrant_client.http.models import (
from qdrant_client.local.qdrant_local import QdrantLocal
from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client

View File

@ -8,9 +8,9 @@ from sqlalchemy import text as sql_text
from sqlalchemy.dialects.postgresql import JSON, TEXT
from sqlalchemy.orm import Session
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from models.dataset import Dataset
try:

View File

@ -8,10 +8,10 @@ from tcvectordb.model import index as vdb_index
from tcvectordb.model.document import Filter
from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset

View File

@ -9,10 +9,10 @@ from sqlalchemy import text as sql_text
from sqlalchemy.orm import Session, declarative_base
from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset

View File

@ -2,12 +2,12 @@ from abc import ABC, abstractmethod
from typing import Any, Optional
from configs import dify_config
from core.embedding.cached_embedding import CacheEmbedding
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.cached_embedding import CacheEmbedding
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset

View File

@ -14,11 +14,11 @@ from volcengine.viking_db import (
)
from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.field import Field as vdb_Field
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset

View File

@ -7,11 +7,11 @@ import weaviate
from pydantic import BaseModel, model_validator
from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset

View File

View File

@ -6,11 +6,11 @@ import numpy as np
from sqlalchemy.exc import IntegrityError
from configs import dify_config
from core.embedding.embedding_constant import EmbeddingInputType
from core.entities.embedding_type import EmbeddingInputType
from core.model_manager import ModelInstance
from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.embedding.embedding_base import Embeddings
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs import helper

View File

@ -7,10 +7,12 @@ class Embeddings(ABC):
@abstractmethod
def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Embed search docs."""
raise NotImplementedError
@abstractmethod
def embed_query(self, text: str) -> list[float]:
"""Embed query text."""
raise NotImplementedError
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
"""Asynchronous Embed search docs."""

View File

@ -0,0 +1,26 @@
from abc import ABC, abstractmethod
from typing import Optional
from core.rag.models.document import Document
class BaseRerankRunner(ABC):
@abstractmethod
def run(
self,
query: str,
documents: list[Document],
score_threshold: Optional[float] = None,
top_n: Optional[int] = None,
user: Optional[str] = None,
) -> list[Document]:
"""
Run rerank model
:param query: search query
:param documents: documents for reranking
:param score_threshold: score threshold
:param top_n: top n
:param user: unique user id if needed
:return:
"""
raise NotImplementedError

View File

@ -0,0 +1,16 @@
from core.rag.rerank.rerank_base import BaseRerankRunner
from core.rag.rerank.rerank_model import RerankModelRunner
from core.rag.rerank.rerank_type import RerankMode
from core.rag.rerank.weight_rerank import WeightRerankRunner
class RerankRunnerFactory:
@staticmethod
def create_rerank_runner(runner_type: str, *args, **kwargs) -> BaseRerankRunner:
match runner_type:
case RerankMode.RERANKING_MODEL.value:
return RerankModelRunner(*args, **kwargs)
case RerankMode.WEIGHTED_SCORE.value:
return WeightRerankRunner(*args, **kwargs)
case _:
raise ValueError(f"Unknown runner type: {runner_type}")

View File

@ -2,9 +2,10 @@ from typing import Optional
from core.model_manager import ModelInstance
from core.rag.models.document import Document
from core.rag.rerank.rerank_base import BaseRerankRunner
class RerankModelRunner:
class RerankModelRunner(BaseRerankRunner):
def __init__(self, rerank_model_instance: ModelInstance) -> None:
self.rerank_model_instance = rerank_model_instance

View File

@ -4,15 +4,16 @@ from typing import Optional
import numpy as np
from core.embedding.cached_embedding import CacheEmbedding
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
from core.rag.embedding.cached_embedding import CacheEmbedding
from core.rag.models.document import Document
from core.rag.rerank.entity.weight import VectorSetting, Weights
from core.rag.rerank.rerank_base import BaseRerankRunner
class WeightRerankRunner:
class WeightRerankRunner(BaseRerankRunner):
def __init__(self, tenant_id: str, weights: Weights) -> None:
self.tenant_id = tenant_id
self.weights = weights