indexing_runner.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577
  1. import datetime
  2. import json
  3. import logging
  4. import re
  5. import time
  6. import uuid
  7. from typing import Optional, List, cast
  8. from flask import current_app
  9. from flask_login import current_user
  10. from langchain.embeddings import OpenAIEmbeddings
  11. from langchain.schema import Document
  12. from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
  13. from core.data_loader.file_extractor import FileExtractor
  14. from core.data_loader.loader.notion import NotionLoader
  15. from core.docstore.dataset_docstore import DatesetDocumentStore
  16. from core.embedding.cached_embedding import CacheEmbedding
  17. from core.index.index import IndexBuilder
  18. from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
  19. from core.index.vector_index.vector_index import VectorIndex
  20. from core.llm.error import ProviderTokenNotInitError
  21. from core.llm.llm_builder import LLMBuilder
  22. from core.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter
  23. from core.llm.token_calculator import TokenCalculator
  24. from extensions.ext_database import db
  25. from extensions.ext_redis import redis_client
  26. from extensions.ext_storage import storage
  27. from libs import helper
  28. from models.dataset import Document as DatasetDocument
  29. from models.dataset import Dataset, DocumentSegment, DatasetProcessRule
  30. from models.model import UploadFile
  31. from models.source import DataSourceBinding
  32. class IndexingRunner:
  33. def __init__(self, embedding_model_name: str = "text-embedding-ada-002"):
  34. self.storage = storage
  35. self.embedding_model_name = embedding_model_name
  36. def run(self, dataset_documents: List[DatasetDocument]):
  37. """Run the indexing process."""
  38. for dataset_document in dataset_documents:
  39. try:
  40. # get dataset
  41. dataset = Dataset.query.filter_by(
  42. id=dataset_document.dataset_id
  43. ).first()
  44. if not dataset:
  45. raise ValueError("no dataset found")
  46. # load file
  47. text_docs = self._load_data(dataset_document)
  48. # get the process rule
  49. processing_rule = db.session.query(DatasetProcessRule). \
  50. filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
  51. first()
  52. # get splitter
  53. splitter = self._get_splitter(processing_rule)
  54. # split to documents
  55. documents = self._step_split(
  56. text_docs=text_docs,
  57. splitter=splitter,
  58. dataset=dataset,
  59. dataset_document=dataset_document,
  60. processing_rule=processing_rule
  61. )
  62. # build index
  63. self._build_index(
  64. dataset=dataset,
  65. dataset_document=dataset_document,
  66. documents=documents
  67. )
  68. except DocumentIsPausedException:
  69. raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id))
  70. except ProviderTokenNotInitError as e:
  71. dataset_document.indexing_status = 'error'
  72. dataset_document.error = str(e.description)
  73. dataset_document.stopped_at = datetime.datetime.utcnow()
  74. db.session.commit()
  75. except Exception as e:
  76. logging.exception("consume document failed")
  77. dataset_document.indexing_status = 'error'
  78. dataset_document.error = str(e)
  79. dataset_document.stopped_at = datetime.datetime.utcnow()
  80. db.session.commit()
  81. def run_in_splitting_status(self, dataset_document: DatasetDocument):
  82. """Run the indexing process when the index_status is splitting."""
  83. try:
  84. # get dataset
  85. dataset = Dataset.query.filter_by(
  86. id=dataset_document.dataset_id
  87. ).first()
  88. if not dataset:
  89. raise ValueError("no dataset found")
  90. # get exist document_segment list and delete
  91. document_segments = DocumentSegment.query.filter_by(
  92. dataset_id=dataset.id,
  93. document_id=dataset_document.id
  94. ).all()
  95. db.session.delete(document_segments)
  96. db.session.commit()
  97. # load file
  98. text_docs = self._load_data(dataset_document)
  99. # get the process rule
  100. processing_rule = db.session.query(DatasetProcessRule). \
  101. filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
  102. first()
  103. # get splitter
  104. splitter = self._get_splitter(processing_rule)
  105. # split to documents
  106. documents = self._step_split(
  107. text_docs=text_docs,
  108. splitter=splitter,
  109. dataset=dataset,
  110. dataset_document=dataset_document,
  111. processing_rule=processing_rule
  112. )
  113. # build index
  114. self._build_index(
  115. dataset=dataset,
  116. dataset_document=dataset_document,
  117. documents=documents
  118. )
  119. except DocumentIsPausedException:
  120. raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id))
  121. except ProviderTokenNotInitError as e:
  122. dataset_document.indexing_status = 'error'
  123. dataset_document.error = str(e.description)
  124. dataset_document.stopped_at = datetime.datetime.utcnow()
  125. db.session.commit()
  126. except Exception as e:
  127. logging.exception("consume document failed")
  128. dataset_document.indexing_status = 'error'
  129. dataset_document.error = str(e)
  130. dataset_document.stopped_at = datetime.datetime.utcnow()
  131. db.session.commit()
  132. def run_in_indexing_status(self, dataset_document: DatasetDocument):
  133. """Run the indexing process when the index_status is indexing."""
  134. try:
  135. # get dataset
  136. dataset = Dataset.query.filter_by(
  137. id=dataset_document.dataset_id
  138. ).first()
  139. if not dataset:
  140. raise ValueError("no dataset found")
  141. # get exist document_segment list and delete
  142. document_segments = DocumentSegment.query.filter_by(
  143. dataset_id=dataset.id,
  144. document_id=dataset_document.id
  145. ).all()
  146. documents = []
  147. if document_segments:
  148. for document_segment in document_segments:
  149. # transform segment to node
  150. if document_segment.status != "completed":
  151. document = Document(
  152. page_content=document_segment.content,
  153. metadata={
  154. "doc_id": document_segment.index_node_id,
  155. "doc_hash": document_segment.index_node_hash,
  156. "document_id": document_segment.document_id,
  157. "dataset_id": document_segment.dataset_id,
  158. }
  159. )
  160. documents.append(document)
  161. # build index
  162. self._build_index(
  163. dataset=dataset,
  164. dataset_document=dataset_document,
  165. documents=documents
  166. )
  167. except DocumentIsPausedException:
  168. raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id))
  169. except ProviderTokenNotInitError as e:
  170. dataset_document.indexing_status = 'error'
  171. dataset_document.error = str(e.description)
  172. dataset_document.stopped_at = datetime.datetime.utcnow()
  173. db.session.commit()
  174. except Exception as e:
  175. logging.exception("consume document failed")
  176. dataset_document.indexing_status = 'error'
  177. dataset_document.error = str(e)
  178. dataset_document.stopped_at = datetime.datetime.utcnow()
  179. db.session.commit()
  180. def file_indexing_estimate(self, file_details: List[UploadFile], tmp_processing_rule: dict) -> dict:
  181. """
  182. Estimate the indexing for the document.
  183. """
  184. tokens = 0
  185. preview_texts = []
  186. total_segments = 0
  187. for file_detail in file_details:
  188. # load data from file
  189. text_docs = FileExtractor.load(file_detail)
  190. processing_rule = DatasetProcessRule(
  191. mode=tmp_processing_rule["mode"],
  192. rules=json.dumps(tmp_processing_rule["rules"])
  193. )
  194. # get splitter
  195. splitter = self._get_splitter(processing_rule)
  196. # split to documents
  197. documents = self._split_to_documents(
  198. text_docs=text_docs,
  199. splitter=splitter,
  200. processing_rule=processing_rule
  201. )
  202. total_segments += len(documents)
  203. for document in documents:
  204. if len(preview_texts) < 5:
  205. preview_texts.append(document.page_content)
  206. tokens += TokenCalculator.get_num_tokens(self.embedding_model_name,
  207. self.filter_string(document.page_content))
  208. return {
  209. "total_segments": total_segments,
  210. "tokens": tokens,
  211. "total_price": '{:f}'.format(TokenCalculator.get_token_price(self.embedding_model_name, tokens)),
  212. "currency": TokenCalculator.get_currency(self.embedding_model_name),
  213. "preview": preview_texts
  214. }
  215. def notion_indexing_estimate(self, notion_info_list: list, tmp_processing_rule: dict) -> dict:
  216. """
  217. Estimate the indexing for the document.
  218. """
  219. # load data from notion
  220. tokens = 0
  221. preview_texts = []
  222. total_segments = 0
  223. for notion_info in notion_info_list:
  224. workspace_id = notion_info['workspace_id']
  225. data_source_binding = DataSourceBinding.query.filter(
  226. db.and_(
  227. DataSourceBinding.tenant_id == current_user.current_tenant_id,
  228. DataSourceBinding.provider == 'notion',
  229. DataSourceBinding.disabled == False,
  230. DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
  231. )
  232. ).first()
  233. if not data_source_binding:
  234. raise ValueError('Data source binding not found.')
  235. for page in notion_info['pages']:
  236. loader = NotionLoader(
  237. notion_access_token=data_source_binding.access_token,
  238. notion_workspace_id=workspace_id,
  239. notion_obj_id=page['page_id'],
  240. notion_page_type=page['type']
  241. )
  242. documents = loader.load()
  243. processing_rule = DatasetProcessRule(
  244. mode=tmp_processing_rule["mode"],
  245. rules=json.dumps(tmp_processing_rule["rules"])
  246. )
  247. # get splitter
  248. splitter = self._get_splitter(processing_rule)
  249. # split to documents
  250. documents = self._split_to_documents(
  251. text_docs=documents,
  252. splitter=splitter,
  253. processing_rule=processing_rule
  254. )
  255. total_segments += len(documents)
  256. for document in documents:
  257. if len(preview_texts) < 5:
  258. preview_texts.append(document.page_content)
  259. tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, document.page_content)
  260. return {
  261. "total_segments": total_segments,
  262. "tokens": tokens,
  263. "total_price": '{:f}'.format(TokenCalculator.get_token_price(self.embedding_model_name, tokens)),
  264. "currency": TokenCalculator.get_currency(self.embedding_model_name),
  265. "preview": preview_texts
  266. }
  267. def _load_data(self, dataset_document: DatasetDocument) -> List[Document]:
  268. # load file
  269. if dataset_document.data_source_type not in ["upload_file", "notion_import"]:
  270. return []
  271. data_source_info = dataset_document.data_source_info_dict
  272. text_docs = []
  273. if dataset_document.data_source_type == 'upload_file':
  274. if not data_source_info or 'upload_file_id' not in data_source_info:
  275. raise ValueError("no upload file found")
  276. file_detail = db.session.query(UploadFile). \
  277. filter(UploadFile.id == data_source_info['upload_file_id']). \
  278. one_or_none()
  279. text_docs = FileExtractor.load(file_detail)
  280. elif dataset_document.data_source_type == 'notion_import':
  281. loader = NotionLoader.from_document(dataset_document)
  282. text_docs = loader.load()
  283. # update document status to splitting
  284. self._update_document_index_status(
  285. document_id=dataset_document.id,
  286. after_indexing_status="splitting",
  287. extra_update_params={
  288. DatasetDocument.word_count: sum([len(text_doc.page_content) for text_doc in text_docs]),
  289. DatasetDocument.parsing_completed_at: datetime.datetime.utcnow()
  290. }
  291. )
  292. # replace doc id to document model id
  293. text_docs = cast(List[Document], text_docs)
  294. for text_doc in text_docs:
  295. # remove invalid symbol
  296. text_doc.page_content = self.filter_string(text_doc.page_content)
  297. text_doc.metadata['document_id'] = dataset_document.id
  298. text_doc.metadata['dataset_id'] = dataset_document.dataset_id
  299. return text_docs
  300. def filter_string(self, text):
  301. text = text.replace('<|', '<')
  302. text = text.replace('|>', '>')
  303. pattern = re.compile('[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\x80-\xFF]')
  304. return pattern.sub('', text)
  305. def _get_splitter(self, processing_rule: DatasetProcessRule) -> TextSplitter:
  306. """
  307. Get the NodeParser object according to the processing rule.
  308. """
  309. if processing_rule.mode == "custom":
  310. # The user-defined segmentation rule
  311. rules = json.loads(processing_rule.rules)
  312. segmentation = rules["segmentation"]
  313. if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > 1000:
  314. raise ValueError("Custom segment length should be between 50 and 1000.")
  315. separator = segmentation["separator"]
  316. if separator:
  317. separator = separator.replace('\\n', '\n')
  318. character_splitter = FixedRecursiveCharacterTextSplitter.from_tiktoken_encoder(
  319. chunk_size=segmentation["max_tokens"],
  320. chunk_overlap=0,
  321. fixed_separator=separator,
  322. separators=["\n\n", "。", ".", " ", ""]
  323. )
  324. else:
  325. # Automatic segmentation
  326. character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
  327. chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'],
  328. chunk_overlap=0,
  329. separators=["\n\n", "。", ".", " ", ""]
  330. )
  331. return character_splitter
  332. def _step_split(self, text_docs: List[Document], splitter: TextSplitter,
  333. dataset: Dataset, dataset_document: DatasetDocument, processing_rule: DatasetProcessRule) \
  334. -> List[Document]:
  335. """
  336. Split the text documents into documents and save them to the document segment.
  337. """
  338. documents = self._split_to_documents(
  339. text_docs=text_docs,
  340. splitter=splitter,
  341. processing_rule=processing_rule
  342. )
  343. # save node to document segment
  344. doc_store = DatesetDocumentStore(
  345. dataset=dataset,
  346. user_id=dataset_document.created_by,
  347. embedding_model_name=self.embedding_model_name,
  348. document_id=dataset_document.id
  349. )
  350. # add document segments
  351. doc_store.add_documents(documents)
  352. # update document status to indexing
  353. cur_time = datetime.datetime.utcnow()
  354. self._update_document_index_status(
  355. document_id=dataset_document.id,
  356. after_indexing_status="indexing",
  357. extra_update_params={
  358. DatasetDocument.cleaning_completed_at: cur_time,
  359. DatasetDocument.splitting_completed_at: cur_time,
  360. }
  361. )
  362. # update segment status to indexing
  363. self._update_segments_by_document(
  364. dataset_document_id=dataset_document.id,
  365. update_params={
  366. DocumentSegment.status: "indexing",
  367. DocumentSegment.indexing_at: datetime.datetime.utcnow()
  368. }
  369. )
  370. return documents
  371. def _split_to_documents(self, text_docs: List[Document], splitter: TextSplitter,
  372. processing_rule: DatasetProcessRule) -> List[Document]:
  373. """
  374. Split the text documents into nodes.
  375. """
  376. all_documents = []
  377. for text_doc in text_docs:
  378. # document clean
  379. document_text = self._document_clean(text_doc.page_content, processing_rule)
  380. text_doc.page_content = document_text
  381. # parse document to nodes
  382. documents = splitter.split_documents([text_doc])
  383. split_documents = []
  384. for document in documents:
  385. if document.page_content is None or not document.page_content.strip():
  386. continue
  387. doc_id = str(uuid.uuid4())
  388. hash = helper.generate_text_hash(document.page_content)
  389. document.metadata['doc_id'] = doc_id
  390. document.metadata['doc_hash'] = hash
  391. split_documents.append(document)
  392. all_documents.extend(split_documents)
  393. return all_documents
  394. def _document_clean(self, text: str, processing_rule: DatasetProcessRule) -> str:
  395. """
  396. Clean the document text according to the processing rules.
  397. """
  398. if processing_rule.mode == "automatic":
  399. rules = DatasetProcessRule.AUTOMATIC_RULES
  400. else:
  401. rules = json.loads(processing_rule.rules) if processing_rule.rules else {}
  402. if 'pre_processing_rules' in rules:
  403. pre_processing_rules = rules["pre_processing_rules"]
  404. for pre_processing_rule in pre_processing_rules:
  405. if pre_processing_rule["id"] == "remove_extra_spaces" and pre_processing_rule["enabled"] is True:
  406. # Remove extra spaces
  407. pattern = r'\n{3,}'
  408. text = re.sub(pattern, '\n\n', text)
  409. pattern = r'[\t\f\r\x20\u00a0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]{2,}'
  410. text = re.sub(pattern, ' ', text)
  411. elif pre_processing_rule["id"] == "remove_urls_emails" and pre_processing_rule["enabled"] is True:
  412. # Remove email
  413. pattern = r'([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)'
  414. text = re.sub(pattern, '', text)
  415. # Remove URL
  416. pattern = r'https?://[^\s]+'
  417. text = re.sub(pattern, '', text)
  418. return text
  419. def _build_index(self, dataset: Dataset, dataset_document: DatasetDocument, documents: List[Document]) -> None:
  420. """
  421. Build the index for the document.
  422. """
  423. vector_index = IndexBuilder.get_index(dataset, 'high_quality')
  424. keyword_table_index = IndexBuilder.get_index(dataset, 'economy')
  425. # chunk nodes by chunk size
  426. indexing_start_at = time.perf_counter()
  427. tokens = 0
  428. chunk_size = 100
  429. for i in range(0, len(documents), chunk_size):
  430. # check document is paused
  431. self._check_document_paused_status(dataset_document.id)
  432. chunk_documents = documents[i:i + chunk_size]
  433. tokens += sum(
  434. TokenCalculator.get_num_tokens(self.embedding_model_name, document.page_content)
  435. for document in chunk_documents
  436. )
  437. # save vector index
  438. if vector_index:
  439. vector_index.add_texts(chunk_documents)
  440. # save keyword index
  441. keyword_table_index.add_texts(chunk_documents)
  442. document_ids = [document.metadata['doc_id'] for document in chunk_documents]
  443. db.session.query(DocumentSegment).filter(
  444. DocumentSegment.document_id == dataset_document.id,
  445. DocumentSegment.index_node_id.in_(document_ids),
  446. DocumentSegment.status == "indexing"
  447. ).update({
  448. DocumentSegment.status: "completed",
  449. DocumentSegment.completed_at: datetime.datetime.utcnow()
  450. })
  451. db.session.commit()
  452. indexing_end_at = time.perf_counter()
  453. # update document status to completed
  454. self._update_document_index_status(
  455. document_id=dataset_document.id,
  456. after_indexing_status="completed",
  457. extra_update_params={
  458. DatasetDocument.tokens: tokens,
  459. DatasetDocument.completed_at: datetime.datetime.utcnow(),
  460. DatasetDocument.indexing_latency: indexing_end_at - indexing_start_at,
  461. }
  462. )
  463. def _check_document_paused_status(self, document_id: str):
  464. indexing_cache_key = 'document_{}_is_paused'.format(document_id)
  465. result = redis_client.get(indexing_cache_key)
  466. if result:
  467. raise DocumentIsPausedException()
  468. def _update_document_index_status(self, document_id: str, after_indexing_status: str,
  469. extra_update_params: Optional[dict] = None) -> None:
  470. """
  471. Update the document indexing status.
  472. """
  473. count = DatasetDocument.query.filter_by(id=document_id, is_paused=True).count()
  474. if count > 0:
  475. raise DocumentIsPausedException()
  476. update_params = {
  477. DatasetDocument.indexing_status: after_indexing_status
  478. }
  479. if extra_update_params:
  480. update_params.update(extra_update_params)
  481. DatasetDocument.query.filter_by(id=document_id).update(update_params)
  482. db.session.commit()
  483. def _update_segments_by_document(self, dataset_document_id: str, update_params: dict) -> None:
  484. """
  485. Update the document segment by document id.
  486. """
  487. DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params)
  488. db.session.commit()
  489. class DocumentIsPausedException(Exception):
  490. pass