indexing_runner.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783
  1. import datetime
  2. import json
  3. import logging
  4. import re
  5. import threading
  6. import time
  7. import uuid
  8. from typing import Optional, List, cast
  9. from flask import current_app, Flask
  10. from flask_login import current_user
  11. from langchain.schema import Document
  12. from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
  13. from core.data_loader.file_extractor import FileExtractor
  14. from core.data_loader.loader.notion import NotionLoader
  15. from core.docstore.dataset_docstore import DatesetDocumentStore
  16. from core.generator.llm_generator import LLMGenerator
  17. from core.index.index import IndexBuilder
  18. from core.model_providers.error import ProviderTokenNotInitError
  19. from core.model_providers.model_factory import ModelFactory
  20. from core.model_providers.models.entity.message import MessageType
  21. from core.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter
  22. from extensions.ext_database import db
  23. from extensions.ext_redis import redis_client
  24. from extensions.ext_storage import storage
  25. from libs import helper
  26. from models.dataset import Document as DatasetDocument
  27. from models.dataset import Dataset, DocumentSegment, DatasetProcessRule
  28. from models.model import UploadFile
  29. from models.source import DataSourceBinding
  30. class IndexingRunner:
  31. def __init__(self):
  32. self.storage = storage
  33. def run(self, dataset_documents: List[DatasetDocument]):
  34. """Run the indexing process."""
  35. for dataset_document in dataset_documents:
  36. try:
  37. # get dataset
  38. dataset = Dataset.query.filter_by(
  39. id=dataset_document.dataset_id
  40. ).first()
  41. if not dataset:
  42. raise ValueError("no dataset found")
  43. # load file
  44. text_docs = self._load_data(dataset_document)
  45. # get the process rule
  46. processing_rule = db.session.query(DatasetProcessRule). \
  47. filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
  48. first()
  49. # get splitter
  50. splitter = self._get_splitter(processing_rule)
  51. # split to documents
  52. documents = self._step_split(
  53. text_docs=text_docs,
  54. splitter=splitter,
  55. dataset=dataset,
  56. dataset_document=dataset_document,
  57. processing_rule=processing_rule
  58. )
  59. self._build_index(
  60. dataset=dataset,
  61. dataset_document=dataset_document,
  62. documents=documents
  63. )
  64. except DocumentIsPausedException:
  65. raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id))
  66. except ProviderTokenNotInitError as e:
  67. dataset_document.indexing_status = 'error'
  68. dataset_document.error = str(e.description)
  69. dataset_document.stopped_at = datetime.datetime.utcnow()
  70. db.session.commit()
  71. except Exception as e:
  72. logging.exception("consume document failed")
  73. dataset_document.indexing_status = 'error'
  74. dataset_document.error = str(e)
  75. dataset_document.stopped_at = datetime.datetime.utcnow()
  76. db.session.commit()
  77. def format_split_text(self, text):
  78. regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q|$)"
  79. matches = re.findall(regex, text, re.MULTILINE)
  80. result = []
  81. for match in matches:
  82. q = match[0]
  83. a = match[1]
  84. if q and a:
  85. result.append({
  86. "question": q,
  87. "answer": re.sub(r"\n\s*", "\n", a.strip())
  88. })
  89. return result
  90. def run_in_splitting_status(self, dataset_document: DatasetDocument):
  91. """Run the indexing process when the index_status is splitting."""
  92. try:
  93. # get dataset
  94. dataset = Dataset.query.filter_by(
  95. id=dataset_document.dataset_id
  96. ).first()
  97. if not dataset:
  98. raise ValueError("no dataset found")
  99. # get exist document_segment list and delete
  100. document_segments = DocumentSegment.query.filter_by(
  101. dataset_id=dataset.id,
  102. document_id=dataset_document.id
  103. ).all()
  104. db.session.delete(document_segments)
  105. db.session.commit()
  106. # load file
  107. text_docs = self._load_data(dataset_document)
  108. # get the process rule
  109. processing_rule = db.session.query(DatasetProcessRule). \
  110. filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
  111. first()
  112. # get splitter
  113. splitter = self._get_splitter(processing_rule)
  114. # split to documents
  115. documents = self._step_split(
  116. text_docs=text_docs,
  117. splitter=splitter,
  118. dataset=dataset,
  119. dataset_document=dataset_document,
  120. processing_rule=processing_rule
  121. )
  122. # build index
  123. self._build_index(
  124. dataset=dataset,
  125. dataset_document=dataset_document,
  126. documents=documents
  127. )
  128. except DocumentIsPausedException:
  129. raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id))
  130. except ProviderTokenNotInitError as e:
  131. dataset_document.indexing_status = 'error'
  132. dataset_document.error = str(e.description)
  133. dataset_document.stopped_at = datetime.datetime.utcnow()
  134. db.session.commit()
  135. except Exception as e:
  136. logging.exception("consume document failed")
  137. dataset_document.indexing_status = 'error'
  138. dataset_document.error = str(e)
  139. dataset_document.stopped_at = datetime.datetime.utcnow()
  140. db.session.commit()
  141. def run_in_indexing_status(self, dataset_document: DatasetDocument):
  142. """Run the indexing process when the index_status is indexing."""
  143. try:
  144. # get dataset
  145. dataset = Dataset.query.filter_by(
  146. id=dataset_document.dataset_id
  147. ).first()
  148. if not dataset:
  149. raise ValueError("no dataset found")
  150. # get exist document_segment list and delete
  151. document_segments = DocumentSegment.query.filter_by(
  152. dataset_id=dataset.id,
  153. document_id=dataset_document.id
  154. ).all()
  155. documents = []
  156. if document_segments:
  157. for document_segment in document_segments:
  158. # transform segment to node
  159. if document_segment.status != "completed":
  160. document = Document(
  161. page_content=document_segment.content,
  162. metadata={
  163. "doc_id": document_segment.index_node_id,
  164. "doc_hash": document_segment.index_node_hash,
  165. "document_id": document_segment.document_id,
  166. "dataset_id": document_segment.dataset_id,
  167. }
  168. )
  169. documents.append(document)
  170. # build index
  171. self._build_index(
  172. dataset=dataset,
  173. dataset_document=dataset_document,
  174. documents=documents
  175. )
  176. except DocumentIsPausedException:
  177. raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id))
  178. except ProviderTokenNotInitError as e:
  179. dataset_document.indexing_status = 'error'
  180. dataset_document.error = str(e.description)
  181. dataset_document.stopped_at = datetime.datetime.utcnow()
  182. db.session.commit()
  183. except Exception as e:
  184. logging.exception("consume document failed")
  185. dataset_document.indexing_status = 'error'
  186. dataset_document.error = str(e)
  187. dataset_document.stopped_at = datetime.datetime.utcnow()
  188. db.session.commit()
  189. def file_indexing_estimate(self, tenant_id: str, file_details: List[UploadFile], tmp_processing_rule: dict,
  190. doc_form: str = None, doc_language: str = 'English', dataset_id: str = None,
  191. indexing_technique: str = 'economy') -> dict:
  192. """
  193. Estimate the indexing for the document.
  194. """
  195. embedding_model = None
  196. if dataset_id:
  197. dataset = Dataset.query.filter_by(
  198. id=dataset_id
  199. ).first()
  200. if not dataset:
  201. raise ValueError('Dataset not found.')
  202. if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality':
  203. embedding_model = ModelFactory.get_embedding_model(
  204. tenant_id=dataset.tenant_id,
  205. model_provider_name=dataset.embedding_model_provider,
  206. model_name=dataset.embedding_model
  207. )
  208. else:
  209. if indexing_technique == 'high_quality':
  210. embedding_model = ModelFactory.get_embedding_model(
  211. tenant_id=tenant_id
  212. )
  213. tokens = 0
  214. preview_texts = []
  215. total_segments = 0
  216. for file_detail in file_details:
  217. # load data from file
  218. text_docs = FileExtractor.load(file_detail)
  219. processing_rule = DatasetProcessRule(
  220. mode=tmp_processing_rule["mode"],
  221. rules=json.dumps(tmp_processing_rule["rules"])
  222. )
  223. # get splitter
  224. splitter = self._get_splitter(processing_rule)
  225. # split to documents
  226. documents = self._split_to_documents_for_estimate(
  227. text_docs=text_docs,
  228. splitter=splitter,
  229. processing_rule=processing_rule
  230. )
  231. total_segments += len(documents)
  232. for document in documents:
  233. if len(preview_texts) < 5:
  234. preview_texts.append(document.page_content)
  235. if indexing_technique == 'high_quality' or embedding_model:
  236. tokens += embedding_model.get_num_tokens(self.filter_string(document.page_content))
  237. if doc_form and doc_form == 'qa_model':
  238. text_generation_model = ModelFactory.get_text_generation_model(
  239. tenant_id=tenant_id
  240. )
  241. if len(preview_texts) > 0:
  242. # qa model document
  243. response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0], doc_language)
  244. document_qa_list = self.format_split_text(response)
  245. return {
  246. "total_segments": total_segments * 20,
  247. "tokens": total_segments * 2000,
  248. "total_price": '{:f}'.format(
  249. text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.HUMAN)),
  250. "currency": embedding_model.get_currency(),
  251. "qa_preview": document_qa_list,
  252. "preview": preview_texts
  253. }
  254. return {
  255. "total_segments": total_segments,
  256. "tokens": tokens,
  257. "total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)) if embedding_model else 0,
  258. "currency": embedding_model.get_currency() if embedding_model else 'USD',
  259. "preview": preview_texts
  260. }
  261. def notion_indexing_estimate(self, tenant_id: str, notion_info_list: list, tmp_processing_rule: dict,
  262. doc_form: str = None, doc_language: str = 'English', dataset_id: str = None,
  263. indexing_technique: str = 'economy') -> dict:
  264. """
  265. Estimate the indexing for the document.
  266. """
  267. embedding_model = None
  268. if dataset_id:
  269. dataset = Dataset.query.filter_by(
  270. id=dataset_id
  271. ).first()
  272. if not dataset:
  273. raise ValueError('Dataset not found.')
  274. if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality':
  275. embedding_model = ModelFactory.get_embedding_model(
  276. tenant_id=dataset.tenant_id,
  277. model_provider_name=dataset.embedding_model_provider,
  278. model_name=dataset.embedding_model
  279. )
  280. else:
  281. if indexing_technique == 'high_quality':
  282. embedding_model = ModelFactory.get_embedding_model(
  283. tenant_id=tenant_id
  284. )
  285. # load data from notion
  286. tokens = 0
  287. preview_texts = []
  288. total_segments = 0
  289. for notion_info in notion_info_list:
  290. workspace_id = notion_info['workspace_id']
  291. data_source_binding = DataSourceBinding.query.filter(
  292. db.and_(
  293. DataSourceBinding.tenant_id == current_user.current_tenant_id,
  294. DataSourceBinding.provider == 'notion',
  295. DataSourceBinding.disabled == False,
  296. DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
  297. )
  298. ).first()
  299. if not data_source_binding:
  300. raise ValueError('Data source binding not found.')
  301. for page in notion_info['pages']:
  302. loader = NotionLoader(
  303. notion_access_token=data_source_binding.access_token,
  304. notion_workspace_id=workspace_id,
  305. notion_obj_id=page['page_id'],
  306. notion_page_type=page['type']
  307. )
  308. documents = loader.load()
  309. processing_rule = DatasetProcessRule(
  310. mode=tmp_processing_rule["mode"],
  311. rules=json.dumps(tmp_processing_rule["rules"])
  312. )
  313. # get splitter
  314. splitter = self._get_splitter(processing_rule)
  315. # split to documents
  316. documents = self._split_to_documents_for_estimate(
  317. text_docs=documents,
  318. splitter=splitter,
  319. processing_rule=processing_rule
  320. )
  321. total_segments += len(documents)
  322. for document in documents:
  323. if len(preview_texts) < 5:
  324. preview_texts.append(document.page_content)
  325. if indexing_technique == 'high_quality' or embedding_model:
  326. tokens += embedding_model.get_num_tokens(document.page_content)
  327. if doc_form and doc_form == 'qa_model':
  328. text_generation_model = ModelFactory.get_text_generation_model(
  329. tenant_id=tenant_id
  330. )
  331. if len(preview_texts) > 0:
  332. # qa model document
  333. response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0], doc_language)
  334. document_qa_list = self.format_split_text(response)
  335. return {
  336. "total_segments": total_segments * 20,
  337. "tokens": total_segments * 2000,
  338. "total_price": '{:f}'.format(
  339. text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.HUMAN)),
  340. "currency": embedding_model.get_currency(),
  341. "qa_preview": document_qa_list,
  342. "preview": preview_texts
  343. }
  344. return {
  345. "total_segments": total_segments,
  346. "tokens": tokens,
  347. "total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)) if embedding_model else 0,
  348. "currency": embedding_model.get_currency() if embedding_model else 'USD',
  349. "preview": preview_texts
  350. }
  351. def _load_data(self, dataset_document: DatasetDocument) -> List[Document]:
  352. # load file
  353. if dataset_document.data_source_type not in ["upload_file", "notion_import"]:
  354. return []
  355. data_source_info = dataset_document.data_source_info_dict
  356. text_docs = []
  357. if dataset_document.data_source_type == 'upload_file':
  358. if not data_source_info or 'upload_file_id' not in data_source_info:
  359. raise ValueError("no upload file found")
  360. file_detail = db.session.query(UploadFile). \
  361. filter(UploadFile.id == data_source_info['upload_file_id']). \
  362. one_or_none()
  363. if file_detail:
  364. text_docs = FileExtractor.load(file_detail)
  365. elif dataset_document.data_source_type == 'notion_import':
  366. loader = NotionLoader.from_document(dataset_document)
  367. text_docs = loader.load()
  368. # update document status to splitting
  369. self._update_document_index_status(
  370. document_id=dataset_document.id,
  371. after_indexing_status="splitting",
  372. extra_update_params={
  373. DatasetDocument.word_count: sum([len(text_doc.page_content) for text_doc in text_docs]),
  374. DatasetDocument.parsing_completed_at: datetime.datetime.utcnow()
  375. }
  376. )
  377. # replace doc id to document model id
  378. text_docs = cast(List[Document], text_docs)
  379. for text_doc in text_docs:
  380. # remove invalid symbol
  381. text_doc.page_content = self.filter_string(text_doc.page_content)
  382. text_doc.metadata['document_id'] = dataset_document.id
  383. text_doc.metadata['dataset_id'] = dataset_document.dataset_id
  384. return text_docs
  385. def filter_string(self, text):
  386. text = re.sub(r'<\|', '<', text)
  387. text = re.sub(r'\|>', '>', text)
  388. text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\x80-\xFF]', '', text)
  389. return text
  390. def _get_splitter(self, processing_rule: DatasetProcessRule) -> TextSplitter:
  391. """
  392. Get the NodeParser object according to the processing rule.
  393. """
  394. if processing_rule.mode == "custom":
  395. # The user-defined segmentation rule
  396. rules = json.loads(processing_rule.rules)
  397. segmentation = rules["segmentation"]
  398. if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > 1000:
  399. raise ValueError("Custom segment length should be between 50 and 1000.")
  400. separator = segmentation["separator"]
  401. if separator:
  402. separator = separator.replace('\\n', '\n')
  403. character_splitter = FixedRecursiveCharacterTextSplitter.from_tiktoken_encoder(
  404. chunk_size=segmentation["max_tokens"],
  405. chunk_overlap=0,
  406. fixed_separator=separator,
  407. separators=["\n\n", "。", ".", " ", ""]
  408. )
  409. else:
  410. # Automatic segmentation
  411. character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
  412. chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'],
  413. chunk_overlap=0,
  414. separators=["\n\n", "。", ".", " ", ""]
  415. )
  416. return character_splitter
  417. def _step_split(self, text_docs: List[Document], splitter: TextSplitter,
  418. dataset: Dataset, dataset_document: DatasetDocument, processing_rule: DatasetProcessRule) \
  419. -> List[Document]:
  420. """
  421. Split the text documents into documents and save them to the document segment.
  422. """
  423. documents = self._split_to_documents(
  424. text_docs=text_docs,
  425. splitter=splitter,
  426. processing_rule=processing_rule,
  427. tenant_id=dataset.tenant_id,
  428. document_form=dataset_document.doc_form,
  429. document_language=dataset_document.doc_language
  430. )
  431. # save node to document segment
  432. doc_store = DatesetDocumentStore(
  433. dataset=dataset,
  434. user_id=dataset_document.created_by,
  435. document_id=dataset_document.id
  436. )
  437. # add document segments
  438. doc_store.add_documents(documents)
  439. # update document status to indexing
  440. cur_time = datetime.datetime.utcnow()
  441. self._update_document_index_status(
  442. document_id=dataset_document.id,
  443. after_indexing_status="indexing",
  444. extra_update_params={
  445. DatasetDocument.cleaning_completed_at: cur_time,
  446. DatasetDocument.splitting_completed_at: cur_time,
  447. }
  448. )
  449. # update segment status to indexing
  450. self._update_segments_by_document(
  451. dataset_document_id=dataset_document.id,
  452. update_params={
  453. DocumentSegment.status: "indexing",
  454. DocumentSegment.indexing_at: datetime.datetime.utcnow()
  455. }
  456. )
  457. return documents
  458. def _split_to_documents(self, text_docs: List[Document], splitter: TextSplitter,
  459. processing_rule: DatasetProcessRule, tenant_id: str,
  460. document_form: str, document_language: str) -> List[Document]:
  461. """
  462. Split the text documents into nodes.
  463. """
  464. all_documents = []
  465. all_qa_documents = []
  466. for text_doc in text_docs:
  467. # document clean
  468. document_text = self._document_clean(text_doc.page_content, processing_rule)
  469. text_doc.page_content = document_text
  470. # parse document to nodes
  471. documents = splitter.split_documents([text_doc])
  472. split_documents = []
  473. for document_node in documents:
  474. if document_node.page_content.strip():
  475. doc_id = str(uuid.uuid4())
  476. hash = helper.generate_text_hash(document_node.page_content)
  477. document_node.metadata['doc_id'] = doc_id
  478. document_node.metadata['doc_hash'] = hash
  479. split_documents.append(document_node)
  480. all_documents.extend(split_documents)
  481. # processing qa document
  482. if document_form == 'qa_model':
  483. for i in range(0, len(all_documents), 10):
  484. threads = []
  485. sub_documents = all_documents[i:i + 10]
  486. for doc in sub_documents:
  487. document_format_thread = threading.Thread(target=self.format_qa_document, kwargs={
  488. 'flask_app': current_app._get_current_object(),
  489. 'tenant_id': tenant_id, 'document_node': doc, 'all_qa_documents': all_qa_documents,
  490. 'document_language': document_language})
  491. threads.append(document_format_thread)
  492. document_format_thread.start()
  493. for thread in threads:
  494. thread.join()
  495. return all_qa_documents
  496. return all_documents
  497. def format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents, document_language):
  498. format_documents = []
  499. if document_node.page_content is None or not document_node.page_content.strip():
  500. return
  501. with flask_app.app_context():
  502. try:
  503. # qa model document
  504. response = LLMGenerator.generate_qa_document(tenant_id, document_node.page_content, document_language)
  505. document_qa_list = self.format_split_text(response)
  506. qa_documents = []
  507. for result in document_qa_list:
  508. qa_document = Document(page_content=result['question'], metadata=document_node.metadata.copy())
  509. doc_id = str(uuid.uuid4())
  510. hash = helper.generate_text_hash(result['question'])
  511. qa_document.metadata['answer'] = result['answer']
  512. qa_document.metadata['doc_id'] = doc_id
  513. qa_document.metadata['doc_hash'] = hash
  514. qa_documents.append(qa_document)
  515. format_documents.extend(qa_documents)
  516. except Exception as e:
  517. logging.exception(e)
  518. all_qa_documents.extend(format_documents)
  519. def _split_to_documents_for_estimate(self, text_docs: List[Document], splitter: TextSplitter,
  520. processing_rule: DatasetProcessRule) -> List[Document]:
  521. """
  522. Split the text documents into nodes.
  523. """
  524. all_documents = []
  525. for text_doc in text_docs:
  526. # document clean
  527. document_text = self._document_clean(text_doc.page_content, processing_rule)
  528. text_doc.page_content = document_text
  529. # parse document to nodes
  530. documents = splitter.split_documents([text_doc])
  531. split_documents = []
  532. for document in documents:
  533. if document.page_content is None or not document.page_content.strip():
  534. continue
  535. doc_id = str(uuid.uuid4())
  536. hash = helper.generate_text_hash(document.page_content)
  537. document.metadata['doc_id'] = doc_id
  538. document.metadata['doc_hash'] = hash
  539. split_documents.append(document)
  540. all_documents.extend(split_documents)
  541. return all_documents
  542. def _document_clean(self, text: str, processing_rule: DatasetProcessRule) -> str:
  543. """
  544. Clean the document text according to the processing rules.
  545. """
  546. if processing_rule.mode == "automatic":
  547. rules = DatasetProcessRule.AUTOMATIC_RULES
  548. else:
  549. rules = json.loads(processing_rule.rules) if processing_rule.rules else {}
  550. if 'pre_processing_rules' in rules:
  551. pre_processing_rules = rules["pre_processing_rules"]
  552. for pre_processing_rule in pre_processing_rules:
  553. if pre_processing_rule["id"] == "remove_extra_spaces" and pre_processing_rule["enabled"] is True:
  554. # Remove extra spaces
  555. pattern = r'\n{3,}'
  556. text = re.sub(pattern, '\n\n', text)
  557. pattern = r'[\t\f\r\x20\u00a0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]{2,}'
  558. text = re.sub(pattern, ' ', text)
  559. elif pre_processing_rule["id"] == "remove_urls_emails" and pre_processing_rule["enabled"] is True:
  560. # Remove email
  561. pattern = r'([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)'
  562. text = re.sub(pattern, '', text)
  563. # Remove URL
  564. pattern = r'https?://[^\s]+'
  565. text = re.sub(pattern, '', text)
  566. return text
  567. def format_split_text(self, text):
  568. regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q|$)" # 匹配Q和A的正则表达式
  569. matches = re.findall(regex, text, re.MULTILINE) # 获取所有匹配到的结果
  570. result = [] # 存储最终的结果
  571. for match in matches:
  572. q = match[0]
  573. a = match[1]
  574. if q and a:
  575. # 如果Q和A都存在,就将其添加到结果中
  576. result.append({
  577. "question": q,
  578. "answer": re.sub(r"\n\s*", "\n", a.strip())
  579. })
  580. return result
  581. def _build_index(self, dataset: Dataset, dataset_document: DatasetDocument, documents: List[Document]) -> None:
  582. """
  583. Build the index for the document.
  584. """
  585. vector_index = IndexBuilder.get_index(dataset, 'high_quality')
  586. keyword_table_index = IndexBuilder.get_index(dataset, 'economy')
  587. embedding_model = None
  588. if dataset.indexing_technique == 'high_quality':
  589. embedding_model = ModelFactory.get_embedding_model(
  590. tenant_id=dataset.tenant_id,
  591. model_provider_name=dataset.embedding_model_provider,
  592. model_name=dataset.embedding_model
  593. )
  594. # chunk nodes by chunk size
  595. indexing_start_at = time.perf_counter()
  596. tokens = 0
  597. chunk_size = 100
  598. for i in range(0, len(documents), chunk_size):
  599. # check document is paused
  600. self._check_document_paused_status(dataset_document.id)
  601. chunk_documents = documents[i:i + chunk_size]
  602. if dataset.indexing_technique == 'high_quality' or embedding_model:
  603. tokens += sum(
  604. embedding_model.get_num_tokens(document.page_content)
  605. for document in chunk_documents
  606. )
  607. # save vector index
  608. if vector_index:
  609. vector_index.add_texts(chunk_documents)
  610. # save keyword index
  611. keyword_table_index.add_texts(chunk_documents)
  612. document_ids = [document.metadata['doc_id'] for document in chunk_documents]
  613. db.session.query(DocumentSegment).filter(
  614. DocumentSegment.document_id == dataset_document.id,
  615. DocumentSegment.index_node_id.in_(document_ids),
  616. DocumentSegment.status == "indexing"
  617. ).update({
  618. DocumentSegment.status: "completed",
  619. DocumentSegment.enabled: True,
  620. DocumentSegment.completed_at: datetime.datetime.utcnow()
  621. })
  622. db.session.commit()
  623. indexing_end_at = time.perf_counter()
  624. # update document status to completed
  625. self._update_document_index_status(
  626. document_id=dataset_document.id,
  627. after_indexing_status="completed",
  628. extra_update_params={
  629. DatasetDocument.tokens: tokens,
  630. DatasetDocument.completed_at: datetime.datetime.utcnow(),
  631. DatasetDocument.indexing_latency: indexing_end_at - indexing_start_at,
  632. }
  633. )
  634. def _check_document_paused_status(self, document_id: str):
  635. indexing_cache_key = 'document_{}_is_paused'.format(document_id)
  636. result = redis_client.get(indexing_cache_key)
  637. if result:
  638. raise DocumentIsPausedException()
  639. def _update_document_index_status(self, document_id: str, after_indexing_status: str,
  640. extra_update_params: Optional[dict] = None) -> None:
  641. """
  642. Update the document indexing status.
  643. """
  644. count = DatasetDocument.query.filter_by(id=document_id, is_paused=True).count()
  645. if count > 0:
  646. raise DocumentIsPausedException()
  647. update_params = {
  648. DatasetDocument.indexing_status: after_indexing_status
  649. }
  650. if extra_update_params:
  651. update_params.update(extra_update_params)
  652. DatasetDocument.query.filter_by(id=document_id).update(update_params)
  653. db.session.commit()
  654. def _update_segments_by_document(self, dataset_document_id: str, update_params: dict) -> None:
  655. """
  656. Update the document segment by document id.
  657. """
  658. DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params)
  659. db.session.commit()
  660. def batch_add_segments(self, segments: List[DocumentSegment], dataset: Dataset):
  661. """
  662. Batch add segments index processing
  663. """
  664. documents = []
  665. for segment in segments:
  666. document = Document(
  667. page_content=segment.content,
  668. metadata={
  669. "doc_id": segment.index_node_id,
  670. "doc_hash": segment.index_node_hash,
  671. "document_id": segment.document_id,
  672. "dataset_id": segment.dataset_id,
  673. }
  674. )
  675. documents.append(document)
  676. # save vector index
  677. index = IndexBuilder.get_index(dataset, 'high_quality')
  678. if index:
  679. index.add_texts(documents, duplicate_check=True)
  680. # save keyword index
  681. index = IndexBuilder.get_index(dataset, 'economy')
  682. if index:
  683. index.add_texts(documents)
  684. class DocumentIsPausedException(Exception):
  685. pass