"""Functionality for splitting text.""" from __future__ import annotations from typing import ( Any, List, Optional, ) from langchain.text_splitter import RecursiveCharacterTextSplitter class FixedRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): def __init__(self, fixed_separator: str = "\n\n", separators: Optional[List[str]] = None, **kwargs: Any): """Create a new TextSplitter.""" super().__init__(**kwargs) self._fixed_separator = fixed_separator self._separators = separators or ["\n\n", "\n", " ", ""] def split_text(self, text: str) -> List[str]: """Split incoming text and return chunks.""" if self._fixed_separator: chunks = text.split(self._fixed_separator) else: chunks = list(text) final_chunks = [] for chunk in chunks: if self._length_function(chunk) > self._chunk_size: final_chunks.extend(self.recursive_split_text(chunk)) else: final_chunks.append(chunk) return final_chunks def recursive_split_text(self, text: str) -> List[str]: """Split incoming text and return chunks.""" final_chunks = [] # Get appropriate separator to use separator = self._separators[-1] for _s in self._separators: if _s == "": separator = _s break if _s in text: separator = _s break # Now that we have the separator, split the text if separator: splits = text.split(separator) else: splits = list(text) # Now go merging things, recursively splitting longer texts. _good_splits = [] for s in splits: if self._length_function(s) < self._chunk_size: _good_splits.append(s) else: if _good_splits: merged_text = self._merge_splits(_good_splits, separator) final_chunks.extend(merged_text) _good_splits = [] other_info = self.recursive_split_text(s) final_chunks.extend(other_info) if _good_splits: merged_text = self._merge_splits(_good_splits, separator) final_chunks.extend(merged_text) return final_chunks