dify/api/core/spiltter/fixed_text_splitter.py
2024-01-12 18:45:34 +08:00

105 lines
4.2 KiB
Python

"""Functionality for splitting text."""
from __future__ import annotations
from typing import Any, List, Optional, cast
from core.model_manager import ModelInstance
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
from langchain.text_splitter import (TS, AbstractSet, Collection, Literal, RecursiveCharacterTextSplitter,
TokenTextSplitter, Type, Union)
class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
"""
This class is used to implement from_gpt2_encoder, to prevent using of tiktoken
"""
@classmethod
def from_encoder(
cls: Type[TS],
embedding_model_instance: Optional[ModelInstance],
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
**kwargs: Any,
):
def _token_encoder(text: str) -> int:
if embedding_model_instance:
embedding_model_type_instance = embedding_model_instance.model_type_instance
embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
return embedding_model_type_instance.get_num_tokens(
model=embedding_model_instance.model,
credentials=embedding_model_instance.credentials,
texts=[text]
)
else:
return GPT2Tokenizer.get_num_tokens(text)
if issubclass(cls, TokenTextSplitter):
extra_kwargs = {
"model_name": embedding_model_instance.model if embedding_model_instance else 'gpt2',
"allowed_special": allowed_special,
"disallowed_special": disallowed_special,
}
kwargs = {**kwargs, **extra_kwargs}
return cls(length_function=_token_encoder, **kwargs)
class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter):
def __init__(self, fixed_separator: str = "\n\n", separators: Optional[List[str]] = None, **kwargs: Any):
"""Create a new TextSplitter."""
super().__init__(**kwargs)
self._fixed_separator = fixed_separator
self._separators = separators or ["\n\n", "\n", " ", ""]
def split_text(self, text: str) -> List[str]:
"""Split incoming text and return chunks."""
if self._fixed_separator:
chunks = text.split(self._fixed_separator)
else:
chunks = list(text)
final_chunks = []
for chunk in chunks:
if self._length_function(chunk) > self._chunk_size:
final_chunks.extend(self.recursive_split_text(chunk))
else:
final_chunks.append(chunk)
return final_chunks
def recursive_split_text(self, text: str) -> List[str]:
"""Split incoming text and return chunks."""
final_chunks = []
# Get appropriate separator to use
separator = self._separators[-1]
for _s in self._separators:
if _s == "":
separator = _s
break
if _s in text:
separator = _s
break
# Now that we have the separator, split the text
if separator:
splits = text.split(separator)
else:
splits = list(text)
# Now go merging things, recursively splitting longer texts.
_good_splits = []
for s in splits:
if self._length_function(s) < self._chunk_size:
_good_splits.append(s)
else:
if _good_splits:
merged_text = self._merge_splits(_good_splits, separator)
final_chunks.extend(merged_text)
_good_splits = []
other_info = self.recursive_split_text(s)
final_chunks.extend(other_info)
if _good_splits:
merged_text = self._merge_splits(_good_splits, separator)
final_chunks.extend(merged_text)
return final_chunks