mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 03:32:23 +08:00
text spliter length method use default embedding model tokenizer (#2011)
Co-authored-by: jyong <jyong@dify.ai>
This commit is contained in:
parent
1779cea6e3
commit
a63a9c7d45
|
@ -13,7 +13,7 @@ from core.docstore.dataset_docstore import DatasetDocumentStore
|
||||||
from core.errors.error import ProviderTokenNotInitError
|
from core.errors.error import ProviderTokenNotInitError
|
||||||
from core.generator.llm_generator import LLMGenerator
|
from core.generator.llm_generator import LLMGenerator
|
||||||
from core.index.index import IndexBuilder
|
from core.index.index import IndexBuilder
|
||||||
from core.model_manager import ModelManager
|
from core.model_manager import ModelManager, ModelInstance
|
||||||
from core.model_runtime.entities.model_entities import ModelType, PriceType
|
from core.model_runtime.entities.model_entities import ModelType, PriceType
|
||||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||||
|
@ -61,8 +61,24 @@ class IndexingRunner:
|
||||||
# load file
|
# load file
|
||||||
text_docs = self._load_data(dataset_document, processing_rule.mode == 'automatic')
|
text_docs = self._load_data(dataset_document, processing_rule.mode == 'automatic')
|
||||||
|
|
||||||
|
# get embedding model instance
|
||||||
|
embedding_model_instance = None
|
||||||
|
if dataset.indexing_technique == 'high_quality':
|
||||||
|
if dataset.embedding_model_provider:
|
||||||
|
embedding_model_instance = self.model_manager.get_model_instance(
|
||||||
|
tenant_id=dataset.tenant_id,
|
||||||
|
provider=dataset.embedding_model_provider,
|
||||||
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
|
model=dataset.embedding_model
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
embedding_model_instance = self.model_manager.get_default_model_instance(
|
||||||
|
tenant_id=dataset.tenant_id,
|
||||||
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
|
)
|
||||||
|
|
||||||
# get splitter
|
# get splitter
|
||||||
splitter = self._get_splitter(processing_rule)
|
splitter = self._get_splitter(processing_rule, embedding_model_instance)
|
||||||
|
|
||||||
# split to documents
|
# split to documents
|
||||||
documents = self._step_split(
|
documents = self._step_split(
|
||||||
|
@ -121,8 +137,24 @@ class IndexingRunner:
|
||||||
# load file
|
# load file
|
||||||
text_docs = self._load_data(dataset_document, processing_rule.mode == 'automatic')
|
text_docs = self._load_data(dataset_document, processing_rule.mode == 'automatic')
|
||||||
|
|
||||||
|
# get embedding model instance
|
||||||
|
embedding_model_instance = None
|
||||||
|
if dataset.indexing_technique == 'high_quality':
|
||||||
|
if dataset.embedding_model_provider:
|
||||||
|
embedding_model_instance = self.model_manager.get_model_instance(
|
||||||
|
tenant_id=dataset.tenant_id,
|
||||||
|
provider=dataset.embedding_model_provider,
|
||||||
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
|
model=dataset.embedding_model
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
embedding_model_instance = self.model_manager.get_default_model_instance(
|
||||||
|
tenant_id=dataset.tenant_id,
|
||||||
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
|
)
|
||||||
|
|
||||||
# get splitter
|
# get splitter
|
||||||
splitter = self._get_splitter(processing_rule)
|
splitter = self._get_splitter(processing_rule, embedding_model_instance)
|
||||||
|
|
||||||
# split to documents
|
# split to documents
|
||||||
documents = self._step_split(
|
documents = self._step_split(
|
||||||
|
@ -253,7 +285,7 @@ class IndexingRunner:
|
||||||
text_docs = FileExtractor.load(file_detail, is_automatic=processing_rule.mode == 'automatic')
|
text_docs = FileExtractor.load(file_detail, is_automatic=processing_rule.mode == 'automatic')
|
||||||
|
|
||||||
# get splitter
|
# get splitter
|
||||||
splitter = self._get_splitter(processing_rule)
|
splitter = self._get_splitter(processing_rule, embedding_model_instance)
|
||||||
|
|
||||||
# split to documents
|
# split to documents
|
||||||
documents = self._split_to_documents_for_estimate(
|
documents = self._split_to_documents_for_estimate(
|
||||||
|
@ -384,7 +416,7 @@ class IndexingRunner:
|
||||||
)
|
)
|
||||||
|
|
||||||
# get splitter
|
# get splitter
|
||||||
splitter = self._get_splitter(processing_rule)
|
splitter = self._get_splitter(processing_rule, embedding_model_instance)
|
||||||
|
|
||||||
# split to documents
|
# split to documents
|
||||||
documents = self._split_to_documents_for_estimate(
|
documents = self._split_to_documents_for_estimate(
|
||||||
|
@ -502,7 +534,8 @@ class IndexingRunner:
|
||||||
text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\x80-\xFF]', '', text)
|
text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\x80-\xFF]', '', text)
|
||||||
return text
|
return text
|
||||||
|
|
||||||
def _get_splitter(self, processing_rule: DatasetProcessRule) -> TextSplitter:
|
def _get_splitter(self, processing_rule: DatasetProcessRule,
|
||||||
|
embedding_model_instance: Optional[ModelInstance]) -> TextSplitter:
|
||||||
"""
|
"""
|
||||||
Get the NodeParser object according to the processing rule.
|
Get the NodeParser object according to the processing rule.
|
||||||
"""
|
"""
|
||||||
|
@ -517,19 +550,20 @@ class IndexingRunner:
|
||||||
if separator:
|
if separator:
|
||||||
separator = separator.replace('\\n', '\n')
|
separator = separator.replace('\\n', '\n')
|
||||||
|
|
||||||
|
character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder(
|
||||||
character_splitter = FixedRecursiveCharacterTextSplitter.from_gpt2_encoder(
|
|
||||||
chunk_size=segmentation["max_tokens"],
|
chunk_size=segmentation["max_tokens"],
|
||||||
chunk_overlap=0,
|
chunk_overlap=0,
|
||||||
fixed_separator=separator,
|
fixed_separator=separator,
|
||||||
separators=["\n\n", "。", ".", " ", ""]
|
separators=["\n\n", "。", ".", " ", ""],
|
||||||
|
embedding_model_instance=embedding_model_instance
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Automatic segmentation
|
# Automatic segmentation
|
||||||
character_splitter = EnhanceRecursiveCharacterTextSplitter.from_gpt2_encoder(
|
character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder(
|
||||||
chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'],
|
chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'],
|
||||||
chunk_overlap=0,
|
chunk_overlap=0,
|
||||||
separators=["\n\n", "。", ".", " ", ""]
|
separators=["\n\n", "。", ".", " ", ""],
|
||||||
|
embedding_model_instance=embedding_model_instance
|
||||||
)
|
)
|
||||||
|
|
||||||
return character_splitter
|
return character_splitter
|
||||||
|
|
|
@ -1,8 +1,10 @@
|
||||||
"""Functionality for splitting text."""
|
"""Functionality for splitting text."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, List, Optional
|
from typing import Any, List, Optional, cast
|
||||||
|
|
||||||
|
from core.model_manager import ModelInstance
|
||||||
|
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||||
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
|
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
|
||||||
from langchain.text_splitter import (TS, AbstractSet, Collection, Literal, RecursiveCharacterTextSplitter,
|
from langchain.text_splitter import (TS, AbstractSet, Collection, Literal, RecursiveCharacterTextSplitter,
|
||||||
TokenTextSplitter, Type, Union)
|
TokenTextSplitter, Type, Union)
|
||||||
|
@ -12,22 +14,30 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
|
||||||
"""
|
"""
|
||||||
This class is used to implement from_gpt2_encoder, to prevent using of tiktoken
|
This class is used to implement from_gpt2_encoder, to prevent using of tiktoken
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_gpt2_encoder(
|
def from_encoder(
|
||||||
cls: Type[TS],
|
cls: Type[TS],
|
||||||
encoding_name: str = "gpt2",
|
embedding_model_instance: Optional[ModelInstance],
|
||||||
model_name: Optional[str] = None,
|
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
|
||||||
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
|
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
||||||
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
**kwargs: Any,
|
||||||
**kwargs: Any,
|
|
||||||
):
|
):
|
||||||
def _token_encoder(text: str) -> int:
|
def _token_encoder(text: str) -> int:
|
||||||
return GPT2Tokenizer.get_num_tokens(text)
|
if embedding_model_instance:
|
||||||
|
embedding_model_type_instance = embedding_model_instance.model_type_instance
|
||||||
|
embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
|
||||||
|
return embedding_model_type_instance.get_num_tokens(
|
||||||
|
model=embedding_model_instance.model,
|
||||||
|
credentials=embedding_model_instance.credentials,
|
||||||
|
texts=[text]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return GPT2Tokenizer.get_num_tokens(text)
|
||||||
|
|
||||||
if issubclass(cls, TokenTextSplitter):
|
if issubclass(cls, TokenTextSplitter):
|
||||||
extra_kwargs = {
|
extra_kwargs = {
|
||||||
"encoding_name": encoding_name,
|
"model_name": embedding_model_instance.model if embedding_model_instance else 'gpt2',
|
||||||
"model_name": model_name,
|
|
||||||
"allowed_special": allowed_special,
|
"allowed_special": allowed_special,
|
||||||
"disallowed_special": disallowed_special,
|
"disallowed_special": disallowed_special,
|
||||||
}
|
}
|
||||||
|
@ -35,6 +45,7 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
|
||||||
|
|
||||||
return cls(length_function=_token_encoder, **kwargs)
|
return cls(length_function=_token_encoder, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter):
|
class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter):
|
||||||
def __init__(self, fixed_separator: str = "\n\n", separators: Optional[List[str]] = None, **kwargs: Any):
|
def __init__(self, fixed_separator: str = "\n\n", separators: Optional[List[str]] = None, **kwargs: Any):
|
||||||
"""Create a new TextSplitter."""
|
"""Create a new TextSplitter."""
|
||||||
|
|
Loading…
Reference in New Issue
Block a user