| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114 | """Functionality for splitting text."""from __future__ import annotationsfrom typing import Any, Optional, castfrom core.model_manager import ModelInstancefrom core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModelfrom core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizerfrom core.splitter.text_splitter import (    TS,    Collection,    Literal,    RecursiveCharacterTextSplitter,    Set,    TokenTextSplitter,    Union,)class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):    """        This class is used to implement from_gpt2_encoder, to prevent using of tiktoken    """    @classmethod    def from_encoder(            cls: type[TS],            embedding_model_instance: Optional[ModelInstance],            allowed_special: Union[Literal[all], Set[str]] = set(),            disallowed_special: Union[Literal[all], Collection[str]] = "all",            **kwargs: Any,    ):        def _token_encoder(text: str) -> int:            if not text:                return 0            if embedding_model_instance:                embedding_model_type_instance = embedding_model_instance.model_type_instance                embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)                return embedding_model_type_instance.get_num_tokens(                    model=embedding_model_instance.model,                    credentials=embedding_model_instance.credentials,                    texts=[text]                )            else:                return GPT2Tokenizer.get_num_tokens(text)        if issubclass(cls, TokenTextSplitter):            extra_kwargs = {                "model_name": embedding_model_instance.model if embedding_model_instance else 'gpt2',                "allowed_special": allowed_special,                "disallowed_special": disallowed_special,            }            kwargs = {**kwargs, **extra_kwargs}        return cls(length_function=_token_encoder, **kwargs)class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter):    def __init__(self, fixed_separator: str = "\n\n", separators: Optional[list[str]] = None, **kwargs: Any):        """Create a new TextSplitter."""        super().__init__(**kwargs)        self._fixed_separator = fixed_separator        self._separators = separators or ["\n\n", "\n", " ", ""]    def split_text(self, text: str) -> list[str]:        """Split incoming text and return chunks."""        if self._fixed_separator:            chunks = text.split(self._fixed_separator)        else:            chunks = list(text)        final_chunks = []        for chunk in chunks:            if self._length_function(chunk) > self._chunk_size:                final_chunks.extend(self.recursive_split_text(chunk))            else:                final_chunks.append(chunk)        return final_chunks    def recursive_split_text(self, text: str) -> list[str]:        """Split incoming text and return chunks."""        final_chunks = []        # Get appropriate separator to use        separator = self._separators[-1]        for _s in self._separators:            if _s == "":                separator = _s                break            if _s in text:                separator = _s                break        # Now that we have the separator, split the text        if separator:            splits = text.split(separator)        else:            splits = list(text)        # Now go merging things, recursively splitting longer texts.        _good_splits = []        for s in splits:            if self._length_function(s) < self._chunk_size:                _good_splits.append(s)            else:                if _good_splits:                    merged_text = self._merge_splits(_good_splits, separator)                    final_chunks.extend(merged_text)                    _good_splits = []                other_info = self.recursive_split_text(s)                final_chunks.extend(other_info)        if _good_splits:            merged_text = self._merge_splits(_good_splits, separator)            final_chunks.extend(merged_text)        return final_chunks
 |