"""Functionality for splitting text.""" from __future__ import annotations from typing import Any, List, Optional, cast from core.model_manager import ModelInstance from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer from langchain.text_splitter import (TS, AbstractSet, Collection, Literal, RecursiveCharacterTextSplitter, TokenTextSplitter, Type, Union) class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): """ This class is used to implement from_gpt2_encoder, to prevent using of tiktoken """ @classmethod def from_encoder( cls: Type[TS], embedding_model_instance: Optional[ModelInstance], allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), disallowed_special: Union[Literal["all"], Collection[str]] = "all", **kwargs: Any, ): def _token_encoder(text: str) -> int: if not text: return 0 if embedding_model_instance: embedding_model_type_instance = embedding_model_instance.model_type_instance embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance) return embedding_model_type_instance.get_num_tokens( model=embedding_model_instance.model, credentials=embedding_model_instance.credentials, texts=[text] ) else: return GPT2Tokenizer.get_num_tokens(text) if issubclass(cls, TokenTextSplitter): extra_kwargs = { "model_name": embedding_model_instance.model if embedding_model_instance else 'gpt2', "allowed_special": allowed_special, "disallowed_special": disallowed_special, } kwargs = {**kwargs, **extra_kwargs} return cls(length_function=_token_encoder, **kwargs) class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter): def __init__(self, fixed_separator: str = "\n\n", separators: Optional[List[str]] = None, **kwargs: Any): """Create a new TextSplitter.""" super().__init__(**kwargs) self._fixed_separator = fixed_separator self._separators = separators or ["\n\n", "\n", " ", ""] def split_text(self, text: str) -> List[str]: """Split incoming text and return chunks.""" if self._fixed_separator: chunks = text.split(self._fixed_separator) else: chunks = list(text) final_chunks = [] for chunk in chunks: if self._length_function(chunk) > self._chunk_size: final_chunks.extend(self.recursive_split_text(chunk)) else: final_chunks.append(chunk) return final_chunks def recursive_split_text(self, text: str) -> List[str]: """Split incoming text and return chunks.""" final_chunks = [] # Get appropriate separator to use separator = self._separators[-1] for _s in self._separators: if _s == "": separator = _s break if _s in text: separator = _s break # Now that we have the separator, split the text if separator: splits = text.split(separator) else: splits = list(text) # Now go merging things, recursively splitting longer texts. _good_splits = [] for s in splits: if self._length_function(s) < self._chunk_size: _good_splits.append(s) else: if _good_splits: merged_text = self._merge_splits(_good_splits, separator) final_chunks.extend(merged_text) _good_splits = [] other_info = self.recursive_split_text(s) final_chunks.extend(other_info) if _good_splits: merged_text = self._merge_splits(_good_splits, separator) final_chunks.extend(merged_text) return final_chunks