| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903 | from __future__ import annotationsimport copyimport loggingimport refrom abc import ABC, abstractmethodfrom collections.abc import Callable, Collection, Iterable, Sequence, Setfrom dataclasses import dataclassfrom enum import Enumfrom typing import (    Any,    Literal,    Optional,    TypedDict,    TypeVar,    Union,)from core.rag.models.document import BaseDocumentTransformer, Documentlogger = logging.getLogger(__name__)TS = TypeVar("TS", bound="TextSplitter")def _split_text_with_regex(        text: str, separator: str, keep_separator: bool) -> list[str]:    # Now that we have the separator, split the text    if separator:        if keep_separator:            # The parentheses in the pattern keep the delimiters in the result.            _splits = re.split(f"({re.escape(separator)})", text)            splits = [_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)]            if len(_splits) % 2 == 0:                splits += _splits[-1:]            splits = [_splits[0]] + splits        else:            splits = re.split(separator, text)    else:        splits = list(text)    return [s for s in splits if s != ""]class TextSplitter(BaseDocumentTransformer, ABC):    """Interface for splitting text into chunks."""    def __init__(            self,            chunk_size: int = 4000,            chunk_overlap: int = 200,            length_function: Callable[[str], int] = len,            keep_separator: bool = False,            add_start_index: bool = False,    ) -> None:        """Create a new TextSplitter.        Args:            chunk_size: Maximum size of chunks to return            chunk_overlap: Overlap in characters between chunks            length_function: Function that measures the length of given chunks            keep_separator: Whether to keep the separator in the chunks            add_start_index: If `True`, includes chunk's start index in metadata        """        if chunk_overlap > chunk_size:            raise ValueError(                f"Got a larger chunk overlap ({chunk_overlap}) than chunk size "                f"({chunk_size}), should be smaller."            )        self._chunk_size = chunk_size        self._chunk_overlap = chunk_overlap        self._length_function = length_function        self._keep_separator = keep_separator        self._add_start_index = add_start_index    @abstractmethod    def split_text(self, text: str) -> list[str]:        """Split text into multiple components."""    def create_documents(            self, texts: list[str], metadatas: Optional[list[dict]] = None    ) -> list[Document]:        """Create documents from a list of texts."""        _metadatas = metadatas or [{}] * len(texts)        documents = []        for i, text in enumerate(texts):            index = -1            for chunk in self.split_text(text):                metadata = copy.deepcopy(_metadatas[i])                if self._add_start_index:                    index = text.find(chunk, index + 1)                    metadata["start_index"] = index                new_doc = Document(page_content=chunk, metadata=metadata)                documents.append(new_doc)        return documents    def split_documents(self, documents: Iterable[Document] ) -> list[Document]:        """Split documents."""        texts, metadatas = [], []        for doc in documents:            texts.append(doc.page_content)            metadatas.append(doc.metadata)        return self.create_documents(texts, metadatas=metadatas)    def _join_docs(self, docs: list[str], separator: str) -> Optional[str]:        text = separator.join(docs)        text = text.strip()        if text == "":            return None        else:            return text    def _merge_splits(self, splits: Iterable[str], separator: str) -> list[str]:        # We now want to combine these smaller pieces into medium size        # chunks to send to the LLM.        separator_len = self._length_function(separator)        docs = []        current_doc: list[str] = []        total = 0        for d in splits:            _len = self._length_function(d)            if (                    total + _len + (separator_len if len(current_doc) > 0 else 0)                    > self._chunk_size            ):                if total > self._chunk_size:                    logger.warning(                        f"Created a chunk of size {total}, "                        f"which is longer than the specified {self._chunk_size}"                    )                if len(current_doc) > 0:                    doc = self._join_docs(current_doc, separator)                    if doc is not None:                        docs.append(doc)                    # Keep on popping if:                    # - we have a larger chunk than in the chunk overlap                    # - or if we still have any chunks and the length is long                    while total > self._chunk_overlap or (                            total + _len + (separator_len if len(current_doc) > 0 else 0)                            > self._chunk_size                            and total > 0                    ):                        total -= self._length_function(current_doc[0]) + (                            separator_len if len(current_doc) > 1 else 0                        )                        current_doc = current_doc[1:]            current_doc.append(d)            total += _len + (separator_len if len(current_doc) > 1 else 0)        doc = self._join_docs(current_doc, separator)        if doc is not None:            docs.append(doc)        return docs    @classmethod    def from_huggingface_tokenizer(cls, tokenizer: Any, **kwargs: Any) -> TextSplitter:        """Text splitter that uses HuggingFace tokenizer to count length."""        try:            from transformers import PreTrainedTokenizerBase            if not isinstance(tokenizer, PreTrainedTokenizerBase):                raise ValueError(                    "Tokenizer received was not an instance of PreTrainedTokenizerBase"                )            def _huggingface_tokenizer_length(text: str) -> int:                return len(tokenizer.encode(text))        except ImportError:            raise ValueError(                "Could not import transformers python package. "                "Please install it with `pip install transformers`."            )        return cls(length_function=_huggingface_tokenizer_length, **kwargs)    @classmethod    def from_tiktoken_encoder(            cls: type[TS],            encoding_name: str = "gpt2",            model_name: Optional[str] = None,            allowed_special: Union[Literal["all"], Set[str]] = set(),            disallowed_special: Union[Literal["all"], Collection[str]] = "all",            **kwargs: Any,    ) -> TS:        """Text splitter that uses tiktoken encoder to count length."""        try:            import tiktoken        except ImportError:            raise ImportError(                "Could not import tiktoken python package. "                "This is needed in order to calculate max_tokens_for_prompt. "                "Please install it with `pip install tiktoken`."            )        if model_name is not None:            enc = tiktoken.encoding_for_model(model_name)        else:            enc = tiktoken.get_encoding(encoding_name)        def _tiktoken_encoder(text: str) -> int:            return len(                enc.encode(                    text,                    allowed_special=allowed_special,                    disallowed_special=disallowed_special,                )            )        if issubclass(cls, TokenTextSplitter):            extra_kwargs = {                "encoding_name": encoding_name,                "model_name": model_name,                "allowed_special": allowed_special,                "disallowed_special": disallowed_special,            }            kwargs = {**kwargs, **extra_kwargs}        return cls(length_function=_tiktoken_encoder, **kwargs)    def transform_documents(            self, documents: Sequence[Document], **kwargs: Any    ) -> Sequence[Document]:        """Transform sequence of documents by splitting them."""        return self.split_documents(list(documents))    async def atransform_documents(            self, documents: Sequence[Document], **kwargs: Any    ) -> Sequence[Document]:        """Asynchronously transform a sequence of documents by splitting them."""        raise NotImplementedErrorclass CharacterTextSplitter(TextSplitter):    """Splitting text that looks at characters."""    def __init__(self, separator: str = "\n\n", **kwargs: Any) -> None:        """Create a new TextSplitter."""        super().__init__(**kwargs)        self._separator = separator    def split_text(self, text: str) -> list[str]:        """Split incoming text and return chunks."""        # First we naively split the large input into a bunch of smaller ones.        splits = _split_text_with_regex(text, self._separator, self._keep_separator)        _separator = "" if self._keep_separator else self._separator        return self._merge_splits(splits, _separator)class LineType(TypedDict):    """Line type as typed dict."""    metadata: dict[str, str]    content: strclass HeaderType(TypedDict):    """Header type as typed dict."""    level: int    name: str    data: strclass MarkdownHeaderTextSplitter:    """Splitting markdown files based on specified headers."""    def __init__(            self, headers_to_split_on: list[tuple[str, str]], return_each_line: bool = False    ):        """Create a new MarkdownHeaderTextSplitter.        Args:            headers_to_split_on: Headers we want to track            return_each_line: Return each line w/ associated headers        """        # Output line-by-line or aggregated into chunks w/ common headers        self.return_each_line = return_each_line        # Given the headers we want to split on,        # (e.g., "#, ##, etc") order by length        self.headers_to_split_on = sorted(            headers_to_split_on, key=lambda split: len(split[0]), reverse=True        )    def aggregate_lines_to_chunks(self, lines: list[LineType]) -> list[Document]:        """Combine lines with common metadata into chunks        Args:            lines: Line of text / associated header metadata        """        aggregated_chunks: list[LineType] = []        for line in lines:            if (                    aggregated_chunks                    and aggregated_chunks[-1]["metadata"] == line["metadata"]            ):                # If the last line in the aggregated list                # has the same metadata as the current line,                # append the current content to the last lines's content                aggregated_chunks[-1]["content"] += "  \n" + line["content"]            else:                # Otherwise, append the current line to the aggregated list                aggregated_chunks.append(line)        return [            Document(page_content=chunk["content"], metadata=chunk["metadata"])            for chunk in aggregated_chunks        ]    def split_text(self, text: str) -> list[Document]:        """Split markdown file        Args:            text: Markdown file"""        # Split the input text by newline character ("\n").        lines = text.split("\n")        # Final output        lines_with_metadata: list[LineType] = []        # Content and metadata of the chunk currently being processed        current_content: list[str] = []        current_metadata: dict[str, str] = {}        # Keep track of the nested header structure        # header_stack: List[Dict[str, Union[int, str]]] = []        header_stack: list[HeaderType] = []        initial_metadata: dict[str, str] = {}        for line in lines:            stripped_line = line.strip()            # Check each line against each of the header types (e.g., #, ##)            for sep, name in self.headers_to_split_on:                # Check if line starts with a header that we intend to split on                if stripped_line.startswith(sep) and (                        # Header with no text OR header is followed by space                        # Both are valid conditions that sep is being used a header                        len(stripped_line) == len(sep)                        or stripped_line[len(sep)] == " "                ):                    # Ensure we are tracking the header as metadata                    if name is not None:                        # Get the current header level                        current_header_level = sep.count("#")                        # Pop out headers of lower or same level from the stack                        while (                                header_stack                                and header_stack[-1]["level"] >= current_header_level                        ):                            # We have encountered a new header                            # at the same or higher level                            popped_header = header_stack.pop()                            # Clear the metadata for the                            # popped header in initial_metadata                            if popped_header["name"] in initial_metadata:                                initial_metadata.pop(popped_header["name"])                        # Push the current header to the stack                        header: HeaderType = {                            "level": current_header_level,                            "name": name,                            "data": stripped_line[len(sep):].strip(),                        }                        header_stack.append(header)                        # Update initial_metadata with the current header                        initial_metadata[name] = header["data"]                    # Add the previous line to the lines_with_metadata                    # only if current_content is not empty                    if current_content:                        lines_with_metadata.append(                            {                                "content": "\n".join(current_content),                                "metadata": current_metadata.copy(),                            }                        )                        current_content.clear()                    break            else:                if stripped_line:                    current_content.append(stripped_line)                elif current_content:                    lines_with_metadata.append(                        {                            "content": "\n".join(current_content),                            "metadata": current_metadata.copy(),                        }                    )                    current_content.clear()            current_metadata = initial_metadata.copy()        if current_content:            lines_with_metadata.append(                {"content": "\n".join(current_content), "metadata": current_metadata}            )        # lines_with_metadata has each line with associated header metadata        # aggregate these into chunks based on common metadata        if not self.return_each_line:            return self.aggregate_lines_to_chunks(lines_with_metadata)        else:            return [                Document(page_content=chunk["content"], metadata=chunk["metadata"])                for chunk in lines_with_metadata            ]# should be in newer Python versions (3.10+)# @dataclass(frozen=True, kw_only=True, slots=True)@dataclass(frozen=True)class Tokenizer:    chunk_overlap: int    tokens_per_chunk: int    decode: Callable[[list[int]], str]    encode: Callable[[str], list[int]]def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> list[str]:    """Split incoming text and return chunks using tokenizer."""    splits: list[str] = []    input_ids = tokenizer.encode(text)    start_idx = 0    cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))    chunk_ids = input_ids[start_idx:cur_idx]    while start_idx < len(input_ids):        splits.append(tokenizer.decode(chunk_ids))        start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap        cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))        chunk_ids = input_ids[start_idx:cur_idx]    return splitsclass TokenTextSplitter(TextSplitter):    """Splitting text to tokens using model tokenizer."""    def __init__(            self,            encoding_name: str = "gpt2",            model_name: Optional[str] = None,            allowed_special: Union[Literal["all"], Set[str]] = set(),            disallowed_special: Union[Literal["all"], Collection[str]] = "all",            **kwargs: Any,    ) -> None:        """Create a new TextSplitter."""        super().__init__(**kwargs)        try:            import tiktoken        except ImportError:            raise ImportError(                "Could not import tiktoken python package. "                "This is needed in order to for TokenTextSplitter. "                "Please install it with `pip install tiktoken`."            )        if model_name is not None:            enc = tiktoken.encoding_for_model(model_name)        else:            enc = tiktoken.get_encoding(encoding_name)        self._tokenizer = enc        self._allowed_special = allowed_special        self._disallowed_special = disallowed_special    def split_text(self, text: str) -> list[str]:        def _encode(_text: str) -> list[int]:            return self._tokenizer.encode(                _text,                allowed_special=self._allowed_special,                disallowed_special=self._disallowed_special,            )        tokenizer = Tokenizer(            chunk_overlap=self._chunk_overlap,            tokens_per_chunk=self._chunk_size,            decode=self._tokenizer.decode,            encode=_encode,        )        return split_text_on_tokens(text=text, tokenizer=tokenizer)class Language(str, Enum):    """Enum of the programming languages."""    CPP = "cpp"    GO = "go"    JAVA = "java"    JS = "js"    PHP = "php"    PROTO = "proto"    PYTHON = "python"    RST = "rst"    RUBY = "ruby"    RUST = "rust"    SCALA = "scala"    SWIFT = "swift"    MARKDOWN = "markdown"    LATEX = "latex"    HTML = "html"    SOL = "sol"class RecursiveCharacterTextSplitter(TextSplitter):    """Splitting text by recursively look at characters.    Recursively tries to split by different characters to find one    that works.    """    def __init__(            self,            separators: Optional[list[str]] = None,            keep_separator: bool = True,            **kwargs: Any,    ) -> None:        """Create a new TextSplitter."""        super().__init__(keep_separator=keep_separator, **kwargs)        self._separators = separators or ["\n\n", "\n", " ", ""]    def _split_text(self, text: str, separators: list[str]) -> list[str]:        """Split incoming text and return chunks."""        final_chunks = []        # Get appropriate separator to use        separator = separators[-1]        new_separators = []        for i, _s in enumerate(separators):            if _s == "":                separator = _s                break            if re.search(_s, text):                separator = _s                new_separators = separators[i + 1:]                break        splits = _split_text_with_regex(text, separator, self._keep_separator)        # Now go merging things, recursively splitting longer texts.        _good_splits = []        _separator = "" if self._keep_separator else separator        for s in splits:            if self._length_function(s) < self._chunk_size:                _good_splits.append(s)            else:                if _good_splits:                    merged_text = self._merge_splits(_good_splits, _separator)                    final_chunks.extend(merged_text)                    _good_splits = []                if not new_separators:                    final_chunks.append(s)                else:                    other_info = self._split_text(s, new_separators)                    final_chunks.extend(other_info)        if _good_splits:            merged_text = self._merge_splits(_good_splits, _separator)            final_chunks.extend(merged_text)        return final_chunks    def split_text(self, text: str) -> list[str]:        return self._split_text(text, self._separators)    @classmethod    def from_language(            cls, language: Language, **kwargs: Any    ) -> RecursiveCharacterTextSplitter:        separators = cls.get_separators_for_language(language)        return cls(separators=separators, **kwargs)    @staticmethod    def get_separators_for_language(language: Language) -> list[str]:        if language == Language.CPP:            return [                # Split along class definitions                "\nclass ",                # Split along function definitions                "\nvoid ",                "\nint ",                "\nfloat ",                "\ndouble ",                # Split along control flow statements                "\nif ",                "\nfor ",                "\nwhile ",                "\nswitch ",                "\ncase ",                # Split by the normal type of lines                "\n\n",                "\n",                " ",                "",            ]        elif language == Language.GO:            return [                # Split along function definitions                "\nfunc ",                "\nvar ",                "\nconst ",                "\ntype ",                # Split along control flow statements                "\nif ",                "\nfor ",                "\nswitch ",                "\ncase ",                # Split by the normal type of lines                "\n\n",                "\n",                " ",                "",            ]        elif language == Language.JAVA:            return [                # Split along class definitions                "\nclass ",                # Split along method definitions                "\npublic ",                "\nprotected ",                "\nprivate ",                "\nstatic ",                # Split along control flow statements                "\nif ",                "\nfor ",                "\nwhile ",                "\nswitch ",                "\ncase ",                # Split by the normal type of lines                "\n\n",                "\n",                " ",                "",            ]        elif language == Language.JS:            return [                # Split along function definitions                "\nfunction ",                "\nconst ",                "\nlet ",                "\nvar ",                "\nclass ",                # Split along control flow statements                "\nif ",                "\nfor ",                "\nwhile ",                "\nswitch ",                "\ncase ",                "\ndefault ",                # Split by the normal type of lines                "\n\n",                "\n",                " ",                "",            ]        elif language == Language.PHP:            return [                # Split along function definitions                "\nfunction ",                # Split along class definitions                "\nclass ",                # Split along control flow statements                "\nif ",                "\nforeach ",                "\nwhile ",                "\ndo ",                "\nswitch ",                "\ncase ",                # Split by the normal type of lines                "\n\n",                "\n",                " ",                "",            ]        elif language == Language.PROTO:            return [                # Split along message definitions                "\nmessage ",                # Split along service definitions                "\nservice ",                # Split along enum definitions                "\nenum ",                # Split along option definitions                "\noption ",                # Split along import statements                "\nimport ",                # Split along syntax declarations                "\nsyntax ",                # Split by the normal type of lines                "\n\n",                "\n",                " ",                "",            ]        elif language == Language.PYTHON:            return [                # First, try to split along class definitions                "\nclass ",                "\ndef ",                "\n\tdef ",                # Now split by the normal type of lines                "\n\n",                "\n",                " ",                "",            ]        elif language == Language.RST:            return [                # Split along section titles                "\n=+\n",                "\n-+\n",                "\n\*+\n",                # Split along directive markers                "\n\n.. *\n\n",                # Split by the normal type of lines                "\n\n",                "\n",                " ",                "",            ]        elif language == Language.RUBY:            return [                # Split along method definitions                "\ndef ",                "\nclass ",                # Split along control flow statements                "\nif ",                "\nunless ",                "\nwhile ",                "\nfor ",                "\ndo ",                "\nbegin ",                "\nrescue ",                # Split by the normal type of lines                "\n\n",                "\n",                " ",                "",            ]        elif language == Language.RUST:            return [                # Split along function definitions                "\nfn ",                "\nconst ",                "\nlet ",                # Split along control flow statements                "\nif ",                "\nwhile ",                "\nfor ",                "\nloop ",                "\nmatch ",                "\nconst ",                # Split by the normal type of lines                "\n\n",                "\n",                " ",                "",            ]        elif language == Language.SCALA:            return [                # Split along class definitions                "\nclass ",                "\nobject ",                # Split along method definitions                "\ndef ",                "\nval ",                "\nvar ",                # Split along control flow statements                "\nif ",                "\nfor ",                "\nwhile ",                "\nmatch ",                "\ncase ",                # Split by the normal type of lines                "\n\n",                "\n",                " ",                "",            ]        elif language == Language.SWIFT:            return [                # Split along function definitions                "\nfunc ",                # Split along class definitions                "\nclass ",                "\nstruct ",                "\nenum ",                # Split along control flow statements                "\nif ",                "\nfor ",                "\nwhile ",                "\ndo ",                "\nswitch ",                "\ncase ",                # Split by the normal type of lines                "\n\n",                "\n",                " ",                "",            ]        elif language == Language.MARKDOWN:            return [                # First, try to split along Markdown headings (starting with level 2)                "\n#{1,6} ",                # Note the alternative syntax for headings (below) is not handled here                # Heading level 2                # ---------------                # End of code block                "```\n",                # Horizontal lines                "\n\*\*\*+\n",                "\n---+\n",                "\n___+\n",                # Note that this splitter doesn't handle horizontal lines defined                # by *three or more* of ***, ---, or ___, but this is not handled                "\n\n",                "\n",                " ",                "",            ]        elif language == Language.LATEX:            return [                # First, try to split along Latex sections                "\n\\\chapter{",                "\n\\\section{",                "\n\\\subsection{",                "\n\\\subsubsection{",                # Now split by environments                "\n\\\begin{enumerate}",                "\n\\\begin{itemize}",                "\n\\\begin{description}",                "\n\\\begin{list}",                "\n\\\begin{quote}",                "\n\\\begin{quotation}",                "\n\\\begin{verse}",                "\n\\\begin{verbatim}",                # Now split by math environments                "\n\\\begin{align}",                "$$",                "$",                # Now split by the normal type of lines                " ",                "",            ]        elif language == Language.HTML:            return [                # First, try to split along HTML tags                "<body",                "<div",                "<p",                "<br",                "<li",                "<h1",                "<h2",                "<h3",                "<h4",                "<h5",                "<h6",                "<span",                "<table",                "<tr",                "<td",                "<th",                "<ul",                "<ol",                "<header",                "<footer",                "<nav",                # Head                "<head",                "<style",                "<script",                "<meta",                "<title",                "",            ]        elif language == Language.SOL:            return [                # Split along compiler information definitions                "\npragma ",                "\nusing ",                # Split along contract definitions                "\ncontract ",                "\ninterface ",                "\nlibrary ",                # Split along method definitions                "\nconstructor ",                "\ntype ",                "\nfunction ",                "\nevent ",                "\nmodifier ",                "\nerror ",                "\nstruct ",                "\nenum ",                # Split along control flow statements                "\nif ",                "\nfor ",                "\nwhile ",                "\ndo while ",                "\nassembly ",                # Split by the normal type of lines                "\n\n",                "\n",                " ",                "",            ]        else:            raise ValueError(                f"Language {language} is not supported! "                f"Please choose from {list(Language)}"            )
 |