text_splitter.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903
  1. from __future__ import annotations
  2. import copy
  3. import logging
  4. import re
  5. from abc import ABC, abstractmethod
  6. from collections.abc import Callable, Collection, Iterable, Sequence, Set
  7. from dataclasses import dataclass
  8. from enum import Enum
  9. from typing import (
  10. Any,
  11. Literal,
  12. Optional,
  13. TypedDict,
  14. TypeVar,
  15. Union,
  16. )
  17. from core.rag.models.document import BaseDocumentTransformer, Document
  18. logger = logging.getLogger(__name__)
  19. TS = TypeVar("TS", bound="TextSplitter")
  20. def _split_text_with_regex(
  21. text: str, separator: str, keep_separator: bool
  22. ) -> list[str]:
  23. # Now that we have the separator, split the text
  24. if separator:
  25. if keep_separator:
  26. # The parentheses in the pattern keep the delimiters in the result.
  27. _splits = re.split(f"({separator})", text)
  28. splits = [_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)]
  29. if len(_splits) % 2 == 0:
  30. splits += _splits[-1:]
  31. splits = [_splits[0]] + splits
  32. else:
  33. splits = re.split(separator, text)
  34. else:
  35. splits = list(text)
  36. return [s for s in splits if s != ""]
  37. class TextSplitter(BaseDocumentTransformer, ABC):
  38. """Interface for splitting text into chunks."""
  39. def __init__(
  40. self,
  41. chunk_size: int = 4000,
  42. chunk_overlap: int = 200,
  43. length_function: Callable[[str], int] = len,
  44. keep_separator: bool = False,
  45. add_start_index: bool = False,
  46. ) -> None:
  47. """Create a new TextSplitter.
  48. Args:
  49. chunk_size: Maximum size of chunks to return
  50. chunk_overlap: Overlap in characters between chunks
  51. length_function: Function that measures the length of given chunks
  52. keep_separator: Whether to keep the separator in the chunks
  53. add_start_index: If `True`, includes chunk's start index in metadata
  54. """
  55. if chunk_overlap > chunk_size:
  56. raise ValueError(
  57. f"Got a larger chunk overlap ({chunk_overlap}) than chunk size "
  58. f"({chunk_size}), should be smaller."
  59. )
  60. self._chunk_size = chunk_size
  61. self._chunk_overlap = chunk_overlap
  62. self._length_function = length_function
  63. self._keep_separator = keep_separator
  64. self._add_start_index = add_start_index
  65. @abstractmethod
  66. def split_text(self, text: str) -> list[str]:
  67. """Split text into multiple components."""
  68. def create_documents(
  69. self, texts: list[str], metadatas: Optional[list[dict]] = None
  70. ) -> list[Document]:
  71. """Create documents from a list of texts."""
  72. _metadatas = metadatas or [{}] * len(texts)
  73. documents = []
  74. for i, text in enumerate(texts):
  75. index = -1
  76. for chunk in self.split_text(text):
  77. metadata = copy.deepcopy(_metadatas[i])
  78. if self._add_start_index:
  79. index = text.find(chunk, index + 1)
  80. metadata["start_index"] = index
  81. new_doc = Document(page_content=chunk, metadata=metadata)
  82. documents.append(new_doc)
  83. return documents
  84. def split_documents(self, documents: Iterable[Document]) -> list[Document]:
  85. """Split documents."""
  86. texts, metadatas = [], []
  87. for doc in documents:
  88. texts.append(doc.page_content)
  89. metadatas.append(doc.metadata)
  90. return self.create_documents(texts, metadatas=metadatas)
  91. def _join_docs(self, docs: list[str], separator: str) -> Optional[str]:
  92. text = separator.join(docs)
  93. text = text.strip()
  94. if text == "":
  95. return None
  96. else:
  97. return text
  98. def _merge_splits(self, splits: Iterable[str], separator: str) -> list[str]:
  99. # We now want to combine these smaller pieces into medium size
  100. # chunks to send to the LLM.
  101. separator_len = self._length_function(separator)
  102. docs = []
  103. current_doc: list[str] = []
  104. total = 0
  105. for d in splits:
  106. _len = self._length_function(d)
  107. if (
  108. total + _len + (separator_len if len(current_doc) > 0 else 0)
  109. > self._chunk_size
  110. ):
  111. if total > self._chunk_size:
  112. logger.warning(
  113. f"Created a chunk of size {total}, "
  114. f"which is longer than the specified {self._chunk_size}"
  115. )
  116. if len(current_doc) > 0:
  117. doc = self._join_docs(current_doc, separator)
  118. if doc is not None:
  119. docs.append(doc)
  120. # Keep on popping if:
  121. # - we have a larger chunk than in the chunk overlap
  122. # - or if we still have any chunks and the length is long
  123. while total > self._chunk_overlap or (
  124. total + _len + (separator_len if len(current_doc) > 0 else 0)
  125. > self._chunk_size
  126. and total > 0
  127. ):
  128. total -= self._length_function(current_doc[0]) + (
  129. separator_len if len(current_doc) > 1 else 0
  130. )
  131. current_doc = current_doc[1:]
  132. current_doc.append(d)
  133. total += _len + (separator_len if len(current_doc) > 1 else 0)
  134. doc = self._join_docs(current_doc, separator)
  135. if doc is not None:
  136. docs.append(doc)
  137. return docs
  138. @classmethod
  139. def from_huggingface_tokenizer(cls, tokenizer: Any, **kwargs: Any) -> TextSplitter:
  140. """Text splitter that uses HuggingFace tokenizer to count length."""
  141. try:
  142. from transformers import PreTrainedTokenizerBase
  143. if not isinstance(tokenizer, PreTrainedTokenizerBase):
  144. raise ValueError(
  145. "Tokenizer received was not an instance of PreTrainedTokenizerBase"
  146. )
  147. def _huggingface_tokenizer_length(text: str) -> int:
  148. return len(tokenizer.encode(text))
  149. except ImportError:
  150. raise ValueError(
  151. "Could not import transformers python package. "
  152. "Please install it with `pip install transformers`."
  153. )
  154. return cls(length_function=_huggingface_tokenizer_length, **kwargs)
  155. @classmethod
  156. def from_tiktoken_encoder(
  157. cls: type[TS],
  158. encoding_name: str = "gpt2",
  159. model_name: Optional[str] = None,
  160. allowed_special: Union[Literal["all"], Set[str]] = set(),
  161. disallowed_special: Union[Literal["all"], Collection[str]] = "all",
  162. **kwargs: Any,
  163. ) -> TS:
  164. """Text splitter that uses tiktoken encoder to count length."""
  165. try:
  166. import tiktoken
  167. except ImportError:
  168. raise ImportError(
  169. "Could not import tiktoken python package. "
  170. "This is needed in order to calculate max_tokens_for_prompt. "
  171. "Please install it with `pip install tiktoken`."
  172. )
  173. if model_name is not None:
  174. enc = tiktoken.encoding_for_model(model_name)
  175. else:
  176. enc = tiktoken.get_encoding(encoding_name)
  177. def _tiktoken_encoder(text: str) -> int:
  178. return len(
  179. enc.encode(
  180. text,
  181. allowed_special=allowed_special,
  182. disallowed_special=disallowed_special,
  183. )
  184. )
  185. if issubclass(cls, TokenTextSplitter):
  186. extra_kwargs = {
  187. "encoding_name": encoding_name,
  188. "model_name": model_name,
  189. "allowed_special": allowed_special,
  190. "disallowed_special": disallowed_special,
  191. }
  192. kwargs = {**kwargs, **extra_kwargs}
  193. return cls(length_function=_tiktoken_encoder, **kwargs)
  194. def transform_documents(
  195. self, documents: Sequence[Document], **kwargs: Any
  196. ) -> Sequence[Document]:
  197. """Transform sequence of documents by splitting them."""
  198. return self.split_documents(list(documents))
  199. async def atransform_documents(
  200. self, documents: Sequence[Document], **kwargs: Any
  201. ) -> Sequence[Document]:
  202. """Asynchronously transform a sequence of documents by splitting them."""
  203. raise NotImplementedError
  204. class CharacterTextSplitter(TextSplitter):
  205. """Splitting text that looks at characters."""
  206. def __init__(self, separator: str = "\n\n", **kwargs: Any) -> None:
  207. """Create a new TextSplitter."""
  208. super().__init__(**kwargs)
  209. self._separator = separator
  210. def split_text(self, text: str) -> list[str]:
  211. """Split incoming text and return chunks."""
  212. # First we naively split the large input into a bunch of smaller ones.
  213. splits = _split_text_with_regex(text, self._separator, self._keep_separator)
  214. _separator = "" if self._keep_separator else self._separator
  215. return self._merge_splits(splits, _separator)
  216. class LineType(TypedDict):
  217. """Line type as typed dict."""
  218. metadata: dict[str, str]
  219. content: str
  220. class HeaderType(TypedDict):
  221. """Header type as typed dict."""
  222. level: int
  223. name: str
  224. data: str
  225. class MarkdownHeaderTextSplitter:
  226. """Splitting markdown files based on specified headers."""
  227. def __init__(
  228. self, headers_to_split_on: list[tuple[str, str]], return_each_line: bool = False
  229. ):
  230. """Create a new MarkdownHeaderTextSplitter.
  231. Args:
  232. headers_to_split_on: Headers we want to track
  233. return_each_line: Return each line w/ associated headers
  234. """
  235. # Output line-by-line or aggregated into chunks w/ common headers
  236. self.return_each_line = return_each_line
  237. # Given the headers we want to split on,
  238. # (e.g., "#, ##, etc") order by length
  239. self.headers_to_split_on = sorted(
  240. headers_to_split_on, key=lambda split: len(split[0]), reverse=True
  241. )
  242. def aggregate_lines_to_chunks(self, lines: list[LineType]) -> list[Document]:
  243. """Combine lines with common metadata into chunks
  244. Args:
  245. lines: Line of text / associated header metadata
  246. """
  247. aggregated_chunks: list[LineType] = []
  248. for line in lines:
  249. if (
  250. aggregated_chunks
  251. and aggregated_chunks[-1]["metadata"] == line["metadata"]
  252. ):
  253. # If the last line in the aggregated list
  254. # has the same metadata as the current line,
  255. # append the current content to the last lines's content
  256. aggregated_chunks[-1]["content"] += " \n" + line["content"]
  257. else:
  258. # Otherwise, append the current line to the aggregated list
  259. aggregated_chunks.append(line)
  260. return [
  261. Document(page_content=chunk["content"], metadata=chunk["metadata"])
  262. for chunk in aggregated_chunks
  263. ]
  264. def split_text(self, text: str) -> list[Document]:
  265. """Split markdown file
  266. Args:
  267. text: Markdown file"""
  268. # Split the input text by newline character ("\n").
  269. lines = text.split("\n")
  270. # Final output
  271. lines_with_metadata: list[LineType] = []
  272. # Content and metadata of the chunk currently being processed
  273. current_content: list[str] = []
  274. current_metadata: dict[str, str] = {}
  275. # Keep track of the nested header structure
  276. # header_stack: List[Dict[str, Union[int, str]]] = []
  277. header_stack: list[HeaderType] = []
  278. initial_metadata: dict[str, str] = {}
  279. for line in lines:
  280. stripped_line = line.strip()
  281. # Check each line against each of the header types (e.g., #, ##)
  282. for sep, name in self.headers_to_split_on:
  283. # Check if line starts with a header that we intend to split on
  284. if stripped_line.startswith(sep) and (
  285. # Header with no text OR header is followed by space
  286. # Both are valid conditions that sep is being used a header
  287. len(stripped_line) == len(sep)
  288. or stripped_line[len(sep)] == " "
  289. ):
  290. # Ensure we are tracking the header as metadata
  291. if name is not None:
  292. # Get the current header level
  293. current_header_level = sep.count("#")
  294. # Pop out headers of lower or same level from the stack
  295. while (
  296. header_stack
  297. and header_stack[-1]["level"] >= current_header_level
  298. ):
  299. # We have encountered a new header
  300. # at the same or higher level
  301. popped_header = header_stack.pop()
  302. # Clear the metadata for the
  303. # popped header in initial_metadata
  304. if popped_header["name"] in initial_metadata:
  305. initial_metadata.pop(popped_header["name"])
  306. # Push the current header to the stack
  307. header: HeaderType = {
  308. "level": current_header_level,
  309. "name": name,
  310. "data": stripped_line[len(sep):].strip(),
  311. }
  312. header_stack.append(header)
  313. # Update initial_metadata with the current header
  314. initial_metadata[name] = header["data"]
  315. # Add the previous line to the lines_with_metadata
  316. # only if current_content is not empty
  317. if current_content:
  318. lines_with_metadata.append(
  319. {
  320. "content": "\n".join(current_content),
  321. "metadata": current_metadata.copy(),
  322. }
  323. )
  324. current_content.clear()
  325. break
  326. else:
  327. if stripped_line:
  328. current_content.append(stripped_line)
  329. elif current_content:
  330. lines_with_metadata.append(
  331. {
  332. "content": "\n".join(current_content),
  333. "metadata": current_metadata.copy(),
  334. }
  335. )
  336. current_content.clear()
  337. current_metadata = initial_metadata.copy()
  338. if current_content:
  339. lines_with_metadata.append(
  340. {"content": "\n".join(current_content), "metadata": current_metadata}
  341. )
  342. # lines_with_metadata has each line with associated header metadata
  343. # aggregate these into chunks based on common metadata
  344. if not self.return_each_line:
  345. return self.aggregate_lines_to_chunks(lines_with_metadata)
  346. else:
  347. return [
  348. Document(page_content=chunk["content"], metadata=chunk["metadata"])
  349. for chunk in lines_with_metadata
  350. ]
  351. # should be in newer Python versions (3.10+)
  352. # @dataclass(frozen=True, kw_only=True, slots=True)
  353. @dataclass(frozen=True)
  354. class Tokenizer:
  355. chunk_overlap: int
  356. tokens_per_chunk: int
  357. decode: Callable[[list[int]], str]
  358. encode: Callable[[str], list[int]]
  359. def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> list[str]:
  360. """Split incoming text and return chunks using tokenizer."""
  361. splits: list[str] = []
  362. input_ids = tokenizer.encode(text)
  363. start_idx = 0
  364. cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
  365. chunk_ids = input_ids[start_idx:cur_idx]
  366. while start_idx < len(input_ids):
  367. splits.append(tokenizer.decode(chunk_ids))
  368. start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap
  369. cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
  370. chunk_ids = input_ids[start_idx:cur_idx]
  371. return splits
  372. class TokenTextSplitter(TextSplitter):
  373. """Splitting text to tokens using model tokenizer."""
  374. def __init__(
  375. self,
  376. encoding_name: str = "gpt2",
  377. model_name: Optional[str] = None,
  378. allowed_special: Union[Literal["all"], Set[str]] = set(),
  379. disallowed_special: Union[Literal["all"], Collection[str]] = "all",
  380. **kwargs: Any,
  381. ) -> None:
  382. """Create a new TextSplitter."""
  383. super().__init__(**kwargs)
  384. try:
  385. import tiktoken
  386. except ImportError:
  387. raise ImportError(
  388. "Could not import tiktoken python package. "
  389. "This is needed in order to for TokenTextSplitter. "
  390. "Please install it with `pip install tiktoken`."
  391. )
  392. if model_name is not None:
  393. enc = tiktoken.encoding_for_model(model_name)
  394. else:
  395. enc = tiktoken.get_encoding(encoding_name)
  396. self._tokenizer = enc
  397. self._allowed_special = allowed_special
  398. self._disallowed_special = disallowed_special
  399. def split_text(self, text: str) -> list[str]:
  400. def _encode(_text: str) -> list[int]:
  401. return self._tokenizer.encode(
  402. _text,
  403. allowed_special=self._allowed_special,
  404. disallowed_special=self._disallowed_special,
  405. )
  406. tokenizer = Tokenizer(
  407. chunk_overlap=self._chunk_overlap,
  408. tokens_per_chunk=self._chunk_size,
  409. decode=self._tokenizer.decode,
  410. encode=_encode,
  411. )
  412. return split_text_on_tokens(text=text, tokenizer=tokenizer)
  413. class Language(str, Enum):
  414. """Enum of the programming languages."""
  415. CPP = "cpp"
  416. GO = "go"
  417. JAVA = "java"
  418. JS = "js"
  419. PHP = "php"
  420. PROTO = "proto"
  421. PYTHON = "python"
  422. RST = "rst"
  423. RUBY = "ruby"
  424. RUST = "rust"
  425. SCALA = "scala"
  426. SWIFT = "swift"
  427. MARKDOWN = "markdown"
  428. LATEX = "latex"
  429. HTML = "html"
  430. SOL = "sol"
  431. class RecursiveCharacterTextSplitter(TextSplitter):
  432. """Splitting text by recursively look at characters.
  433. Recursively tries to split by different characters to find one
  434. that works.
  435. """
  436. def __init__(
  437. self,
  438. separators: Optional[list[str]] = None,
  439. keep_separator: bool = True,
  440. **kwargs: Any,
  441. ) -> None:
  442. """Create a new TextSplitter."""
  443. super().__init__(keep_separator=keep_separator, **kwargs)
  444. self._separators = separators or ["\n\n", "\n", " ", ""]
  445. def _split_text(self, text: str, separators: list[str]) -> list[str]:
  446. """Split incoming text and return chunks."""
  447. final_chunks = []
  448. # Get appropriate separator to use
  449. separator = separators[-1]
  450. new_separators = []
  451. for i, _s in enumerate(separators):
  452. if _s == "":
  453. separator = _s
  454. break
  455. if re.search(_s, text):
  456. separator = _s
  457. new_separators = separators[i + 1:]
  458. break
  459. splits = _split_text_with_regex(text, separator, self._keep_separator)
  460. # Now go merging things, recursively splitting longer texts.
  461. _good_splits = []
  462. _separator = "" if self._keep_separator else separator
  463. for s in splits:
  464. if self._length_function(s) < self._chunk_size:
  465. _good_splits.append(s)
  466. else:
  467. if _good_splits:
  468. merged_text = self._merge_splits(_good_splits, _separator)
  469. final_chunks.extend(merged_text)
  470. _good_splits = []
  471. if not new_separators:
  472. final_chunks.append(s)
  473. else:
  474. other_info = self._split_text(s, new_separators)
  475. final_chunks.extend(other_info)
  476. if _good_splits:
  477. merged_text = self._merge_splits(_good_splits, _separator)
  478. final_chunks.extend(merged_text)
  479. return final_chunks
  480. def split_text(self, text: str) -> list[str]:
  481. return self._split_text(text, self._separators)
  482. @classmethod
  483. def from_language(
  484. cls, language: Language, **kwargs: Any
  485. ) -> RecursiveCharacterTextSplitter:
  486. separators = cls.get_separators_for_language(language)
  487. return cls(separators=separators, **kwargs)
  488. @staticmethod
  489. def get_separators_for_language(language: Language) -> list[str]:
  490. if language == Language.CPP:
  491. return [
  492. # Split along class definitions
  493. "\nclass ",
  494. # Split along function definitions
  495. "\nvoid ",
  496. "\nint ",
  497. "\nfloat ",
  498. "\ndouble ",
  499. # Split along control flow statements
  500. "\nif ",
  501. "\nfor ",
  502. "\nwhile ",
  503. "\nswitch ",
  504. "\ncase ",
  505. # Split by the normal type of lines
  506. "\n\n",
  507. "\n",
  508. " ",
  509. "",
  510. ]
  511. elif language == Language.GO:
  512. return [
  513. # Split along function definitions
  514. "\nfunc ",
  515. "\nvar ",
  516. "\nconst ",
  517. "\ntype ",
  518. # Split along control flow statements
  519. "\nif ",
  520. "\nfor ",
  521. "\nswitch ",
  522. "\ncase ",
  523. # Split by the normal type of lines
  524. "\n\n",
  525. "\n",
  526. " ",
  527. "",
  528. ]
  529. elif language == Language.JAVA:
  530. return [
  531. # Split along class definitions
  532. "\nclass ",
  533. # Split along method definitions
  534. "\npublic ",
  535. "\nprotected ",
  536. "\nprivate ",
  537. "\nstatic ",
  538. # Split along control flow statements
  539. "\nif ",
  540. "\nfor ",
  541. "\nwhile ",
  542. "\nswitch ",
  543. "\ncase ",
  544. # Split by the normal type of lines
  545. "\n\n",
  546. "\n",
  547. " ",
  548. "",
  549. ]
  550. elif language == Language.JS:
  551. return [
  552. # Split along function definitions
  553. "\nfunction ",
  554. "\nconst ",
  555. "\nlet ",
  556. "\nvar ",
  557. "\nclass ",
  558. # Split along control flow statements
  559. "\nif ",
  560. "\nfor ",
  561. "\nwhile ",
  562. "\nswitch ",
  563. "\ncase ",
  564. "\ndefault ",
  565. # Split by the normal type of lines
  566. "\n\n",
  567. "\n",
  568. " ",
  569. "",
  570. ]
  571. elif language == Language.PHP:
  572. return [
  573. # Split along function definitions
  574. "\nfunction ",
  575. # Split along class definitions
  576. "\nclass ",
  577. # Split along control flow statements
  578. "\nif ",
  579. "\nforeach ",
  580. "\nwhile ",
  581. "\ndo ",
  582. "\nswitch ",
  583. "\ncase ",
  584. # Split by the normal type of lines
  585. "\n\n",
  586. "\n",
  587. " ",
  588. "",
  589. ]
  590. elif language == Language.PROTO:
  591. return [
  592. # Split along message definitions
  593. "\nmessage ",
  594. # Split along service definitions
  595. "\nservice ",
  596. # Split along enum definitions
  597. "\nenum ",
  598. # Split along option definitions
  599. "\noption ",
  600. # Split along import statements
  601. "\nimport ",
  602. # Split along syntax declarations
  603. "\nsyntax ",
  604. # Split by the normal type of lines
  605. "\n\n",
  606. "\n",
  607. " ",
  608. "",
  609. ]
  610. elif language == Language.PYTHON:
  611. return [
  612. # First, try to split along class definitions
  613. "\nclass ",
  614. "\ndef ",
  615. "\n\tdef ",
  616. # Now split by the normal type of lines
  617. "\n\n",
  618. "\n",
  619. " ",
  620. "",
  621. ]
  622. elif language == Language.RST:
  623. return [
  624. # Split along section titles
  625. "\n=+\n",
  626. "\n-+\n",
  627. "\n\*+\n",
  628. # Split along directive markers
  629. "\n\n.. *\n\n",
  630. # Split by the normal type of lines
  631. "\n\n",
  632. "\n",
  633. " ",
  634. "",
  635. ]
  636. elif language == Language.RUBY:
  637. return [
  638. # Split along method definitions
  639. "\ndef ",
  640. "\nclass ",
  641. # Split along control flow statements
  642. "\nif ",
  643. "\nunless ",
  644. "\nwhile ",
  645. "\nfor ",
  646. "\ndo ",
  647. "\nbegin ",
  648. "\nrescue ",
  649. # Split by the normal type of lines
  650. "\n\n",
  651. "\n",
  652. " ",
  653. "",
  654. ]
  655. elif language == Language.RUST:
  656. return [
  657. # Split along function definitions
  658. "\nfn ",
  659. "\nconst ",
  660. "\nlet ",
  661. # Split along control flow statements
  662. "\nif ",
  663. "\nwhile ",
  664. "\nfor ",
  665. "\nloop ",
  666. "\nmatch ",
  667. "\nconst ",
  668. # Split by the normal type of lines
  669. "\n\n",
  670. "\n",
  671. " ",
  672. "",
  673. ]
  674. elif language == Language.SCALA:
  675. return [
  676. # Split along class definitions
  677. "\nclass ",
  678. "\nobject ",
  679. # Split along method definitions
  680. "\ndef ",
  681. "\nval ",
  682. "\nvar ",
  683. # Split along control flow statements
  684. "\nif ",
  685. "\nfor ",
  686. "\nwhile ",
  687. "\nmatch ",
  688. "\ncase ",
  689. # Split by the normal type of lines
  690. "\n\n",
  691. "\n",
  692. " ",
  693. "",
  694. ]
  695. elif language == Language.SWIFT:
  696. return [
  697. # Split along function definitions
  698. "\nfunc ",
  699. # Split along class definitions
  700. "\nclass ",
  701. "\nstruct ",
  702. "\nenum ",
  703. # Split along control flow statements
  704. "\nif ",
  705. "\nfor ",
  706. "\nwhile ",
  707. "\ndo ",
  708. "\nswitch ",
  709. "\ncase ",
  710. # Split by the normal type of lines
  711. "\n\n",
  712. "\n",
  713. " ",
  714. "",
  715. ]
  716. elif language == Language.MARKDOWN:
  717. return [
  718. # First, try to split along Markdown headings (starting with level 2)
  719. "\n#{1,6} ",
  720. # Note the alternative syntax for headings (below) is not handled here
  721. # Heading level 2
  722. # ---------------
  723. # End of code block
  724. "```\n",
  725. # Horizontal lines
  726. "\n\*\*\*+\n",
  727. "\n---+\n",
  728. "\n___+\n",
  729. # Note that this splitter doesn't handle horizontal lines defined
  730. # by *three or more* of ***, ---, or ___, but this is not handled
  731. "\n\n",
  732. "\n",
  733. " ",
  734. "",
  735. ]
  736. elif language == Language.LATEX:
  737. return [
  738. # First, try to split along Latex sections
  739. "\n\\\chapter{",
  740. "\n\\\section{",
  741. "\n\\\subsection{",
  742. "\n\\\subsubsection{",
  743. # Now split by environments
  744. "\n\\\begin{enumerate}",
  745. "\n\\\begin{itemize}",
  746. "\n\\\begin{description}",
  747. "\n\\\begin{list}",
  748. "\n\\\begin{quote}",
  749. "\n\\\begin{quotation}",
  750. "\n\\\begin{verse}",
  751. "\n\\\begin{verbatim}",
  752. # Now split by math environments
  753. "\n\\\begin{align}",
  754. "$$",
  755. "$",
  756. # Now split by the normal type of lines
  757. " ",
  758. "",
  759. ]
  760. elif language == Language.HTML:
  761. return [
  762. # First, try to split along HTML tags
  763. "<body",
  764. "<div",
  765. "<p",
  766. "<br",
  767. "<li",
  768. "<h1",
  769. "<h2",
  770. "<h3",
  771. "<h4",
  772. "<h5",
  773. "<h6",
  774. "<span",
  775. "<table",
  776. "<tr",
  777. "<td",
  778. "<th",
  779. "<ul",
  780. "<ol",
  781. "<header",
  782. "<footer",
  783. "<nav",
  784. # Head
  785. "<head",
  786. "<style",
  787. "<script",
  788. "<meta",
  789. "<title",
  790. "",
  791. ]
  792. elif language == Language.SOL:
  793. return [
  794. # Split along compiler information definitions
  795. "\npragma ",
  796. "\nusing ",
  797. # Split along contract definitions
  798. "\ncontract ",
  799. "\ninterface ",
  800. "\nlibrary ",
  801. # Split along method definitions
  802. "\nconstructor ",
  803. "\ntype ",
  804. "\nfunction ",
  805. "\nevent ",
  806. "\nmodifier ",
  807. "\nerror ",
  808. "\nstruct ",
  809. "\nenum ",
  810. # Split along control flow statements
  811. "\nif ",
  812. "\nfor ",
  813. "\nwhile ",
  814. "\ndo while ",
  815. "\nassembly ",
  816. # Split by the normal type of lines
  817. "\n\n",
  818. "\n",
  819. " ",
  820. "",
  821. ]
  822. else:
  823. raise ValueError(
  824. f"Language {language} is not supported! "
  825. f"Please choose from {list(Language)}"
  826. )