fixed_text_splitter.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. """Functionality for splitting text."""
  2. from __future__ import annotations
  3. from typing import (
  4. Any,
  5. List,
  6. Optional,
  7. )
  8. from langchain.text_splitter import RecursiveCharacterTextSplitter
  9. class FixedRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
  10. def __init__(self, fixed_separator: str = "\n\n", separators: Optional[List[str]] = None, **kwargs: Any):
  11. """Create a new TextSplitter."""
  12. super().__init__(**kwargs)
  13. self._fixed_separator = fixed_separator
  14. self._separators = separators or ["\n\n", "\n", " ", ""]
  15. def split_text(self, text: str) -> List[str]:
  16. """Split incoming text and return chunks."""
  17. if self._fixed_separator:
  18. chunks = text.split(self._fixed_separator)
  19. else:
  20. chunks = list(text)
  21. final_chunks = []
  22. for chunk in chunks:
  23. if self._length_function(chunk) > self._chunk_size:
  24. final_chunks.extend(self.recursive_split_text(chunk))
  25. else:
  26. final_chunks.append(chunk)
  27. return final_chunks
  28. def recursive_split_text(self, text: str) -> List[str]:
  29. """Split incoming text and return chunks."""
  30. final_chunks = []
  31. # Get appropriate separator to use
  32. separator = self._separators[-1]
  33. for _s in self._separators:
  34. if _s == "":
  35. separator = _s
  36. break
  37. if _s in text:
  38. separator = _s
  39. break
  40. # Now that we have the separator, split the text
  41. if separator:
  42. splits = text.split(separator)
  43. else:
  44. splits = list(text)
  45. # Now go merging things, recursively splitting longer texts.
  46. _good_splits = []
  47. for s in splits:
  48. if self._length_function(s) < self._chunk_size:
  49. _good_splits.append(s)
  50. else:
  51. if _good_splits:
  52. merged_text = self._merge_splits(_good_splits, separator)
  53. final_chunks.extend(merged_text)
  54. _good_splits = []
  55. other_info = self.recursive_split_text(s)
  56. final_chunks.extend(other_info)
  57. if _good_splits:
  58. merged_text = self._merge_splits(_good_splits, separator)
  59. final_chunks.extend(merged_text)
  60. return final_chunks