Browse Source

fix: remove tiktoken from text splitter (#1876)

Yeuoly 1 year ago
parent
commit
9134849744
2 changed files with 38 additions and 8 deletions
  1. 7 5
      api/core/indexing_runner.py
  2. 31 3
      api/core/spiltter/fixed_text_splitter.py

+ 7 - 5
api/core/indexing_runner.py

@@ -5,12 +5,12 @@ import re
 import threading
 import time
 import uuid
-from typing import Optional, List, cast
+from typing import Optional, List, cast, Type, Union, Literal, AbstractSet, Collection, Any
 
 from flask import current_app, Flask
 from flask_login import current_user
 from langchain.schema import Document
-from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
+from langchain.text_splitter import TextSplitter, TS, TokenTextSplitter
 from sqlalchemy.orm.exc import ObjectDeletedError
 
 from core.data_loader.file_extractor import FileExtractor
@@ -23,7 +23,8 @@ from core.errors.error import ProviderTokenNotInitError
 from core.model_runtime.entities.model_entities import ModelType, PriceType
 from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
 from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
-from core.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter
+from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
+from core.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter, EnhanceRecursiveCharacterTextSplitter
 from extensions.ext_database import db
 from extensions.ext_redis import redis_client
 from extensions.ext_storage import storage
@@ -502,7 +503,8 @@ class IndexingRunner:
             if separator:
                 separator = separator.replace('\\n', '\n')
 
-            character_splitter = FixedRecursiveCharacterTextSplitter.from_tiktoken_encoder(
+
+            character_splitter = FixedRecursiveCharacterTextSplitter.from_gpt2_encoder(
                 chunk_size=segmentation["max_tokens"],
                 chunk_overlap=0,
                 fixed_separator=separator,
@@ -510,7 +512,7 @@ class IndexingRunner:
             )
         else:
             # Automatic segmentation
-            character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
+            character_splitter = EnhanceRecursiveCharacterTextSplitter.from_gpt2_encoder(
                 chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'],
                 chunk_overlap=0,
                 separators=["\n\n", "。", ".", " ", ""]

+ 31 - 3
api/core/spiltter/fixed_text_splitter.py

@@ -7,10 +7,38 @@ from typing import (
     Optional,
 )
 
-from langchain.text_splitter import RecursiveCharacterTextSplitter
+from langchain.text_splitter import RecursiveCharacterTextSplitter, TokenTextSplitter, TS, Type, Union, AbstractSet, Literal, Collection
 
+from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
 
-class FixedRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
+class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
+    """
+        This class is used to implement from_gpt2_encoder, to prevent using of tiktoken
+    """
+    @classmethod
+    def from_gpt2_encoder(
+        cls: Type[TS],
+        encoding_name: str = "gpt2",
+        model_name: Optional[str] = None,
+        allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
+        disallowed_special: Union[Literal["all"], Collection[str]] = "all",
+        **kwargs: Any,
+    ):
+        def _token_encoder(text: str) -> int:
+            return GPT2Tokenizer.get_num_tokens(text)
+
+        if issubclass(cls, TokenTextSplitter):
+            extra_kwargs = {
+                "encoding_name": encoding_name,
+                "model_name": model_name,
+                "allowed_special": allowed_special,
+                "disallowed_special": disallowed_special,
+            }
+            kwargs = {**kwargs, **extra_kwargs}
+
+        return cls(length_function=_token_encoder, **kwargs)
+
+class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter):
     def __init__(self, fixed_separator: str = "\n\n", separators: Optional[List[str]] = None, **kwargs: Any):
         """Create a new TextSplitter."""
         super().__init__(**kwargs)
@@ -65,4 +93,4 @@ class FixedRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
         if _good_splits:
             merged_text = self._merge_splits(_good_splits, separator)
             final_chunks.extend(merged_text)
-        return final_chunks
+        return final_chunks