indexing_runner.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774
  1. import datetime
  2. import json
  3. import logging
  4. import re
  5. import threading
  6. import time
  7. import uuid
  8. from typing import Optional, List, cast
  9. from flask import current_app, Flask
  10. from flask_login import current_user
  11. from langchain.schema import Document
  12. from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
  13. from sqlalchemy.orm.exc import ObjectDeletedError
  14. from core.data_loader.file_extractor import FileExtractor
  15. from core.data_loader.loader.notion import NotionLoader
  16. from core.docstore.dataset_docstore import DatasetDocumentStore
  17. from core.generator.llm_generator import LLMGenerator
  18. from core.index.index import IndexBuilder
  19. from core.model_providers.error import ProviderTokenNotInitError
  20. from core.model_providers.model_factory import ModelFactory
  21. from core.model_providers.models.entity.message import MessageType
  22. from core.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter
  23. from extensions.ext_database import db
  24. from extensions.ext_redis import redis_client
  25. from extensions.ext_storage import storage
  26. from libs import helper
  27. from models.dataset import Document as DatasetDocument
  28. from models.dataset import Dataset, DocumentSegment, DatasetProcessRule
  29. from models.model import UploadFile
  30. from models.source import DataSourceBinding
  31. class IndexingRunner:
  32. def __init__(self):
  33. self.storage = storage
  34. def run(self, dataset_documents: List[DatasetDocument]):
  35. """Run the indexing process."""
  36. for dataset_document in dataset_documents:
  37. try:
  38. # get dataset
  39. dataset = Dataset.query.filter_by(
  40. id=dataset_document.dataset_id
  41. ).first()
  42. if not dataset:
  43. raise ValueError("no dataset found")
  44. # get the process rule
  45. processing_rule = db.session.query(DatasetProcessRule). \
  46. filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
  47. first()
  48. # load file
  49. text_docs = self._load_data(dataset_document)
  50. # get splitter
  51. splitter = self._get_splitter(processing_rule)
  52. # split to documents
  53. documents = self._step_split(
  54. text_docs=text_docs,
  55. splitter=splitter,
  56. dataset=dataset,
  57. dataset_document=dataset_document,
  58. processing_rule=processing_rule
  59. )
  60. self._build_index(
  61. dataset=dataset,
  62. dataset_document=dataset_document,
  63. documents=documents
  64. )
  65. except DocumentIsPausedException:
  66. raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id))
  67. except ProviderTokenNotInitError as e:
  68. dataset_document.indexing_status = 'error'
  69. dataset_document.error = str(e.description)
  70. dataset_document.stopped_at = datetime.datetime.utcnow()
  71. db.session.commit()
  72. except ObjectDeletedError:
  73. logging.warning('Document deleted, document id: {}'.format(dataset_document.id))
  74. except Exception as e:
  75. logging.exception("consume document failed")
  76. dataset_document.indexing_status = 'error'
  77. dataset_document.error = str(e)
  78. dataset_document.stopped_at = datetime.datetime.utcnow()
  79. db.session.commit()
  80. def run_in_splitting_status(self, dataset_document: DatasetDocument):
  81. """Run the indexing process when the index_status is splitting."""
  82. try:
  83. # get dataset
  84. dataset = Dataset.query.filter_by(
  85. id=dataset_document.dataset_id
  86. ).first()
  87. if not dataset:
  88. raise ValueError("no dataset found")
  89. # get exist document_segment list and delete
  90. document_segments = DocumentSegment.query.filter_by(
  91. dataset_id=dataset.id,
  92. document_id=dataset_document.id
  93. ).all()
  94. for document_segment in document_segments:
  95. db.session.delete(document_segment)
  96. db.session.commit()
  97. # load file
  98. text_docs = self._load_data(dataset_document)
  99. # get the process rule
  100. processing_rule = db.session.query(DatasetProcessRule). \
  101. filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
  102. first()
  103. # get splitter
  104. splitter = self._get_splitter(processing_rule)
  105. # split to documents
  106. documents = self._step_split(
  107. text_docs=text_docs,
  108. splitter=splitter,
  109. dataset=dataset,
  110. dataset_document=dataset_document,
  111. processing_rule=processing_rule
  112. )
  113. # build index
  114. self._build_index(
  115. dataset=dataset,
  116. dataset_document=dataset_document,
  117. documents=documents
  118. )
  119. except DocumentIsPausedException:
  120. raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id))
  121. except ProviderTokenNotInitError as e:
  122. dataset_document.indexing_status = 'error'
  123. dataset_document.error = str(e.description)
  124. dataset_document.stopped_at = datetime.datetime.utcnow()
  125. db.session.commit()
  126. except Exception as e:
  127. logging.exception("consume document failed")
  128. dataset_document.indexing_status = 'error'
  129. dataset_document.error = str(e)
  130. dataset_document.stopped_at = datetime.datetime.utcnow()
  131. db.session.commit()
  132. def run_in_indexing_status(self, dataset_document: DatasetDocument):
  133. """Run the indexing process when the index_status is indexing."""
  134. try:
  135. # get dataset
  136. dataset = Dataset.query.filter_by(
  137. id=dataset_document.dataset_id
  138. ).first()
  139. if not dataset:
  140. raise ValueError("no dataset found")
  141. # get exist document_segment list and delete
  142. document_segments = DocumentSegment.query.filter_by(
  143. dataset_id=dataset.id,
  144. document_id=dataset_document.id
  145. ).all()
  146. documents = []
  147. if document_segments:
  148. for document_segment in document_segments:
  149. # transform segment to node
  150. if document_segment.status != "completed":
  151. document = Document(
  152. page_content=document_segment.content,
  153. metadata={
  154. "doc_id": document_segment.index_node_id,
  155. "doc_hash": document_segment.index_node_hash,
  156. "document_id": document_segment.document_id,
  157. "dataset_id": document_segment.dataset_id,
  158. }
  159. )
  160. documents.append(document)
  161. # build index
  162. self._build_index(
  163. dataset=dataset,
  164. dataset_document=dataset_document,
  165. documents=documents
  166. )
  167. except DocumentIsPausedException:
  168. raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id))
  169. except ProviderTokenNotInitError as e:
  170. dataset_document.indexing_status = 'error'
  171. dataset_document.error = str(e.description)
  172. dataset_document.stopped_at = datetime.datetime.utcnow()
  173. db.session.commit()
  174. except Exception as e:
  175. logging.exception("consume document failed")
  176. dataset_document.indexing_status = 'error'
  177. dataset_document.error = str(e)
  178. dataset_document.stopped_at = datetime.datetime.utcnow()
  179. db.session.commit()
  180. def file_indexing_estimate(self, tenant_id: str, file_details: List[UploadFile], tmp_processing_rule: dict,
  181. doc_form: str = None, doc_language: str = 'English', dataset_id: str = None,
  182. indexing_technique: str = 'economy') -> dict:
  183. """
  184. Estimate the indexing for the document.
  185. """
  186. embedding_model = None
  187. if dataset_id:
  188. dataset = Dataset.query.filter_by(
  189. id=dataset_id
  190. ).first()
  191. if not dataset:
  192. raise ValueError('Dataset not found.')
  193. if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality':
  194. embedding_model = ModelFactory.get_embedding_model(
  195. tenant_id=dataset.tenant_id,
  196. model_provider_name=dataset.embedding_model_provider,
  197. model_name=dataset.embedding_model
  198. )
  199. else:
  200. if indexing_technique == 'high_quality':
  201. embedding_model = ModelFactory.get_embedding_model(
  202. tenant_id=tenant_id
  203. )
  204. tokens = 0
  205. preview_texts = []
  206. total_segments = 0
  207. for file_detail in file_details:
  208. # load data from file
  209. text_docs = FileExtractor.load(file_detail)
  210. processing_rule = DatasetProcessRule(
  211. mode=tmp_processing_rule["mode"],
  212. rules=json.dumps(tmp_processing_rule["rules"])
  213. )
  214. # get splitter
  215. splitter = self._get_splitter(processing_rule)
  216. # split to documents
  217. documents = self._split_to_documents_for_estimate(
  218. text_docs=text_docs,
  219. splitter=splitter,
  220. processing_rule=processing_rule
  221. )
  222. total_segments += len(documents)
  223. for document in documents:
  224. if len(preview_texts) < 5:
  225. preview_texts.append(document.page_content)
  226. if indexing_technique == 'high_quality' or embedding_model:
  227. tokens += embedding_model.get_num_tokens(self.filter_string(document.page_content))
  228. if doc_form and doc_form == 'qa_model':
  229. text_generation_model = ModelFactory.get_text_generation_model(
  230. tenant_id=tenant_id
  231. )
  232. if len(preview_texts) > 0:
  233. # qa model document
  234. response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0],
  235. doc_language)
  236. document_qa_list = self.format_split_text(response)
  237. return {
  238. "total_segments": total_segments * 20,
  239. "tokens": total_segments * 2000,
  240. "total_price": '{:f}'.format(
  241. text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.USER)),
  242. "currency": embedding_model.get_currency(),
  243. "qa_preview": document_qa_list,
  244. "preview": preview_texts
  245. }
  246. return {
  247. "total_segments": total_segments,
  248. "tokens": tokens,
  249. "total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)) if embedding_model else 0,
  250. "currency": embedding_model.get_currency() if embedding_model else 'USD',
  251. "preview": preview_texts
  252. }
  253. def notion_indexing_estimate(self, tenant_id: str, notion_info_list: list, tmp_processing_rule: dict,
  254. doc_form: str = None, doc_language: str = 'English', dataset_id: str = None,
  255. indexing_technique: str = 'economy') -> dict:
  256. """
  257. Estimate the indexing for the document.
  258. """
  259. embedding_model = None
  260. if dataset_id:
  261. dataset = Dataset.query.filter_by(
  262. id=dataset_id
  263. ).first()
  264. if not dataset:
  265. raise ValueError('Dataset not found.')
  266. if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality':
  267. embedding_model = ModelFactory.get_embedding_model(
  268. tenant_id=dataset.tenant_id,
  269. model_provider_name=dataset.embedding_model_provider,
  270. model_name=dataset.embedding_model
  271. )
  272. else:
  273. if indexing_technique == 'high_quality':
  274. embedding_model = ModelFactory.get_embedding_model(
  275. tenant_id=tenant_id
  276. )
  277. # load data from notion
  278. tokens = 0
  279. preview_texts = []
  280. total_segments = 0
  281. for notion_info in notion_info_list:
  282. workspace_id = notion_info['workspace_id']
  283. data_source_binding = DataSourceBinding.query.filter(
  284. db.and_(
  285. DataSourceBinding.tenant_id == current_user.current_tenant_id,
  286. DataSourceBinding.provider == 'notion',
  287. DataSourceBinding.disabled == False,
  288. DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
  289. )
  290. ).first()
  291. if not data_source_binding:
  292. raise ValueError('Data source binding not found.')
  293. for page in notion_info['pages']:
  294. loader = NotionLoader(
  295. notion_access_token=data_source_binding.access_token,
  296. notion_workspace_id=workspace_id,
  297. notion_obj_id=page['page_id'],
  298. notion_page_type=page['type']
  299. )
  300. documents = loader.load()
  301. processing_rule = DatasetProcessRule(
  302. mode=tmp_processing_rule["mode"],
  303. rules=json.dumps(tmp_processing_rule["rules"])
  304. )
  305. # get splitter
  306. splitter = self._get_splitter(processing_rule)
  307. # split to documents
  308. documents = self._split_to_documents_for_estimate(
  309. text_docs=documents,
  310. splitter=splitter,
  311. processing_rule=processing_rule
  312. )
  313. total_segments += len(documents)
  314. for document in documents:
  315. if len(preview_texts) < 5:
  316. preview_texts.append(document.page_content)
  317. if indexing_technique == 'high_quality' or embedding_model:
  318. tokens += embedding_model.get_num_tokens(document.page_content)
  319. if doc_form and doc_form == 'qa_model':
  320. text_generation_model = ModelFactory.get_text_generation_model(
  321. tenant_id=tenant_id
  322. )
  323. if len(preview_texts) > 0:
  324. # qa model document
  325. response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0],
  326. doc_language)
  327. document_qa_list = self.format_split_text(response)
  328. return {
  329. "total_segments": total_segments * 20,
  330. "tokens": total_segments * 2000,
  331. "total_price": '{:f}'.format(
  332. text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.USER)),
  333. "currency": embedding_model.get_currency(),
  334. "qa_preview": document_qa_list,
  335. "preview": preview_texts
  336. }
  337. return {
  338. "total_segments": total_segments,
  339. "tokens": tokens,
  340. "total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)) if embedding_model else 0,
  341. "currency": embedding_model.get_currency() if embedding_model else 'USD',
  342. "preview": preview_texts
  343. }
  344. def _load_data(self, dataset_document: DatasetDocument, automatic: bool = False) -> List[Document]:
  345. # load file
  346. if dataset_document.data_source_type not in ["upload_file", "notion_import"]:
  347. return []
  348. data_source_info = dataset_document.data_source_info_dict
  349. text_docs = []
  350. if dataset_document.data_source_type == 'upload_file':
  351. if not data_source_info or 'upload_file_id' not in data_source_info:
  352. raise ValueError("no upload file found")
  353. file_detail = db.session.query(UploadFile). \
  354. filter(UploadFile.id == data_source_info['upload_file_id']). \
  355. one_or_none()
  356. if file_detail:
  357. text_docs = FileExtractor.load(file_detail, is_automatic=False)
  358. elif dataset_document.data_source_type == 'notion_import':
  359. loader = NotionLoader.from_document(dataset_document)
  360. text_docs = loader.load()
  361. # update document status to splitting
  362. self._update_document_index_status(
  363. document_id=dataset_document.id,
  364. after_indexing_status="splitting",
  365. extra_update_params={
  366. DatasetDocument.word_count: sum([len(text_doc.page_content) for text_doc in text_docs]),
  367. DatasetDocument.parsing_completed_at: datetime.datetime.utcnow()
  368. }
  369. )
  370. # replace doc id to document model id
  371. text_docs = cast(List[Document], text_docs)
  372. for text_doc in text_docs:
  373. # remove invalid symbol
  374. text_doc.page_content = self.filter_string(text_doc.page_content)
  375. text_doc.metadata['document_id'] = dataset_document.id
  376. text_doc.metadata['dataset_id'] = dataset_document.dataset_id
  377. return text_docs
  378. def filter_string(self, text):
  379. text = re.sub(r'<\|', '<', text)
  380. text = re.sub(r'\|>', '>', text)
  381. text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\x80-\xFF]', '', text)
  382. return text
  383. def _get_splitter(self, processing_rule: DatasetProcessRule) -> TextSplitter:
  384. """
  385. Get the NodeParser object according to the processing rule.
  386. """
  387. if processing_rule.mode == "custom":
  388. # The user-defined segmentation rule
  389. rules = json.loads(processing_rule.rules)
  390. segmentation = rules["segmentation"]
  391. if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > 1000:
  392. raise ValueError("Custom segment length should be between 50 and 1000.")
  393. separator = segmentation["separator"]
  394. if separator:
  395. separator = separator.replace('\\n', '\n')
  396. character_splitter = FixedRecursiveCharacterTextSplitter.from_tiktoken_encoder(
  397. chunk_size=segmentation["max_tokens"],
  398. chunk_overlap=0,
  399. fixed_separator=separator,
  400. separators=["\n\n", "。", ".", " ", ""]
  401. )
  402. else:
  403. # Automatic segmentation
  404. character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
  405. chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'],
  406. chunk_overlap=0,
  407. separators=["\n\n", "。", ".", " ", ""]
  408. )
  409. return character_splitter
  410. def _step_split(self, text_docs: List[Document], splitter: TextSplitter,
  411. dataset: Dataset, dataset_document: DatasetDocument, processing_rule: DatasetProcessRule) \
  412. -> List[Document]:
  413. """
  414. Split the text documents into documents and save them to the document segment.
  415. """
  416. documents = self._split_to_documents(
  417. text_docs=text_docs,
  418. splitter=splitter,
  419. processing_rule=processing_rule,
  420. tenant_id=dataset.tenant_id,
  421. document_form=dataset_document.doc_form,
  422. document_language=dataset_document.doc_language
  423. )
  424. # save node to document segment
  425. doc_store = DatasetDocumentStore(
  426. dataset=dataset,
  427. user_id=dataset_document.created_by,
  428. document_id=dataset_document.id
  429. )
  430. # add document segments
  431. doc_store.add_documents(documents)
  432. # update document status to indexing
  433. cur_time = datetime.datetime.utcnow()
  434. self._update_document_index_status(
  435. document_id=dataset_document.id,
  436. after_indexing_status="indexing",
  437. extra_update_params={
  438. DatasetDocument.cleaning_completed_at: cur_time,
  439. DatasetDocument.splitting_completed_at: cur_time,
  440. }
  441. )
  442. # update segment status to indexing
  443. self._update_segments_by_document(
  444. dataset_document_id=dataset_document.id,
  445. update_params={
  446. DocumentSegment.status: "indexing",
  447. DocumentSegment.indexing_at: datetime.datetime.utcnow()
  448. }
  449. )
  450. return documents
  451. def _split_to_documents(self, text_docs: List[Document], splitter: TextSplitter,
  452. processing_rule: DatasetProcessRule, tenant_id: str,
  453. document_form: str, document_language: str) -> List[Document]:
  454. """
  455. Split the text documents into nodes.
  456. """
  457. all_documents = []
  458. all_qa_documents = []
  459. for text_doc in text_docs:
  460. # document clean
  461. document_text = self._document_clean(text_doc.page_content, processing_rule)
  462. text_doc.page_content = document_text
  463. # parse document to nodes
  464. documents = splitter.split_documents([text_doc])
  465. split_documents = []
  466. for document_node in documents:
  467. if document_node.page_content.strip():
  468. doc_id = str(uuid.uuid4())
  469. hash = helper.generate_text_hash(document_node.page_content)
  470. document_node.metadata['doc_id'] = doc_id
  471. document_node.metadata['doc_hash'] = hash
  472. split_documents.append(document_node)
  473. all_documents.extend(split_documents)
  474. # processing qa document
  475. if document_form == 'qa_model':
  476. for i in range(0, len(all_documents), 10):
  477. threads = []
  478. sub_documents = all_documents[i:i + 10]
  479. for doc in sub_documents:
  480. document_format_thread = threading.Thread(target=self.format_qa_document, kwargs={
  481. 'flask_app': current_app._get_current_object(),
  482. 'tenant_id': tenant_id, 'document_node': doc, 'all_qa_documents': all_qa_documents,
  483. 'document_language': document_language})
  484. threads.append(document_format_thread)
  485. document_format_thread.start()
  486. for thread in threads:
  487. thread.join()
  488. return all_qa_documents
  489. return all_documents
  490. def format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents, document_language):
  491. format_documents = []
  492. if document_node.page_content is None or not document_node.page_content.strip():
  493. return
  494. with flask_app.app_context():
  495. try:
  496. # qa model document
  497. response = LLMGenerator.generate_qa_document(tenant_id, document_node.page_content, document_language)
  498. document_qa_list = self.format_split_text(response)
  499. qa_documents = []
  500. for result in document_qa_list:
  501. qa_document = Document(page_content=result['question'], metadata=document_node.metadata.copy())
  502. doc_id = str(uuid.uuid4())
  503. hash = helper.generate_text_hash(result['question'])
  504. qa_document.metadata['answer'] = result['answer']
  505. qa_document.metadata['doc_id'] = doc_id
  506. qa_document.metadata['doc_hash'] = hash
  507. qa_documents.append(qa_document)
  508. format_documents.extend(qa_documents)
  509. except Exception as e:
  510. logging.exception(e)
  511. all_qa_documents.extend(format_documents)
  512. def _split_to_documents_for_estimate(self, text_docs: List[Document], splitter: TextSplitter,
  513. processing_rule: DatasetProcessRule) -> List[Document]:
  514. """
  515. Split the text documents into nodes.
  516. """
  517. all_documents = []
  518. for text_doc in text_docs:
  519. # document clean
  520. document_text = self._document_clean(text_doc.page_content, processing_rule)
  521. text_doc.page_content = document_text
  522. # parse document to nodes
  523. documents = splitter.split_documents([text_doc])
  524. split_documents = []
  525. for document in documents:
  526. if document.page_content is None or not document.page_content.strip():
  527. continue
  528. doc_id = str(uuid.uuid4())
  529. hash = helper.generate_text_hash(document.page_content)
  530. document.metadata['doc_id'] = doc_id
  531. document.metadata['doc_hash'] = hash
  532. split_documents.append(document)
  533. all_documents.extend(split_documents)
  534. return all_documents
  535. def _document_clean(self, text: str, processing_rule: DatasetProcessRule) -> str:
  536. """
  537. Clean the document text according to the processing rules.
  538. """
  539. if processing_rule.mode == "automatic":
  540. rules = DatasetProcessRule.AUTOMATIC_RULES
  541. else:
  542. rules = json.loads(processing_rule.rules) if processing_rule.rules else {}
  543. if 'pre_processing_rules' in rules:
  544. pre_processing_rules = rules["pre_processing_rules"]
  545. for pre_processing_rule in pre_processing_rules:
  546. if pre_processing_rule["id"] == "remove_extra_spaces" and pre_processing_rule["enabled"] is True:
  547. # Remove extra spaces
  548. pattern = r'\n{3,}'
  549. text = re.sub(pattern, '\n\n', text)
  550. pattern = r'[\t\f\r\x20\u00a0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]{2,}'
  551. text = re.sub(pattern, ' ', text)
  552. elif pre_processing_rule["id"] == "remove_urls_emails" and pre_processing_rule["enabled"] is True:
  553. # Remove email
  554. pattern = r'([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)'
  555. text = re.sub(pattern, '', text)
  556. # Remove URL
  557. pattern = r'https?://[^\s]+'
  558. text = re.sub(pattern, '', text)
  559. return text
  560. def format_split_text(self, text):
  561. regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q|$)"
  562. matches = re.findall(regex, text, re.MULTILINE)
  563. return [
  564. {
  565. "question": q,
  566. "answer": re.sub(r"\n\s*", "\n", a.strip())
  567. }
  568. for q, a in matches if q and a
  569. ]
  570. def _build_index(self, dataset: Dataset, dataset_document: DatasetDocument, documents: List[Document]) -> None:
  571. """
  572. Build the index for the document.
  573. """
  574. vector_index = IndexBuilder.get_index(dataset, 'high_quality')
  575. keyword_table_index = IndexBuilder.get_index(dataset, 'economy')
  576. embedding_model = None
  577. if dataset.indexing_technique == 'high_quality':
  578. embedding_model = ModelFactory.get_embedding_model(
  579. tenant_id=dataset.tenant_id,
  580. model_provider_name=dataset.embedding_model_provider,
  581. model_name=dataset.embedding_model
  582. )
  583. # chunk nodes by chunk size
  584. indexing_start_at = time.perf_counter()
  585. tokens = 0
  586. chunk_size = 100
  587. for i in range(0, len(documents), chunk_size):
  588. # check document is paused
  589. self._check_document_paused_status(dataset_document.id)
  590. chunk_documents = documents[i:i + chunk_size]
  591. if dataset.indexing_technique == 'high_quality' or embedding_model:
  592. tokens += sum(
  593. embedding_model.get_num_tokens(document.page_content)
  594. for document in chunk_documents
  595. )
  596. # save vector index
  597. if vector_index:
  598. vector_index.add_texts(chunk_documents)
  599. # save keyword index
  600. keyword_table_index.add_texts(chunk_documents)
  601. document_ids = [document.metadata['doc_id'] for document in chunk_documents]
  602. db.session.query(DocumentSegment).filter(
  603. DocumentSegment.document_id == dataset_document.id,
  604. DocumentSegment.index_node_id.in_(document_ids),
  605. DocumentSegment.status == "indexing"
  606. ).update({
  607. DocumentSegment.status: "completed",
  608. DocumentSegment.enabled: True,
  609. DocumentSegment.completed_at: datetime.datetime.utcnow()
  610. })
  611. db.session.commit()
  612. indexing_end_at = time.perf_counter()
  613. # update document status to completed
  614. self._update_document_index_status(
  615. document_id=dataset_document.id,
  616. after_indexing_status="completed",
  617. extra_update_params={
  618. DatasetDocument.tokens: tokens,
  619. DatasetDocument.completed_at: datetime.datetime.utcnow(),
  620. DatasetDocument.indexing_latency: indexing_end_at - indexing_start_at,
  621. }
  622. )
  623. def _check_document_paused_status(self, document_id: str):
  624. indexing_cache_key = 'document_{}_is_paused'.format(document_id)
  625. result = redis_client.get(indexing_cache_key)
  626. if result:
  627. raise DocumentIsPausedException()
  628. def _update_document_index_status(self, document_id: str, after_indexing_status: str,
  629. extra_update_params: Optional[dict] = None) -> None:
  630. """
  631. Update the document indexing status.
  632. """
  633. count = DatasetDocument.query.filter_by(id=document_id, is_paused=True).count()
  634. if count > 0:
  635. raise DocumentIsPausedException()
  636. document = DatasetDocument.query.filter_by(id=document_id).first()
  637. if not document:
  638. raise DocumentIsDeletedPausedException()
  639. update_params = {
  640. DatasetDocument.indexing_status: after_indexing_status
  641. }
  642. if extra_update_params:
  643. update_params.update(extra_update_params)
  644. DatasetDocument.query.filter_by(id=document_id).update(update_params)
  645. db.session.commit()
  646. def _update_segments_by_document(self, dataset_document_id: str, update_params: dict) -> None:
  647. """
  648. Update the document segment by document id.
  649. """
  650. DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params)
  651. db.session.commit()
  652. def batch_add_segments(self, segments: List[DocumentSegment], dataset: Dataset):
  653. """
  654. Batch add segments index processing
  655. """
  656. documents = []
  657. for segment in segments:
  658. document = Document(
  659. page_content=segment.content,
  660. metadata={
  661. "doc_id": segment.index_node_id,
  662. "doc_hash": segment.index_node_hash,
  663. "document_id": segment.document_id,
  664. "dataset_id": segment.dataset_id,
  665. }
  666. )
  667. documents.append(document)
  668. # save vector index
  669. index = IndexBuilder.get_index(dataset, 'high_quality')
  670. if index:
  671. index.add_texts(documents, duplicate_check=True)
  672. # save keyword index
  673. index = IndexBuilder.get_index(dataset, 'economy')
  674. if index:
  675. index.add_texts(documents)
  676. class DocumentIsPausedException(Exception):
  677. pass
  678. class DocumentIsDeletedPausedException(Exception):
  679. pass