mirror of
https://github.com/langgenius/dify.git
synced 2024-11-15 19:22:36 +08:00
add qdrant test
This commit is contained in:
parent
52e6f458be
commit
3622691f38
|
@ -2,25 +2,26 @@
|
|||
import datetime
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from core.rag.cleaner.clean_processor import CleanProcessor
|
||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from core.rag.models.document import Document
|
||||
from libs import helper
|
||||
from models.dataset import Dataset
|
||||
from models.model import UploadFile
|
||||
|
||||
|
||||
|
||||
@pytest.mark.parametrize('setup_unstructured_mock', [['partition_md', 'chunk_by_title']], indirect=True)
|
||||
def extract() -> list[Document]:
|
||||
def extract():
|
||||
|
||||
index_processor = IndexProcessorFactory('text_model').init_index_processor()
|
||||
|
||||
# extract
|
||||
file_detail = UploadFile(
|
||||
tenant_id='test',
|
||||
storage_type='local',
|
||||
|
@ -44,45 +45,30 @@ def extract() -> list[Document]:
|
|||
text_docs = ExtractProcessor.extract(extract_setting=extract_setting,
|
||||
is_automatic=True)
|
||||
assert isinstance(text_docs, list)
|
||||
return text_docs
|
||||
for text_doc in text_docs:
|
||||
assert isinstance(text_doc, Document)
|
||||
|
||||
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
|
||||
# Split the text documents into nodes.
|
||||
splitter = self._get_splitter(processing_rule=kwargs.get('process_rule'),
|
||||
embedding_model_instance=kwargs.get('embedding_model_instance'))
|
||||
all_documents = []
|
||||
# transform
|
||||
process_rule = {
|
||||
'pre_processing_rules': [
|
||||
{'id': 'remove_extra_spaces', 'enabled': True},
|
||||
{'id': 'remove_urls_emails', 'enabled': False}
|
||||
],
|
||||
'segmentation': {
|
||||
'delimiter': '\n',
|
||||
'max_tokens': 500,
|
||||
'chunk_overlap': 50
|
||||
}
|
||||
}
|
||||
documents = index_processor.transform(text_docs, embedding_model_instance=None,
|
||||
process_rule=process_rule)
|
||||
for document in documents:
|
||||
# document clean
|
||||
document_text = CleanProcessor.clean(document.page_content, kwargs.get('process_rule'))
|
||||
document.page_content = document_text
|
||||
# parse document to nodes
|
||||
document_nodes = splitter.split_documents([document])
|
||||
split_documents = []
|
||||
for document_node in document_nodes:
|
||||
assert isinstance(document, Document)
|
||||
|
||||
if document_node.page_content.strip():
|
||||
doc_id = str(uuid.uuid4())
|
||||
hash = helper.generate_text_hash(document_node.page_content)
|
||||
document_node.metadata['doc_id'] = doc_id
|
||||
document_node.metadata['doc_hash'] = hash
|
||||
# delete Spliter character
|
||||
page_content = document_node.page_content
|
||||
if page_content.startswith(".") or page_content.startswith("。"):
|
||||
page_content = page_content[1:]
|
||||
else:
|
||||
page_content = page_content
|
||||
document_node.page_content = page_content
|
||||
split_documents.append(document_node)
|
||||
all_documents.extend(split_documents)
|
||||
return all_documents
|
||||
# load
|
||||
vector = Vector(dataset)
|
||||
vector.create(documents)
|
||||
|
||||
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True):
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
vector = Vector(dataset)
|
||||
vector.create(documents)
|
||||
if with_keywords:
|
||||
keyword = Keyword(dataset)
|
||||
keyword.create(documents)
|
||||
|
||||
def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True):
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
|
@ -98,6 +84,7 @@ def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords:
|
|||
else:
|
||||
keyword.delete()
|
||||
|
||||
|
||||
def retrieve(self, retrival_method: str, query: str, dataset: Dataset, top_k: int,
|
||||
score_threshold: float, reranking_model: dict) -> list[Document]:
|
||||
# Set search parameters.
|
||||
|
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user