add qdrant test

This commit is contained in:
jyong 2024-03-07 13:30:04 +08:00
parent 52e6f458be
commit 3622691f38
2 changed files with 53 additions and 261 deletions

View File

@ -2,25 +2,26 @@
import datetime import datetime
import uuid import uuid
from typing import Optional from typing import Optional
import pytest import pytest
from core.rag.cleaner.clean_processor import CleanProcessor from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.datasource.vdb.vector_factory import Vector from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import Document from core.rag.models.document import Document
from libs import helper from libs import helper
from models.dataset import Dataset from models.dataset import Dataset
from models.model import UploadFile from models.model import UploadFile
@pytest.mark.parametrize('setup_unstructured_mock', [['partition_md', 'chunk_by_title']], indirect=True) @pytest.mark.parametrize('setup_unstructured_mock', [['partition_md', 'chunk_by_title']], indirect=True)
def extract() -> list[Document]: def extract():
index_processor = IndexProcessorFactory('text_model').init_index_processor()
# extract
file_detail = UploadFile( file_detail = UploadFile(
tenant_id='test', tenant_id='test',
storage_type='local', storage_type='local',
@ -44,45 +45,30 @@ def extract() -> list[Document]:
text_docs = ExtractProcessor.extract(extract_setting=extract_setting, text_docs = ExtractProcessor.extract(extract_setting=extract_setting,
is_automatic=True) is_automatic=True)
assert isinstance(text_docs, list) assert isinstance(text_docs, list)
return text_docs for text_doc in text_docs:
assert isinstance(text_doc, Document)
def transform(self, documents: list[Document], **kwargs) -> list[Document]: # transform
# Split the text documents into nodes. process_rule = {
splitter = self._get_splitter(processing_rule=kwargs.get('process_rule'), 'pre_processing_rules': [
embedding_model_instance=kwargs.get('embedding_model_instance')) {'id': 'remove_extra_spaces', 'enabled': True},
all_documents = [] {'id': 'remove_urls_emails', 'enabled': False}
],
'segmentation': {
'delimiter': '\n',
'max_tokens': 500,
'chunk_overlap': 50
}
}
documents = index_processor.transform(text_docs, embedding_model_instance=None,
process_rule=process_rule)
for document in documents: for document in documents:
# document clean assert isinstance(document, Document)
document_text = CleanProcessor.clean(document.page_content, kwargs.get('process_rule'))
document.page_content = document_text
# parse document to nodes
document_nodes = splitter.split_documents([document])
split_documents = []
for document_node in document_nodes:
if document_node.page_content.strip(): # load
doc_id = str(uuid.uuid4()) vector = Vector(dataset)
hash = helper.generate_text_hash(document_node.page_content) vector.create(documents)
document_node.metadata['doc_id'] = doc_id
document_node.metadata['doc_hash'] = hash
# delete Spliter character
page_content = document_node.page_content
if page_content.startswith(".") or page_content.startswith(""):
page_content = page_content[1:]
else:
page_content = page_content
document_node.page_content = page_content
split_documents.append(document_node)
all_documents.extend(split_documents)
return all_documents
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True):
if dataset.indexing_technique == 'high_quality':
vector = Vector(dataset)
vector.create(documents)
if with_keywords:
keyword = Keyword(dataset)
keyword.create(documents)
def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True): def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True):
if dataset.indexing_technique == 'high_quality': if dataset.indexing_technique == 'high_quality':
@ -98,6 +84,7 @@ def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords:
else: else:
keyword.delete() keyword.delete()
def retrieve(self, retrival_method: str, query: str, dataset: Dataset, top_k: int, def retrieve(self, retrival_method: str, query: str, dataset: Dataset, top_k: int,
score_threshold: float, reranking_model: dict) -> list[Document]: score_threshold: float, reranking_model: dict) -> list[Document]:
# Set search parameters. # Set search parameters.

File diff suppressed because one or more lines are too long