fixed_text_splitter.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. """Functionality for splitting text."""
  2. from __future__ import annotations
  3. from typing import (
  4. Any,
  5. List,
  6. Optional,
  7. )
  8. from langchain.text_splitter import RecursiveCharacterTextSplitter, TokenTextSplitter, TS, Type, Union, AbstractSet, Literal, Collection
  9. from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
  10. class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
  11. """
  12. This class is used to implement from_gpt2_encoder, to prevent using of tiktoken
  13. """
  14. @classmethod
  15. def from_gpt2_encoder(
  16. cls: Type[TS],
  17. encoding_name: str = "gpt2",
  18. model_name: Optional[str] = None,
  19. allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
  20. disallowed_special: Union[Literal["all"], Collection[str]] = "all",
  21. **kwargs: Any,
  22. ):
  23. def _token_encoder(text: str) -> int:
  24. return GPT2Tokenizer.get_num_tokens(text)
  25. if issubclass(cls, TokenTextSplitter):
  26. extra_kwargs = {
  27. "encoding_name": encoding_name,
  28. "model_name": model_name,
  29. "allowed_special": allowed_special,
  30. "disallowed_special": disallowed_special,
  31. }
  32. kwargs = {**kwargs, **extra_kwargs}
  33. return cls(length_function=_token_encoder, **kwargs)
  34. class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter):
  35. def __init__(self, fixed_separator: str = "\n\n", separators: Optional[List[str]] = None, **kwargs: Any):
  36. """Create a new TextSplitter."""
  37. super().__init__(**kwargs)
  38. self._fixed_separator = fixed_separator
  39. self._separators = separators or ["\n\n", "\n", " ", ""]
  40. def split_text(self, text: str) -> List[str]:
  41. """Split incoming text and return chunks."""
  42. if self._fixed_separator:
  43. chunks = text.split(self._fixed_separator)
  44. else:
  45. chunks = list(text)
  46. final_chunks = []
  47. for chunk in chunks:
  48. if self._length_function(chunk) > self._chunk_size:
  49. final_chunks.extend(self.recursive_split_text(chunk))
  50. else:
  51. final_chunks.append(chunk)
  52. return final_chunks
  53. def recursive_split_text(self, text: str) -> List[str]:
  54. """Split incoming text and return chunks."""
  55. final_chunks = []
  56. # Get appropriate separator to use
  57. separator = self._separators[-1]
  58. for _s in self._separators:
  59. if _s == "":
  60. separator = _s
  61. break
  62. if _s in text:
  63. separator = _s
  64. break
  65. # Now that we have the separator, split the text
  66. if separator:
  67. splits = text.split(separator)
  68. else:
  69. splits = list(text)
  70. # Now go merging things, recursively splitting longer texts.
  71. _good_splits = []
  72. for s in splits:
  73. if self._length_function(s) < self._chunk_size:
  74. _good_splits.append(s)
  75. else:
  76. if _good_splits:
  77. merged_text = self._merge_splits(_good_splits, separator)
  78. final_chunks.extend(merged_text)
  79. _good_splits = []
  80. other_info = self.recursive_split_text(s)
  81. final_chunks.extend(other_info)
  82. if _good_splits:
  83. merged_text = self._merge_splits(_good_splits, separator)
  84. final_chunks.extend(merged_text)
  85. return final_chunks