From 3622691f384be231f0a7d251378fe581b7f24488 Mon Sep 17 00:00:00 2001 From: jyong Date: Thu, 7 Mar 2024 13:30:04 +0800 Subject: [PATCH] add qdrant test --- .../test_paragraph_index_processor.py | 67 ++--- .../rag/vector/test_qdrant.py | 247 ++---------------- 2 files changed, 53 insertions(+), 261 deletions(-) diff --git a/api/tests/integration_tests/rag/index_processor/test_paragraph_index_processor.py b/api/tests/integration_tests/rag/index_processor/test_paragraph_index_processor.py index 17f1f7581e..0bd6448b20 100644 --- a/api/tests/integration_tests/rag/index_processor/test_paragraph_index_processor.py +++ b/api/tests/integration_tests/rag/index_processor/test_paragraph_index_processor.py @@ -2,25 +2,26 @@ import datetime import uuid from typing import Optional - import pytest - from core.rag.cleaner.clean_processor import CleanProcessor from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.vdb.vector_factory import Vector from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.extract_processor import ExtractProcessor -from core.rag.index_processor.index_processor_base import BaseIndexProcessor +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.models.document import Document from libs import helper from models.dataset import Dataset from models.model import UploadFile - @pytest.mark.parametrize('setup_unstructured_mock', [['partition_md', 'chunk_by_title']], indirect=True) -def extract() -> list[Document]: +def extract(): + + index_processor = IndexProcessorFactory('text_model').init_index_processor() + + # extract file_detail = UploadFile( tenant_id='test', storage_type='local', @@ -44,45 +45,30 @@ def extract() -> list[Document]: text_docs = ExtractProcessor.extract(extract_setting=extract_setting, is_automatic=True) assert isinstance(text_docs, list) - return text_docs + for text_doc in text_docs: + assert isinstance(text_doc, Document) -def transform(self, documents: list[Document], **kwargs) -> list[Document]: - # Split the text documents into nodes. - splitter = self._get_splitter(processing_rule=kwargs.get('process_rule'), - embedding_model_instance=kwargs.get('embedding_model_instance')) - all_documents = [] + # transform + process_rule = { + 'pre_processing_rules': [ + {'id': 'remove_extra_spaces', 'enabled': True}, + {'id': 'remove_urls_emails', 'enabled': False} + ], + 'segmentation': { + 'delimiter': '\n', + 'max_tokens': 500, + 'chunk_overlap': 50 + } + } + documents = index_processor.transform(text_docs, embedding_model_instance=None, + process_rule=process_rule) for document in documents: - # document clean - document_text = CleanProcessor.clean(document.page_content, kwargs.get('process_rule')) - document.page_content = document_text - # parse document to nodes - document_nodes = splitter.split_documents([document]) - split_documents = [] - for document_node in document_nodes: + assert isinstance(document, Document) - if document_node.page_content.strip(): - doc_id = str(uuid.uuid4()) - hash = helper.generate_text_hash(document_node.page_content) - document_node.metadata['doc_id'] = doc_id - document_node.metadata['doc_hash'] = hash - # delete Spliter character - page_content = document_node.page_content - if page_content.startswith(".") or page_content.startswith("。"): - page_content = page_content[1:] - else: - page_content = page_content - document_node.page_content = page_content - split_documents.append(document_node) - all_documents.extend(split_documents) - return all_documents + # load + vector = Vector(dataset) + vector.create(documents) -def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True): - if dataset.indexing_technique == 'high_quality': - vector = Vector(dataset) - vector.create(documents) - if with_keywords: - keyword = Keyword(dataset) - keyword.create(documents) def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True): if dataset.indexing_technique == 'high_quality': @@ -98,6 +84,7 @@ def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: else: keyword.delete() + def retrieve(self, retrival_method: str, query: str, dataset: Dataset, top_k: int, score_threshold: float, reranking_model: dict) -> list[Document]: # Set search parameters. diff --git a/api/tests/integration_tests/rag/vector/test_qdrant.py b/api/tests/integration_tests/rag/vector/test_qdrant.py index 4da7f174f6..bfcf006c05 100644 --- a/api/tests/integration_tests/rag/vector/test_qdrant.py +++ b/api/tests/integration_tests/rag/vector/test_qdrant.py @@ -1,227 +1,32 @@ -import os -from typing import Generator - import pytest -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, ImagePromptMessageContent, - SystemPromptMessage, TextPromptMessageContent, - UserPromptMessage) -from core.model_runtime.errors.validate import CredentialsValidateFailedError -from core.model_runtime.model_providers.google.llm.llm import GoogleLargeLanguageModel -from tests.integration_tests.model_runtime.__mock.google import setup_google_mock + +from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantVector, QdrantConfig +from core.rag.models.document import Document -def test_validate_credentials(setup_google_mock): - model = GoogleLargeLanguageModel() - - with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='gemini-pro', - credentials={ - 'google_api_key': 'invalid_key' - } +@pytest.mark.parametrize('setup_qdrant_mock', + [['get_collections', 'recreate_collection', + 'create_payload_index', 'upsert', 'scroll', + 'search']], + indirect=True) +def test_qdrant(setup_qdrant_mock): + document = Document(page_content="test", metadata={"test": "test"}) + qdrant_vector = QdrantVector( + collection_name="test", + group_id='test', + config=QdrantConfig( + endpoint="http://localhost:6333", + api_key="test", + root_path="test", + timeout=10 ) - - model.validate_credentials( - model='gemini-pro', - credentials={ - 'google_api_key': os.environ.get('GOOGLE_API_KEY') - } ) + # create + qdrant_vector.create(texts=[document], embeddings=[[0.23333 for _ in range(233)]]) + # search + result = qdrant_vector.search_by_vector(query_vector=[0.23333 for _ in range(233)]) + for item in result: + assert isinstance(item, Document) + # delete + qdrant_vector.delete() -@pytest.mark.parametrize('setup_google_mock', [['none']], indirect=True) -def test_invoke_model(setup_google_mock): - model = GoogleLargeLanguageModel() - - response = model.invoke( - model='gemini-pro', - credentials={ - 'google_api_key': os.environ.get('GOOGLE_API_KEY') - }, - prompt_messages=[ - SystemPromptMessage( - content='You are a helpful AI assistant.', - ), - UserPromptMessage( - content='Give me your worst dad joke or i will unplug you' - ), - AssistantPromptMessage( - content='Why did the scarecrow win an award? Because he was outstanding in his field!' - ), - UserPromptMessage( - content=[ - TextPromptMessageContent( - data="ok something snarkier pls" - ), - TextPromptMessageContent( - data="i may still unplug you" - )] - ) - ], - model_parameters={ - 'temperature': 0.5, - 'top_p': 1.0, - 'max_tokens_to_sample': 2048 - }, - stop=['How'], - stream=False, - user="abc-123" - ) - - assert isinstance(response, LLMResult) - assert len(response.message.content) > 0 - -@pytest.mark.parametrize('setup_google_mock', [['none']], indirect=True) -def test_invoke_stream_model(setup_google_mock): - model = GoogleLargeLanguageModel() - response = model.invoke( - model='gemini-pro', - credentials={ - 'google_api_key': os.environ.get('GOOGLE_API_KEY') - }, - prompt_messages=[ - SystemPromptMessage( - content='You are a helpful AI assistant.', - ), - UserPromptMessage( - content='Give me your worst dad joke or i will unplug you' - ), - AssistantPromptMessage( - content='Why did the scarecrow win an award? Because he was outstanding in his field!' - ), - UserPromptMessage( - content=[ - TextPromptMessageContent( - data="ok something snarkier pls" - ), - TextPromptMessageContent( - data="i may still unplug you" - )] - ) - ], - model_parameters={ - 'temperature': 0.2, - 'top_k': 5, - 'max_tokens_to_sample': 2048 - }, - stream=True, - user="abc-123" - ) - - assert isinstance(response, Generator) - - for chunk in response: - assert isinstance(chunk, LLMResultChunk) - assert isinstance(chunk.delta, LLMResultChunkDelta) - assert isinstance(chunk.delta.message, AssistantPromptMessage) - assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True - -@pytest.mark.parametrize('setup_google_mock', [['none']], indirect=True) -def test_invoke_chat_model_with_vision(setup_google_mock): - model = GoogleLargeLanguageModel() - - result = model.invoke( - model='gemini-pro-vision', - credentials={ - 'google_api_key': os.environ.get('GOOGLE_API_KEY') - }, - prompt_messages=[ - SystemPromptMessage( - content='You are a helpful AI assistant.', - ), - UserPromptMessage( - content=[ - TextPromptMessageContent( - data="what do you see?" - ), - ImagePromptMessageContent( - data='' - ) - ] - ) - ], - model_parameters={ - 'temperature': 0.3, - 'top_p': 0.2, - 'top_k': 3, - 'max_tokens': 100 - }, - stream=False, - user="abc-123" - ) - - assert isinstance(result, LLMResult) - assert len(result.message.content) > 0 - -@pytest.mark.parametrize('setup_google_mock', [['none']], indirect=True) -def test_invoke_chat_model_with_vision_multi_pics(setup_google_mock): - model = GoogleLargeLanguageModel() - - result = model.invoke( - model='gemini-pro-vision', - credentials={ - 'google_api_key': os.environ.get('GOOGLE_API_KEY') - }, - prompt_messages=[ - SystemPromptMessage( - content='You are a helpful AI assistant.' - ), - UserPromptMessage( - content=[ - TextPromptMessageContent( - data="what do you see?" - ), - ImagePromptMessageContent( - data='' - ) - ] - ), - AssistantPromptMessage( - content="I see a blue letter 'D' with a gradient from light blue to dark blue." - ), - UserPromptMessage( - content=[ - TextPromptMessageContent( - data="what about now?" - ), - ImagePromptMessageContent( - data='' - ) - ] - ) - ], - model_parameters={ - 'temperature': 0.3, - 'top_p': 0.2, - 'top_k': 3, - 'max_tokens': 100 - }, - stream=False, - user="abc-123" - ) - - print(f"resultz: {result.message.content}") - assert isinstance(result, LLMResult) - assert len(result.message.content) > 0 - - - -def test_get_num_tokens(): - model = GoogleLargeLanguageModel() - - num_tokens = model.get_num_tokens( - model='gemini-pro', - credentials={ - 'google_api_key': os.environ.get('GOOGLE_API_KEY') - }, - prompt_messages=[ - SystemPromptMessage( - content='You are a helpful AI assistant.', - ), - UserPromptMessage( - content='Hello World!' - ) - ] - ) - - assert num_tokens > 0 # The exact number of tokens may vary based on the model's tokenization