From a63a9c7d458c126cdb6ff534de285e4ec8bebacd Mon Sep 17 00:00:00 2001 From: Jyong <76649700+JohnJyong@users.noreply.github.com> Date: Fri, 12 Jan 2024 18:45:34 +0800 Subject: [PATCH] text spliter length method use default embedding model tokenizer (#2011) Co-authored-by: jyong --- api/core/indexing_runner.py | 58 +++++++++++++++++++----- api/core/spiltter/fixed_text_splitter.py | 35 +++++++++----- 2 files changed, 69 insertions(+), 24 deletions(-) diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 51c42bf75b..28a99d2f7f 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -13,7 +13,7 @@ from core.docstore.dataset_docstore import DatasetDocumentStore from core.errors.error import ProviderTokenNotInitError from core.generator.llm_generator import LLMGenerator from core.index.index import IndexBuilder -from core.model_manager import ModelManager +from core.model_manager import ModelManager, ModelInstance from core.model_runtime.entities.model_entities import ModelType, PriceType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel @@ -61,8 +61,24 @@ class IndexingRunner: # load file text_docs = self._load_data(dataset_document, processing_rule.mode == 'automatic') + # get embedding model instance + embedding_model_instance = None + if dataset.indexing_technique == 'high_quality': + if dataset.embedding_model_provider: + embedding_model_instance = self.model_manager.get_model_instance( + tenant_id=dataset.tenant_id, + provider=dataset.embedding_model_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=dataset.embedding_model + ) + else: + embedding_model_instance = self.model_manager.get_default_model_instance( + tenant_id=dataset.tenant_id, + model_type=ModelType.TEXT_EMBEDDING, + ) + # get splitter - splitter = self._get_splitter(processing_rule) + splitter = self._get_splitter(processing_rule, embedding_model_instance) # split to documents documents = self._step_split( @@ -121,8 +137,24 @@ class IndexingRunner: # load file text_docs = self._load_data(dataset_document, processing_rule.mode == 'automatic') + # get embedding model instance + embedding_model_instance = None + if dataset.indexing_technique == 'high_quality': + if dataset.embedding_model_provider: + embedding_model_instance = self.model_manager.get_model_instance( + tenant_id=dataset.tenant_id, + provider=dataset.embedding_model_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=dataset.embedding_model + ) + else: + embedding_model_instance = self.model_manager.get_default_model_instance( + tenant_id=dataset.tenant_id, + model_type=ModelType.TEXT_EMBEDDING, + ) + # get splitter - splitter = self._get_splitter(processing_rule) + splitter = self._get_splitter(processing_rule, embedding_model_instance) # split to documents documents = self._step_split( @@ -253,7 +285,7 @@ class IndexingRunner: text_docs = FileExtractor.load(file_detail, is_automatic=processing_rule.mode == 'automatic') # get splitter - splitter = self._get_splitter(processing_rule) + splitter = self._get_splitter(processing_rule, embedding_model_instance) # split to documents documents = self._split_to_documents_for_estimate( @@ -384,7 +416,7 @@ class IndexingRunner: ) # get splitter - splitter = self._get_splitter(processing_rule) + splitter = self._get_splitter(processing_rule, embedding_model_instance) # split to documents documents = self._split_to_documents_for_estimate( @@ -502,7 +534,8 @@ class IndexingRunner: text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\x80-\xFF]', '', text) return text - def _get_splitter(self, processing_rule: DatasetProcessRule) -> TextSplitter: + def _get_splitter(self, processing_rule: DatasetProcessRule, + embedding_model_instance: Optional[ModelInstance]) -> TextSplitter: """ Get the NodeParser object according to the processing rule. """ @@ -517,19 +550,20 @@ class IndexingRunner: if separator: separator = separator.replace('\\n', '\n') - - character_splitter = FixedRecursiveCharacterTextSplitter.from_gpt2_encoder( + character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder( chunk_size=segmentation["max_tokens"], chunk_overlap=0, fixed_separator=separator, - separators=["\n\n", "。", ".", " ", ""] + separators=["\n\n", "。", ".", " ", ""], + embedding_model_instance=embedding_model_instance ) else: # Automatic segmentation - character_splitter = EnhanceRecursiveCharacterTextSplitter.from_gpt2_encoder( + character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder( chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'], chunk_overlap=0, - separators=["\n\n", "。", ".", " ", ""] + separators=["\n\n", "。", ".", " ", ""], + embedding_model_instance=embedding_model_instance ) return character_splitter @@ -714,7 +748,7 @@ class IndexingRunner: return text def format_split_text(self, text): - regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q\d+:|$)" + regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q\d+:|$)" matches = re.findall(regex, text, re.UNICODE) return [ diff --git a/api/core/spiltter/fixed_text_splitter.py b/api/core/spiltter/fixed_text_splitter.py index 80d609d800..b8f384eee9 100644 --- a/api/core/spiltter/fixed_text_splitter.py +++ b/api/core/spiltter/fixed_text_splitter.py @@ -1,8 +1,10 @@ """Functionality for splitting text.""" from __future__ import annotations -from typing import Any, List, Optional +from typing import Any, List, Optional, cast +from core.model_manager import ModelInstance +from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer from langchain.text_splitter import (TS, AbstractSet, Collection, Literal, RecursiveCharacterTextSplitter, TokenTextSplitter, Type, Union) @@ -12,22 +14,30 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): """ This class is used to implement from_gpt2_encoder, to prevent using of tiktoken """ + @classmethod - def from_gpt2_encoder( - cls: Type[TS], - encoding_name: str = "gpt2", - model_name: Optional[str] = None, - allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), - disallowed_special: Union[Literal["all"], Collection[str]] = "all", - **kwargs: Any, + def from_encoder( + cls: Type[TS], + embedding_model_instance: Optional[ModelInstance], + allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), + disallowed_special: Union[Literal["all"], Collection[str]] = "all", + **kwargs: Any, ): def _token_encoder(text: str) -> int: - return GPT2Tokenizer.get_num_tokens(text) + if embedding_model_instance: + embedding_model_type_instance = embedding_model_instance.model_type_instance + embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance) + return embedding_model_type_instance.get_num_tokens( + model=embedding_model_instance.model, + credentials=embedding_model_instance.credentials, + texts=[text] + ) + else: + return GPT2Tokenizer.get_num_tokens(text) if issubclass(cls, TokenTextSplitter): extra_kwargs = { - "encoding_name": encoding_name, - "model_name": model_name, + "model_name": embedding_model_instance.model if embedding_model_instance else 'gpt2', "allowed_special": allowed_special, "disallowed_special": disallowed_special, } @@ -35,6 +45,7 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): return cls(length_function=_token_encoder, **kwargs) + class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter): def __init__(self, fixed_separator: str = "\n\n", separators: Optional[List[str]] = None, **kwargs: Any): """Create a new TextSplitter.""" @@ -90,4 +101,4 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter) if _good_splits: merged_text = self._merge_splits(_good_splits, separator) final_chunks.extend(merged_text) - return final_chunks \ No newline at end of file + return final_chunks