fixed_text_splitter.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. """Functionality for splitting text."""
  2. from __future__ import annotations
  3. from typing import Any, List, Optional, cast
  4. from core.model_manager import ModelInstance
  5. from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
  6. from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
  7. from langchain.text_splitter import (TS, AbstractSet, Collection, Literal, RecursiveCharacterTextSplitter,
  8. TokenTextSplitter, Type, Union)
  9. class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
  10. """
  11. This class is used to implement from_gpt2_encoder, to prevent using of tiktoken
  12. """
  13. @classmethod
  14. def from_encoder(
  15. cls: Type[TS],
  16. embedding_model_instance: Optional[ModelInstance],
  17. allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
  18. disallowed_special: Union[Literal["all"], Collection[str]] = "all",
  19. **kwargs: Any,
  20. ):
  21. def _token_encoder(text: str) -> int:
  22. if embedding_model_instance:
  23. embedding_model_type_instance = embedding_model_instance.model_type_instance
  24. embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
  25. return embedding_model_type_instance.get_num_tokens(
  26. model=embedding_model_instance.model,
  27. credentials=embedding_model_instance.credentials,
  28. texts=[text]
  29. )
  30. else:
  31. return GPT2Tokenizer.get_num_tokens(text)
  32. if issubclass(cls, TokenTextSplitter):
  33. extra_kwargs = {
  34. "model_name": embedding_model_instance.model if embedding_model_instance else 'gpt2',
  35. "allowed_special": allowed_special,
  36. "disallowed_special": disallowed_special,
  37. }
  38. kwargs = {**kwargs, **extra_kwargs}
  39. return cls(length_function=_token_encoder, **kwargs)
  40. class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter):
  41. def __init__(self, fixed_separator: str = "\n\n", separators: Optional[List[str]] = None, **kwargs: Any):
  42. """Create a new TextSplitter."""
  43. super().__init__(**kwargs)
  44. self._fixed_separator = fixed_separator
  45. self._separators = separators or ["\n\n", "\n", " ", ""]
  46. def split_text(self, text: str) -> List[str]:
  47. """Split incoming text and return chunks."""
  48. if self._fixed_separator:
  49. chunks = text.split(self._fixed_separator)
  50. else:
  51. chunks = list(text)
  52. final_chunks = []
  53. for chunk in chunks:
  54. if self._length_function(chunk) > self._chunk_size:
  55. final_chunks.extend(self.recursive_split_text(chunk))
  56. else:
  57. final_chunks.append(chunk)
  58. return final_chunks
  59. def recursive_split_text(self, text: str) -> List[str]:
  60. """Split incoming text and return chunks."""
  61. final_chunks = []
  62. # Get appropriate separator to use
  63. separator = self._separators[-1]
  64. for _s in self._separators:
  65. if _s == "":
  66. separator = _s
  67. break
  68. if _s in text:
  69. separator = _s
  70. break
  71. # Now that we have the separator, split the text
  72. if separator:
  73. splits = text.split(separator)
  74. else:
  75. splits = list(text)
  76. # Now go merging things, recursively splitting longer texts.
  77. _good_splits = []
  78. for s in splits:
  79. if self._length_function(s) < self._chunk_size:
  80. _good_splits.append(s)
  81. else:
  82. if _good_splits:
  83. merged_text = self._merge_splits(_good_splits, separator)
  84. final_chunks.extend(merged_text)
  85. _good_splits = []
  86. other_info = self.recursive_split_text(s)
  87. final_chunks.extend(other_info)
  88. if _good_splits:
  89. merged_text = self._merge_splits(_good_splits, separator)
  90. final_chunks.extend(merged_text)
  91. return final_chunks