add RAG test

This commit is contained in:
jyong 2024-03-04 15:07:56 +08:00
parent 796f7d4d29
commit 4ea468b52a
28 changed files with 807 additions and 3 deletions

View File

@ -277,7 +277,7 @@ class QdrantVector(BaseVector):
from qdrant_client.http import models
filter = models.Filter(
must=[
models.FieldCondition(
models.FieldCondition(
key="group_id",
match=models.MatchValue(value=self._group_id),
),

9
api/core/rag/test.py Normal file
View File

@ -0,0 +1,9 @@
from llama_index.core import DocumentSummaryIndex
doc_summary_index = DocumentSummaryIndex.from_documents(
city_docs,
llm=chatgpt,
transformations=[splitter],
response_synthesizer=response_synthesizer,
show_progress=True,
)

View File

@ -66,4 +66,4 @@ JINA_API_KEY=
OLLAMA_BASE_URL=
# Mock Switch
MOCK_SWITCH=false
MOCK_SWITCH=true

View File

@ -0,0 +1,73 @@
from ctypes import Union
from typing import List, Optional, Tuple
from qdrant_client.conversions import common_types as types
class MockMilvusClass(object):
@staticmethod
def get_collections() -> types.CollectionsResponse:
collections_response = types.CollectionsResponse(
collections=["test"]
)
return collections_response
@staticmethod
def recreate_collection() -> bool:
return True
@staticmethod
def create_payload_index() -> types.UpdateResult:
update_result = types.UpdateResult(
updated=1
)
return update_result
@staticmethod
def upsert() -> types.UpdateResult:
update_result = types.UpdateResult(
updated=1
)
return update_result
@staticmethod
def insert() -> List[Union[str, int]]:
result = ['d48632d7-c972-484a-8ed9-262490919c79']
return result
@staticmethod
def delete() -> List[Union[str, int]]:
result = ['d48632d7-c972-484a-8ed9-262490919c79']
return result
@staticmethod
def scroll() -> Tuple[List[types.Record], Optional[types.PointId]]:
record = types.Record(
id='d48632d7-c972-484a-8ed9-262490919c79',
payload={'group_id': '06798db6-1f99-489a-b599-dd386a043f2d',
'metadata': {'dataset_id': '06798db6-1f99-489a-b599-dd386a043f2d',
'doc_hash': '85197672a2c2b05d2c8690cb7f1eedc78fe5f0ca7b8ae8a301f64eb8d959b436',
'doc_id': 'd48632d7-c972-484a-8ed9-262490919c79',
'document_id': '1518a57d-9049-426e-99ae-5a6d479175c0'},
'page_content': 'Dify is a company that provides a platform for the development of AI models.'},
vector=[0.23333 for _ in range(233)]
)
return [record], 'd48632d7-c972-484a-8ed9-262490919c79'
@staticmethod
def search() -> List[types.ScoredPoint]:
result = types.ScoredPoint(
id='d48632d7-c972-484a-8ed9-262490919c79',
payload={'group_id': '06798db6-1f99-489a-b599-dd386a043f2d',
'metadata': {'dataset_id': '06798db6-1f99-489a-b599-dd386a043f2d',
'doc_hash': '85197672a2c2b05d2c8690cb7f1eedc78fe5f0ca7b8ae8a301f64eb8d959b436',
'doc_id': 'd48632d7-c972-484a-8ed9-262490919c79',
'document_id': '1518a57d-9049-426e-99ae-5a6d479175c0'},
'page_content': 'Dify is a company that provides a platform for the development of AI models.'},
vision=999,
vector=[0.23333 for _ in range(233)],
score=0.99
)
return [result]

View File

@ -0,0 +1,58 @@
import os
from typing import Callable, List, Literal
import pytest
# import monkeypatch
from _pytest.monkeypatch import MonkeyPatch
from pymilvus import Connections, MilvusClient
from pymilvus.orm import utility
from qdrant_client import QdrantClient
from unstructured.chunking.title import chunk_by_title
from unstructured.partition.md import partition_md
from tests.integration_tests.model_runtime.__mock.openai_completion import MockCompletionsClass
from tests.integration_tests.rag.__mock.milvus_function import MockMilvusClass
from tests.integration_tests.rag.__mock.qdrant_function import MockQdrantClass
from tests.integration_tests.rag.__mock.unstructured_function import MockUnstructuredClass
def mock_milvus(monkeypatch: MonkeyPatch, methods: List[Literal["get_collections", "delete", "recreate_collection", "create_payload_index", "upsert", "scroll", "search"]]) -> Callable[[], None]:
"""
mock unstructured module
:param monkeypatch: pytest monkeypatch fixture
:return: unpatch function
"""
def unpatch() -> None:
monkeypatch.undo()
if "connect" in methods:
monkeypatch.setattr(Connections, "connect", MockMilvusClass.delete())
if "get_collections" in methods:
monkeypatch.setattr(utility, "has_collection", MockMilvusClass.get_collections())
if "insert" in methods:
monkeypatch.setattr(MilvusClient, "insert", MockMilvusClass.insert())
if "create_payload_index" in methods:
monkeypatch.setattr(QdrantClient, "create_payload_index", MockMilvusClass.create_payload_index())
if "upsert" in methods:
monkeypatch.setattr(QdrantClient, "upsert", MockMilvusClass.upsert())
if "scroll" in methods:
monkeypatch.setattr(QdrantClient, "scroll", MockMilvusClass.scroll())
if "search" in methods:
monkeypatch.setattr(QdrantClient, "search", MockMilvusClass.search())
return unpatch
MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
@pytest.fixture
def setup_milvus_mock(request, monkeypatch):
methods = request.param if hasattr(request, 'param') else []
if MOCK:
unpatch = mock_milvus(monkeypatch, methods=methods)
yield
if MOCK:
unpatch()

View File

@ -0,0 +1,68 @@
from typing import List, Optional, Tuple
from qdrant_client.conversions import common_types as types
class MockQdrantClass(object):
@staticmethod
def get_collections() -> types.CollectionsResponse:
collections_response = types.CollectionsResponse(
collections=["test"]
)
return collections_response
@staticmethod
def recreate_collection() -> bool:
return True
@staticmethod
def create_payload_index() -> types.UpdateResult:
update_result = types.UpdateResult(
updated=1
)
return update_result
@staticmethod
def upsert() -> types.UpdateResult:
update_result = types.UpdateResult(
updated=1
)
return update_result
@staticmethod
def delete() -> types.UpdateResult:
update_result = types.UpdateResult(
updated=1
)
return update_result
@staticmethod
def scroll() -> Tuple[List[types.Record], Optional[types.PointId]]:
record = types.Record(
id='d48632d7-c972-484a-8ed9-262490919c79',
payload={'group_id': '06798db6-1f99-489a-b599-dd386a043f2d',
'metadata': {'dataset_id': '06798db6-1f99-489a-b599-dd386a043f2d',
'doc_hash': '85197672a2c2b05d2c8690cb7f1eedc78fe5f0ca7b8ae8a301f64eb8d959b436',
'doc_id': 'd48632d7-c972-484a-8ed9-262490919c79',
'document_id': '1518a57d-9049-426e-99ae-5a6d479175c0'},
'page_content': 'Dify is a company that provides a platform for the development of AI models.'},
vector=[0.23333 for _ in range(233)]
)
return [record], 'd48632d7-c972-484a-8ed9-262490919c79'
@staticmethod
def search() -> List[types.ScoredPoint]:
result = types.ScoredPoint(
id='d48632d7-c972-484a-8ed9-262490919c79',
payload={'group_id': '06798db6-1f99-489a-b599-dd386a043f2d',
'metadata': {'dataset_id': '06798db6-1f99-489a-b599-dd386a043f2d',
'doc_hash': '85197672a2c2b05d2c8690cb7f1eedc78fe5f0ca7b8ae8a301f64eb8d959b436',
'doc_id': 'd48632d7-c972-484a-8ed9-262490919c79',
'document_id': '1518a57d-9049-426e-99ae-5a6d479175c0'},
'page_content': 'Dify is a company that provides a platform for the development of AI models.'},
vision=999,
vector=[0.23333 for _ in range(233)],
score=0.99
)
return [result]

View File

@ -0,0 +1,55 @@
import os
from typing import Callable, List, Literal
import pytest
# import monkeypatch
from _pytest.monkeypatch import MonkeyPatch
from qdrant_client import QdrantClient
from unstructured.chunking.title import chunk_by_title
from unstructured.partition.md import partition_md
from tests.integration_tests.model_runtime.__mock.openai_completion import MockCompletionsClass
from tests.integration_tests.rag.__mock.qdrant_function import MockQdrantClass
from tests.integration_tests.rag.__mock.unstructured_function import MockUnstructuredClass
def mock_qdrant(monkeypatch: MonkeyPatch, methods: List[Literal["get_collections", "delete", "recreate_collection", "create_payload_index", "upsert", "scroll", "search"]]) -> Callable[[], None]:
"""
mock unstructured module
:param monkeypatch: pytest monkeypatch fixture
:return: unpatch function
"""
def unpatch() -> None:
monkeypatch.undo()
if "delete" in methods:
monkeypatch.setattr(QdrantClient, "delete", MockQdrantClass.delete())
if "get_collections" in methods:
monkeypatch.setattr(QdrantClient, "get_collections", MockQdrantClass.get_collections())
if "recreate_collection" in methods:
monkeypatch.setattr(QdrantClient, "recreate_collection", MockQdrantClass.recreate_collection())
if "create_payload_index" in methods:
monkeypatch.setattr(QdrantClient, "create_payload_index", MockQdrantClass.create_payload_index())
if "upsert" in methods:
monkeypatch.setattr(QdrantClient, "upsert", MockQdrantClass.upsert())
if "scroll" in methods:
monkeypatch.setattr(QdrantClient, "scroll", MockQdrantClass.scroll())
if "search" in methods:
monkeypatch.setattr(QdrantClient, "search", MockQdrantClass.search())
return unpatch
MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
@pytest.fixture
def setup_qdrant_mock(request, monkeypatch):
methods = request.param if hasattr(request, 'param') else []
if MOCK:
unpatch = mock_qdrant(monkeypatch, methods=methods)
yield
if MOCK:
unpatch()

View File

@ -0,0 +1,39 @@
from typing import List
from unstructured.documents.elements import Element
class MockUnstructuredClass(object):
@staticmethod
def partition_md() -> List[Element]:
element = Element(
category="title",
embeddings=[],
id="test",
metadata={},
text="test"
)
return [element]
@staticmethod
def partition_text() -> List[Element]:
element = Element(
category="title",
embeddings=[],
id="test",
metadata={},
text="test"
)
return [element]
@staticmethod
def chunk_by_title() -> List[Element]:
element = Element(
category="title",
embeddings=[],
id="test",
metadata={},
text="test"
)
return [element]

View File

@ -0,0 +1,45 @@
import os
from typing import Callable, List, Literal
import pytest
# import monkeypatch
from _pytest.monkeypatch import MonkeyPatch
from unstructured.chunking import title
from unstructured.partition import md, text
from tests.integration_tests.rag.__mock.unstructured_function import MockUnstructuredClass
def mock_unstructured(monkeypatch: MonkeyPatch, methods: List[Literal["partition_md", "chunk_by_title"]]) -> Callable[[], None]:
"""
mock unstructured module
:param monkeypatch: pytest monkeypatch fixture
:return: unpatch function
"""
def unpatch() -> None:
monkeypatch.undo()
if "partition_md" in methods:
monkeypatch.setattr(md, "partition_md", MockUnstructuredClass.partition_md())
if "partition_text" in methods:
monkeypatch.setattr(text, "partition_text", MockUnstructuredClass.partition_text())
if "chunk_by_title" in methods:
monkeypatch.setattr(title, "chunk_by_title", MockUnstructuredClass.chunk_by_title())
return unpatch
MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
@pytest.fixture
def setup_unstructured_mock(request, monkeypatch):
methods = request.param if hasattr(request, 'param') else []
if MOCK:
unpatch = mock_unstructured(monkeypatch, methods=methods)
yield
if MOCK:
unpatch()

View File

@ -0,0 +1,39 @@
from typing import List, Optional, Tuple
from qdrant_client.conversions import common_types as types
class MockWeaviateClass(object):
@staticmethod
def contains() -> bool:
return True
@staticmethod
def add_data_object() -> str:
return 'd48632d7-c972-484a-8ed9-262490919c79'
@staticmethod
def delete_class() -> None:
return None
@staticmethod
def do() -> dict:
record = {
'Get': {
'Vector_index_a5f66ab4_cc83_4061_85a5_cb775933d52a_Node': [
{
'_additional': {
'distance': 0.10660946,
'vector': [0.23333 for _ in range(233)]
},
'dataset_id': 'a5f66ab4-cc83-4061-85a5-cb775933d52a',
'doc_hash': '52c3c8889c34d2d7b50bb04ca4d77081b1b4b625bc69c82294abfbdf7e918c21',
'doc_id': 'b3fdec03-99ad-4a7c-a565-94d02dcde05e',
'document_id': '71ec7e68-c45a-4d8b-886b-6077730a83ee',
'text': '1、你知道孙悟空是从哪里生出来的吗'
}
]
}
}
return record

View File

@ -0,0 +1,53 @@
import os
from typing import Callable, List, Literal
import pytest
# import monkeypatch
from _pytest.monkeypatch import MonkeyPatch
from qdrant_client import QdrantClient
from unstructured.chunking.title import chunk_by_title
from unstructured.partition.md import partition_md
from weaviate.batch import Batch
from weaviate.gql.get import GetBuilder
from weaviate.schema import Schema
from tests.integration_tests.model_runtime.__mock.openai_completion import MockCompletionsClass
from tests.integration_tests.rag.__mock.qdrant_function import MockQdrantClass
from tests.integration_tests.rag.__mock.weaviate_function import MockWeaviateClass
def mock_weaviate(monkeypatch: MonkeyPatch, methods: List[Literal["get_collections", "delete", "recreate_collection", "create_payload_index", "upsert", "scroll", "search"]]) -> Callable[[], None]:
"""
mock unstructured module
:param monkeypatch: pytest monkeypatch fixture
:return: unpatch function
"""
def unpatch() -> None:
monkeypatch.undo()
if "delete" in methods:
monkeypatch.setattr(Schema, "delete", MockWeaviateClass.delete_class())
if "contains" in methods:
monkeypatch.setattr(Schema, "contains", MockWeaviateClass.contains())
if "add_data_object" in methods:
monkeypatch.setattr(Batch, "add_data_object", MockWeaviateClass.add_data_object())
if "do" in methods:
monkeypatch.setattr(GetBuilder, "do", MockWeaviateClass.do())
return unpatch
MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
@pytest.fixture
def setup_weaviate_mock(request, monkeypatch):
methods = request.param if hasattr(request, 'param') else []
if MOCK:
unpatch = mock_weaviate(monkeypatch, methods=methods)
yield
if MOCK:
unpatch()

View File

@ -0,0 +1,20 @@
import os
from core.rag.extractor.excel_extractor import ExcelExtractor
from core.rag.models.document import Document
def test_extract_xlsx():
# Get the directory of the current file
current_dir = os.path.dirname(os.path.abspath(__file__))
# Get assets directory
assets_dir = os.path.join(os.path.dirname(current_dir), 'assets')
# Construct the path to the xlsx file
test_file_path = os.path.join(assets_dir, 'test.xlsx')
extractor = ExcelExtractor(test_file_path)
result = extractor.extract()
assert isinstance(result, list)
for item in result:
assert isinstance(item, Document)

View File

@ -0,0 +1,22 @@
import os
from core.rag.extractor.csv_extractor import CSVExtractor
from core.rag.models.document import Document
def test_extract_csv():
# Get the directory of the current file
current_dir = os.path.dirname(os.path.abspath(__file__))
# Get assets directory
assets_dir = os.path.join(os.path.dirname(current_dir), 'assets')
# Construct the path to the txt file
test_file_path = os.path.join(assets_dir, 'test.csv')
extractor = CSVExtractor(test_file_path, autodetect_encoding=True)
result = extractor.extract()
assert isinstance(result, list)
for item in result:
assert isinstance(item, Document)

View File

@ -0,0 +1,25 @@
import os
from core.rag.extractor.markdown_extractor import MarkdownExtractor
from core.rag.extractor.pdf_extractor import PdfExtractor
from core.rag.extractor.text_extractor import TextExtractor
from core.rag.extractor.word_extractor import WordExtractor
from core.rag.models.document import Document
def test_extract_docx():
# Get the directory of the current file
current_dir = os.path.dirname(os.path.abspath(__file__))
# Get assets directory
assets_dir = os.path.join(os.path.dirname(current_dir), 'assets')
# Construct the path to the markdown file
test_file_path = os.path.join(assets_dir, 'test.docx')
extractor = WordExtractor(test_file_path)
result = extractor.extract()
assert isinstance(result, list)
for item in result:
assert isinstance(item, Document)

View File

@ -0,0 +1,22 @@
import os
from core.rag.extractor.html_extractor import HtmlExtractor
from core.rag.models.document import Document
def test_extract_html():
# Get the directory of the current file
current_dir = os.path.dirname(os.path.abspath(__file__))
# Get assets directory
assets_dir = os.path.join(os.path.dirname(current_dir), 'assets')
# Construct the path to the markdown file
test_file_path = os.path.join(assets_dir, 'test.html')
extractor = HtmlExtractor(test_file_path)
result = extractor.extract()
assert isinstance(result, list)
for item in result:
assert isinstance(item, Document)

View File

@ -0,0 +1,23 @@
import os
from core.rag.extractor.markdown_extractor import MarkdownExtractor
from core.rag.extractor.text_extractor import TextExtractor
from core.rag.models.document import Document
def test_extract_markdown():
# Get the directory of the current file
current_dir = os.path.dirname(os.path.abspath(__file__))
# Get assets directory
assets_dir = os.path.join(os.path.dirname(current_dir), 'assets')
# Construct the path to the markdown file
test_file_path = os.path.join(assets_dir, 'test.md')
extractor = MarkdownExtractor(test_file_path, autodetect_encoding=True)
result = extractor.extract()
assert isinstance(result, list)
for item in result:
assert isinstance(item, Document)

View File

@ -0,0 +1,24 @@
import os
from core.rag.extractor.markdown_extractor import MarkdownExtractor
from core.rag.extractor.pdf_extractor import PdfExtractor
from core.rag.extractor.text_extractor import TextExtractor
from core.rag.models.document import Document
def test_extract_pdf():
# Get the directory of the current file
current_dir = os.path.dirname(os.path.abspath(__file__))
# Get assets directory
assets_dir = os.path.join(os.path.dirname(current_dir), 'assets')
# Construct the path to the markdown file
test_file_path = os.path.join(assets_dir, 'test.pdf')
extractor = PdfExtractor(test_file_path)
result = extractor.extract()
assert isinstance(result, list)
for item in result:
assert isinstance(item, Document)

View File

@ -0,0 +1,21 @@
import os
from core.rag.extractor.text_extractor import TextExtractor
from core.rag.models.document import Document
def test_extract_text():
# Get the directory of the current file
current_dir = os.path.dirname(os.path.abspath(__file__))
# Get assets directory
assets_dir = os.path.join(os.path.dirname(current_dir), 'assets')
# Construct the path to the txt file
test_file_path = os.path.join(assets_dir, 'test.txt')
extractor = TextExtractor(test_file_path, autodetect_encoding=True)
result = extractor.extract()
assert isinstance(result, list)
for item in result:
assert isinstance(item, Document)

View File

@ -0,0 +1,27 @@
import os
from core.rag.extractor.markdown_extractor import MarkdownExtractor
from core.rag.extractor.text_extractor import TextExtractor
from core.rag.extractor.unstructured.unstructured_doc_extractor import UnstructuredWordExtractor
from core.rag.extractor.unstructured.unstructured_markdown_extractor import UnstructuredMarkdownExtractor
from core.rag.models.document import Document
def test_extract_unstructured_docx():
# Get the directory of the current file
current_dir = os.path.dirname(os.path.abspath(__file__))
# Get assets directory
assets_dir = os.path.join(os.path.dirname(current_dir), 'assets')
# Construct the path to the docx file
test_file_path = os.path.join(assets_dir, 'test.docx')
unstructured_api_url = os.getenv('UNSTRUCTURED_API_URL')
extractor = UnstructuredWordExtractor(test_file_path, unstructured_api_url)
result = extractor.extract()
assert isinstance(result, list)
for item in result:
assert isinstance(item, Document)

View File

@ -0,0 +1,25 @@
import os
import pytest
from core.rag.extractor.unstructured.unstructured_markdown_extractor import UnstructuredMarkdownExtractor
from core.rag.models.document import Document
from tests.integration_tests.rag.__mock.unstructured_mock import setup_unstructured_mock
@pytest.mark.parametrize('setup_unstructured_mock', [['partition_md', 'chunk_by_title']], indirect=True)
def test_extract_unstructured_markdown(setup_unstructured_mock):
# Get the directory of the current file
current_dir = os.path.dirname(os.path.abspath(__file__))
# Get assets directory
assets_dir = os.path.join(os.path.dirname(current_dir), 'assets')
# Construct the path to the markdown file
test_file_path = os.path.join(assets_dir, 'test.md')
unstructured_api_url = os.getenv('UNSTRUCTURED_API_URL')
extractor = UnstructuredMarkdownExtractor(test_file_path, unstructured_api_url)
result = extractor.extract()
assert isinstance(result, list)
for item in result:
assert isinstance(item, Document)

View File

@ -74,7 +74,6 @@ def test_invoke_model(setup_google_mock):
@pytest.mark.parametrize('setup_google_mock', [['none']], indirect=True)
def test_invoke_stream_model(setup_google_mock):
model = GoogleLargeLanguageModel()
response = model.invoke(
model='gemini-pro',
credentials={

View File

@ -0,0 +1,20 @@
import os
from core.rag.extractor.excel_extractor import ExcelExtractor
from core.rag.models.document import Document
def test_extract_xlsx():
# Get the directory of the current file
current_dir = os.path.dirname(os.path.abspath(__file__))
# Get assets directory
assets_dir = os.path.join(os.path.dirname(current_dir), 'assets')
# Construct the path to the xlsx file
test_file_path = os.path.join(assets_dir, 'test.xlsx')
extractor = ExcelExtractor(test_file_path)
result = extractor.extract()
assert isinstance(result, list)
for item in result:
assert isinstance(item, Document)

View File

@ -0,0 +1,22 @@
import os
from core.rag.extractor.csv_extractor import CSVExtractor
from core.rag.models.document import Document
def test_extract_csv():
# Get the directory of the current file
current_dir = os.path.dirname(os.path.abspath(__file__))
# Get assets directory
assets_dir = os.path.join(os.path.dirname(current_dir), 'assets')
# Construct the path to the txt file
test_file_path = os.path.join(assets_dir, 'test.csv')
extractor = CSVExtractor(test_file_path, autodetect_encoding=True)
result = extractor.extract()
assert isinstance(result, list)
for item in result:
assert isinstance(item, Document)

View File

@ -0,0 +1,25 @@
import os
from core.rag.extractor.markdown_extractor import MarkdownExtractor
from core.rag.extractor.pdf_extractor import PdfExtractor
from core.rag.extractor.text_extractor import TextExtractor
from core.rag.extractor.word_extractor import WordExtractor
from core.rag.models.document import Document
def test_extract_docx():
# Get the directory of the current file
current_dir = os.path.dirname(os.path.abspath(__file__))
# Get assets directory
assets_dir = os.path.join(os.path.dirname(current_dir), 'assets')
# Construct the path to the markdown file
test_file_path = os.path.join(assets_dir, 'test.docx')
extractor = WordExtractor(test_file_path)
result = extractor.extract()
assert isinstance(result, list)
for item in result:
assert isinstance(item, Document)

View File

@ -0,0 +1,22 @@
import os
from core.rag.extractor.html_extractor import HtmlExtractor
from core.rag.models.document import Document
def test_extract_html():
# Get the directory of the current file
current_dir = os.path.dirname(os.path.abspath(__file__))
# Get assets directory
assets_dir = os.path.join(os.path.dirname(current_dir), 'assets')
# Construct the path to the markdown file
test_file_path = os.path.join(assets_dir, 'test.html')
extractor = HtmlExtractor(test_file_path)
result = extractor.extract()
assert isinstance(result, list)
for item in result:
assert isinstance(item, Document)

View File

@ -0,0 +1,23 @@
import os
from core.rag.extractor.markdown_extractor import MarkdownExtractor
from core.rag.extractor.text_extractor import TextExtractor
from core.rag.models.document import Document
def test_extract_markdown():
# Get the directory of the current file
current_dir = os.path.dirname(os.path.abspath(__file__))
# Get assets directory
assets_dir = os.path.join(os.path.dirname(current_dir), 'assets')
# Construct the path to the markdown file
test_file_path = os.path.join(assets_dir, 'test.md')
extractor = MarkdownExtractor(test_file_path, autodetect_encoding=True)
result = extractor.extract()
assert isinstance(result, list)
for item in result:
assert isinstance(item, Document)

View File

@ -0,0 +1,24 @@
import os
from core.rag.extractor.markdown_extractor import MarkdownExtractor
from core.rag.extractor.pdf_extractor import PdfExtractor
from core.rag.extractor.text_extractor import TextExtractor
from core.rag.models.document import Document
def test_extract_pdf():
# Get the directory of the current file
current_dir = os.path.dirname(os.path.abspath(__file__))
# Get assets directory
assets_dir = os.path.join(os.path.dirname(current_dir), 'assets')
# Construct the path to the markdown file
test_file_path = os.path.join(assets_dir, 'test.pdf')
extractor = PdfExtractor(test_file_path)
result = extractor.extract()
assert isinstance(result, list)
for item in result:
assert isinstance(item, Document)

View File

@ -0,0 +1,21 @@
import os
from core.rag.extractor.text_extractor import TextExtractor
from core.rag.models.document import Document
def test_extract_text():
# Get the directory of the current file
current_dir = os.path.dirname(os.path.abspath(__file__))
# Get assets directory
assets_dir = os.path.join(os.path.dirname(current_dir), 'assets')
# Construct the path to the txt file
test_file_path = os.path.join(assets_dir, 'test.txt')
extractor = TextExtractor(test_file_path, autodetect_encoding=True)
result = extractor.extract()
assert isinstance(result, list)
for item in result:
assert isinstance(item, Document)