fixed_text_splitter.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. """Functionality for splitting text."""
  2. from __future__ import annotations
  3. from typing import Any, Optional
  4. from core.model_manager import ModelInstance
  5. from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
  6. from core.rag.splitter.text_splitter import (
  7. TS,
  8. Collection,
  9. Literal,
  10. RecursiveCharacterTextSplitter,
  11. Set,
  12. TokenTextSplitter,
  13. Union,
  14. )
  15. class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
  16. """
  17. This class is used to implement from_gpt2_encoder, to prevent using of tiktoken
  18. """
  19. @classmethod
  20. def from_encoder(
  21. cls: type[TS],
  22. embedding_model_instance: Optional[ModelInstance],
  23. allowed_special: Union[Literal[all], Set[str]] = set(),
  24. disallowed_special: Union[Literal[all], Collection[str]] = "all",
  25. **kwargs: Any,
  26. ):
  27. def _token_encoder(text: str) -> int:
  28. if not text:
  29. return 0
  30. if embedding_model_instance:
  31. return embedding_model_instance.get_text_embedding_num_tokens(
  32. texts=[text]
  33. )
  34. else:
  35. return GPT2Tokenizer.get_num_tokens(text)
  36. if issubclass(cls, TokenTextSplitter):
  37. extra_kwargs = {
  38. "model_name": embedding_model_instance.model if embedding_model_instance else 'gpt2',
  39. "allowed_special": allowed_special,
  40. "disallowed_special": disallowed_special,
  41. }
  42. kwargs = {**kwargs, **extra_kwargs}
  43. return cls(length_function=_token_encoder, **kwargs)
  44. class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter):
  45. def __init__(self, fixed_separator: str = "\n\n", separators: Optional[list[str]] = None, **kwargs: Any):
  46. """Create a new TextSplitter."""
  47. super().__init__(**kwargs)
  48. self._fixed_separator = fixed_separator
  49. self._separators = separators or ["\n\n", "\n", " ", ""]
  50. def split_text(self, text: str) -> list[str]:
  51. """Split incoming text and return chunks."""
  52. if self._fixed_separator:
  53. chunks = text.split(self._fixed_separator)
  54. else:
  55. chunks = list(text)
  56. final_chunks = []
  57. for chunk in chunks:
  58. if self._length_function(chunk) > self._chunk_size:
  59. final_chunks.extend(self.recursive_split_text(chunk))
  60. else:
  61. final_chunks.append(chunk)
  62. return final_chunks
  63. def recursive_split_text(self, text: str) -> list[str]:
  64. """Split incoming text and return chunks."""
  65. final_chunks = []
  66. # Get appropriate separator to use
  67. separator = self._separators[-1]
  68. for _s in self._separators:
  69. if _s == "":
  70. separator = _s
  71. break
  72. if _s in text:
  73. separator = _s
  74. break
  75. # Now that we have the separator, split the text
  76. if separator:
  77. splits = text.split(separator)
  78. else:
  79. splits = list(text)
  80. # Now go merging things, recursively splitting longer texts.
  81. _good_splits = []
  82. for s in splits:
  83. if self._length_function(s) < self._chunk_size:
  84. _good_splits.append(s)
  85. else:
  86. if _good_splits:
  87. merged_text = self._merge_splits(_good_splits, separator)
  88. final_chunks.extend(merged_text)
  89. _good_splits = []
  90. other_info = self.recursive_split_text(s)
  91. final_chunks.extend(other_info)
  92. if _good_splits:
  93. merged_text = self._merge_splits(_good_splits, separator)
  94. final_chunks.extend(merged_text)
  95. return final_chunks