indexing_runner.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838
  1. import concurrent.futures
  2. import datetime
  3. import json
  4. import logging
  5. import re
  6. import threading
  7. import time
  8. import uuid
  9. from typing import Optional, cast
  10. from flask import Flask, current_app
  11. from flask_login import current_user
  12. from sqlalchemy.orm.exc import ObjectDeletedError
  13. from core.docstore.dataset_docstore import DatasetDocumentStore
  14. from core.errors.error import ProviderTokenNotInitError
  15. from core.generator.llm_generator import LLMGenerator
  16. from core.model_manager import ModelInstance, ModelManager
  17. from core.model_runtime.entities.model_entities import ModelType, PriceType
  18. from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
  19. from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
  20. from core.rag.extractor.entity.extract_setting import ExtractSetting
  21. from core.rag.index_processor.index_processor_base import BaseIndexProcessor
  22. from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
  23. from core.rag.models.document import Document
  24. from core.splitter.fixed_text_splitter import EnhanceRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitter
  25. from core.splitter.text_splitter import TextSplitter
  26. from extensions.ext_database import db
  27. from extensions.ext_redis import redis_client
  28. from extensions.ext_storage import storage
  29. from libs import helper
  30. from models.dataset import Dataset, DatasetProcessRule, DocumentSegment
  31. from models.dataset import Document as DatasetDocument
  32. from models.model import UploadFile
  33. from services.feature_service import FeatureService
  34. class IndexingRunner:
  35. def __init__(self):
  36. self.storage = storage
  37. self.model_manager = ModelManager()
  38. def run(self, dataset_documents: list[DatasetDocument]):
  39. """Run the indexing process."""
  40. for dataset_document in dataset_documents:
  41. try:
  42. # get dataset
  43. dataset = Dataset.query.filter_by(
  44. id=dataset_document.dataset_id
  45. ).first()
  46. if not dataset:
  47. raise ValueError("no dataset found")
  48. # get the process rule
  49. processing_rule = db.session.query(DatasetProcessRule). \
  50. filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
  51. first()
  52. index_type = dataset_document.doc_form
  53. index_processor = IndexProcessorFactory(index_type).init_index_processor()
  54. # extract
  55. text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict())
  56. # transform
  57. documents = self._transform(index_processor, dataset, text_docs, dataset_document.doc_language,
  58. processing_rule.to_dict())
  59. # save segment
  60. self._load_segments(dataset, dataset_document, documents)
  61. # load
  62. self._load(
  63. index_processor=index_processor,
  64. dataset=dataset,
  65. dataset_document=dataset_document,
  66. documents=documents
  67. )
  68. except DocumentIsPausedException:
  69. raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id))
  70. except ProviderTokenNotInitError as e:
  71. dataset_document.indexing_status = 'error'
  72. dataset_document.error = str(e.description)
  73. dataset_document.stopped_at = datetime.datetime.utcnow()
  74. db.session.commit()
  75. except ObjectDeletedError:
  76. logging.warning('Document deleted, document id: {}'.format(dataset_document.id))
  77. except Exception as e:
  78. logging.exception("consume document failed")
  79. dataset_document.indexing_status = 'error'
  80. dataset_document.error = str(e)
  81. dataset_document.stopped_at = datetime.datetime.utcnow()
  82. db.session.commit()
  83. def run_in_splitting_status(self, dataset_document: DatasetDocument):
  84. """Run the indexing process when the index_status is splitting."""
  85. try:
  86. # get dataset
  87. dataset = Dataset.query.filter_by(
  88. id=dataset_document.dataset_id
  89. ).first()
  90. if not dataset:
  91. raise ValueError("no dataset found")
  92. # get exist document_segment list and delete
  93. document_segments = DocumentSegment.query.filter_by(
  94. dataset_id=dataset.id,
  95. document_id=dataset_document.id
  96. ).all()
  97. for document_segment in document_segments:
  98. db.session.delete(document_segment)
  99. db.session.commit()
  100. # get the process rule
  101. processing_rule = db.session.query(DatasetProcessRule). \
  102. filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
  103. first()
  104. index_type = dataset_document.doc_form
  105. index_processor = IndexProcessorFactory(index_type).init_index_processor()
  106. # extract
  107. text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict())
  108. # transform
  109. documents = self._transform(index_processor, dataset, text_docs, dataset_document.doc_language,
  110. processing_rule.to_dict())
  111. # save segment
  112. self._load_segments(dataset, dataset_document, documents)
  113. # load
  114. self._load(
  115. index_processor=index_processor,
  116. dataset=dataset,
  117. dataset_document=dataset_document,
  118. documents=documents
  119. )
  120. except DocumentIsPausedException:
  121. raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id))
  122. except ProviderTokenNotInitError as e:
  123. dataset_document.indexing_status = 'error'
  124. dataset_document.error = str(e.description)
  125. dataset_document.stopped_at = datetime.datetime.utcnow()
  126. db.session.commit()
  127. except Exception as e:
  128. logging.exception("consume document failed")
  129. dataset_document.indexing_status = 'error'
  130. dataset_document.error = str(e)
  131. dataset_document.stopped_at = datetime.datetime.utcnow()
  132. db.session.commit()
  133. def run_in_indexing_status(self, dataset_document: DatasetDocument):
  134. """Run the indexing process when the index_status is indexing."""
  135. try:
  136. # get dataset
  137. dataset = Dataset.query.filter_by(
  138. id=dataset_document.dataset_id
  139. ).first()
  140. if not dataset:
  141. raise ValueError("no dataset found")
  142. # get exist document_segment list and delete
  143. document_segments = DocumentSegment.query.filter_by(
  144. dataset_id=dataset.id,
  145. document_id=dataset_document.id
  146. ).all()
  147. documents = []
  148. if document_segments:
  149. for document_segment in document_segments:
  150. # transform segment to node
  151. if document_segment.status != "completed":
  152. document = Document(
  153. page_content=document_segment.content,
  154. metadata={
  155. "doc_id": document_segment.index_node_id,
  156. "doc_hash": document_segment.index_node_hash,
  157. "document_id": document_segment.document_id,
  158. "dataset_id": document_segment.dataset_id,
  159. }
  160. )
  161. documents.append(document)
  162. # build index
  163. # get the process rule
  164. processing_rule = db.session.query(DatasetProcessRule). \
  165. filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
  166. first()
  167. index_type = dataset_document.doc_form
  168. index_processor = IndexProcessorFactory(index_type).init_index_processor()
  169. self._load(
  170. index_processor=index_processor,
  171. dataset=dataset,
  172. dataset_document=dataset_document,
  173. documents=documents
  174. )
  175. except DocumentIsPausedException:
  176. raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id))
  177. except ProviderTokenNotInitError as e:
  178. dataset_document.indexing_status = 'error'
  179. dataset_document.error = str(e.description)
  180. dataset_document.stopped_at = datetime.datetime.utcnow()
  181. db.session.commit()
  182. except Exception as e:
  183. logging.exception("consume document failed")
  184. dataset_document.indexing_status = 'error'
  185. dataset_document.error = str(e)
  186. dataset_document.stopped_at = datetime.datetime.utcnow()
  187. db.session.commit()
  188. def indexing_estimate(self, tenant_id: str, extract_settings: list[ExtractSetting], tmp_processing_rule: dict,
  189. doc_form: str = None, doc_language: str = 'English', dataset_id: str = None,
  190. indexing_technique: str = 'economy') -> dict:
  191. """
  192. Estimate the indexing for the document.
  193. """
  194. # check document limit
  195. features = FeatureService.get_features(tenant_id)
  196. if features.billing.enabled:
  197. count = len(extract_settings)
  198. batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT'])
  199. if count > batch_upload_limit:
  200. raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
  201. embedding_model_instance = None
  202. if dataset_id:
  203. dataset = Dataset.query.filter_by(
  204. id=dataset_id
  205. ).first()
  206. if not dataset:
  207. raise ValueError('Dataset not found.')
  208. if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality':
  209. if dataset.embedding_model_provider:
  210. embedding_model_instance = self.model_manager.get_model_instance(
  211. tenant_id=tenant_id,
  212. provider=dataset.embedding_model_provider,
  213. model_type=ModelType.TEXT_EMBEDDING,
  214. model=dataset.embedding_model
  215. )
  216. else:
  217. embedding_model_instance = self.model_manager.get_default_model_instance(
  218. tenant_id=tenant_id,
  219. model_type=ModelType.TEXT_EMBEDDING,
  220. )
  221. else:
  222. if indexing_technique == 'high_quality':
  223. embedding_model_instance = self.model_manager.get_default_model_instance(
  224. tenant_id=tenant_id,
  225. model_type=ModelType.TEXT_EMBEDDING,
  226. )
  227. tokens = 0
  228. preview_texts = []
  229. total_segments = 0
  230. total_price = 0
  231. currency = 'USD'
  232. index_type = doc_form
  233. index_processor = IndexProcessorFactory(index_type).init_index_processor()
  234. all_text_docs = []
  235. for extract_setting in extract_settings:
  236. # extract
  237. text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"])
  238. all_text_docs.extend(text_docs)
  239. processing_rule = DatasetProcessRule(
  240. mode=tmp_processing_rule["mode"],
  241. rules=json.dumps(tmp_processing_rule["rules"])
  242. )
  243. # get splitter
  244. splitter = self._get_splitter(processing_rule, embedding_model_instance)
  245. # split to documents
  246. documents = self._split_to_documents_for_estimate(
  247. text_docs=text_docs,
  248. splitter=splitter,
  249. processing_rule=processing_rule
  250. )
  251. total_segments += len(documents)
  252. for document in documents:
  253. if len(preview_texts) < 5:
  254. preview_texts.append(document.page_content)
  255. if indexing_technique == 'high_quality' or embedding_model_instance:
  256. embedding_model_type_instance = embedding_model_instance.model_type_instance
  257. embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
  258. tokens += embedding_model_type_instance.get_num_tokens(
  259. model=embedding_model_instance.model,
  260. credentials=embedding_model_instance.credentials,
  261. texts=[self.filter_string(document.page_content)]
  262. )
  263. if doc_form and doc_form == 'qa_model':
  264. model_instance = self.model_manager.get_default_model_instance(
  265. tenant_id=tenant_id,
  266. model_type=ModelType.LLM
  267. )
  268. model_type_instance = model_instance.model_type_instance
  269. model_type_instance = cast(LargeLanguageModel, model_type_instance)
  270. if len(preview_texts) > 0:
  271. # qa model document
  272. response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0],
  273. doc_language)
  274. document_qa_list = self.format_split_text(response)
  275. price_info = model_type_instance.get_price(
  276. model=model_instance.model,
  277. credentials=model_instance.credentials,
  278. price_type=PriceType.INPUT,
  279. tokens=total_segments * 2000,
  280. )
  281. return {
  282. "total_segments": total_segments * 20,
  283. "tokens": total_segments * 2000,
  284. "total_price": '{:f}'.format(price_info.total_amount),
  285. "currency": price_info.currency,
  286. "qa_preview": document_qa_list,
  287. "preview": preview_texts
  288. }
  289. if embedding_model_instance:
  290. embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_instance.model_type_instance)
  291. embedding_price_info = embedding_model_type_instance.get_price(
  292. model=embedding_model_instance.model,
  293. credentials=embedding_model_instance.credentials,
  294. price_type=PriceType.INPUT,
  295. tokens=tokens
  296. )
  297. total_price = '{:f}'.format(embedding_price_info.total_amount)
  298. currency = embedding_price_info.currency
  299. return {
  300. "total_segments": total_segments,
  301. "tokens": tokens,
  302. "total_price": total_price,
  303. "currency": currency,
  304. "preview": preview_texts
  305. }
  306. def _extract(self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict) \
  307. -> list[Document]:
  308. # load file
  309. if dataset_document.data_source_type not in ["upload_file", "notion_import"]:
  310. return []
  311. data_source_info = dataset_document.data_source_info_dict
  312. text_docs = []
  313. if dataset_document.data_source_type == 'upload_file':
  314. if not data_source_info or 'upload_file_id' not in data_source_info:
  315. raise ValueError("no upload file found")
  316. file_detail = db.session.query(UploadFile). \
  317. filter(UploadFile.id == data_source_info['upload_file_id']). \
  318. one_or_none()
  319. if file_detail:
  320. extract_setting = ExtractSetting(
  321. datasource_type="upload_file",
  322. upload_file=file_detail,
  323. document_model=dataset_document.doc_form
  324. )
  325. text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule['mode'])
  326. elif dataset_document.data_source_type == 'notion_import':
  327. if (not data_source_info or 'notion_workspace_id' not in data_source_info
  328. or 'notion_page_id' not in data_source_info):
  329. raise ValueError("no notion import info found")
  330. extract_setting = ExtractSetting(
  331. datasource_type="notion_import",
  332. notion_info={
  333. "notion_workspace_id": data_source_info['notion_workspace_id'],
  334. "notion_obj_id": data_source_info['notion_page_id'],
  335. "notion_page_type": data_source_info['type'],
  336. "document": dataset_document,
  337. "tenant_id": dataset_document.tenant_id
  338. },
  339. document_model=dataset_document.doc_form
  340. )
  341. text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule['mode'])
  342. # update document status to splitting
  343. self._update_document_index_status(
  344. document_id=dataset_document.id,
  345. after_indexing_status="splitting",
  346. extra_update_params={
  347. DatasetDocument.word_count: sum([len(text_doc.page_content) for text_doc in text_docs]),
  348. DatasetDocument.parsing_completed_at: datetime.datetime.utcnow()
  349. }
  350. )
  351. # replace doc id to document model id
  352. text_docs = cast(list[Document], text_docs)
  353. for text_doc in text_docs:
  354. text_doc.metadata['document_id'] = dataset_document.id
  355. text_doc.metadata['dataset_id'] = dataset_document.dataset_id
  356. return text_docs
  357. def filter_string(self, text):
  358. text = re.sub(r'<\|', '<', text)
  359. text = re.sub(r'\|>', '>', text)
  360. text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\xEF\xBF\xBE]', '', text)
  361. # Unicode U+FFFE
  362. text = re.sub('\uFFFE', '', text)
  363. return text
  364. def _get_splitter(self, processing_rule: DatasetProcessRule,
  365. embedding_model_instance: Optional[ModelInstance]) -> TextSplitter:
  366. """
  367. Get the NodeParser object according to the processing rule.
  368. """
  369. if processing_rule.mode == "custom":
  370. # The user-defined segmentation rule
  371. rules = json.loads(processing_rule.rules)
  372. segmentation = rules["segmentation"]
  373. if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > 1000:
  374. raise ValueError("Custom segment length should be between 50 and 1000.")
  375. separator = segmentation["separator"]
  376. if separator:
  377. separator = separator.replace('\\n', '\n')
  378. if 'chunk_overlap' in segmentation and segmentation['chunk_overlap']:
  379. chunk_overlap = segmentation['chunk_overlap']
  380. else:
  381. chunk_overlap = 0
  382. character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder(
  383. chunk_size=segmentation["max_tokens"],
  384. chunk_overlap=chunk_overlap,
  385. fixed_separator=separator,
  386. separators=["\n\n", "。", ".", " ", ""],
  387. embedding_model_instance=embedding_model_instance
  388. )
  389. else:
  390. # Automatic segmentation
  391. character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder(
  392. chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'],
  393. chunk_overlap=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['chunk_overlap'],
  394. separators=["\n\n", "。", ".", " ", ""],
  395. embedding_model_instance=embedding_model_instance
  396. )
  397. return character_splitter
  398. def _step_split(self, text_docs: list[Document], splitter: TextSplitter,
  399. dataset: Dataset, dataset_document: DatasetDocument, processing_rule: DatasetProcessRule) \
  400. -> list[Document]:
  401. """
  402. Split the text documents into documents and save them to the document segment.
  403. """
  404. documents = self._split_to_documents(
  405. text_docs=text_docs,
  406. splitter=splitter,
  407. processing_rule=processing_rule,
  408. tenant_id=dataset.tenant_id,
  409. document_form=dataset_document.doc_form,
  410. document_language=dataset_document.doc_language
  411. )
  412. # save node to document segment
  413. doc_store = DatasetDocumentStore(
  414. dataset=dataset,
  415. user_id=dataset_document.created_by,
  416. document_id=dataset_document.id
  417. )
  418. # add document segments
  419. doc_store.add_documents(documents)
  420. # update document status to indexing
  421. cur_time = datetime.datetime.utcnow()
  422. self._update_document_index_status(
  423. document_id=dataset_document.id,
  424. after_indexing_status="indexing",
  425. extra_update_params={
  426. DatasetDocument.cleaning_completed_at: cur_time,
  427. DatasetDocument.splitting_completed_at: cur_time,
  428. }
  429. )
  430. # update segment status to indexing
  431. self._update_segments_by_document(
  432. dataset_document_id=dataset_document.id,
  433. update_params={
  434. DocumentSegment.status: "indexing",
  435. DocumentSegment.indexing_at: datetime.datetime.utcnow()
  436. }
  437. )
  438. return documents
  439. def _split_to_documents(self, text_docs: list[Document], splitter: TextSplitter,
  440. processing_rule: DatasetProcessRule, tenant_id: str,
  441. document_form: str, document_language: str) -> list[Document]:
  442. """
  443. Split the text documents into nodes.
  444. """
  445. all_documents = []
  446. all_qa_documents = []
  447. for text_doc in text_docs:
  448. # document clean
  449. document_text = self._document_clean(text_doc.page_content, processing_rule)
  450. text_doc.page_content = document_text
  451. # parse document to nodes
  452. documents = splitter.split_documents([text_doc])
  453. split_documents = []
  454. for document_node in documents:
  455. if document_node.page_content.strip():
  456. doc_id = str(uuid.uuid4())
  457. hash = helper.generate_text_hash(document_node.page_content)
  458. document_node.metadata['doc_id'] = doc_id
  459. document_node.metadata['doc_hash'] = hash
  460. # delete Spliter character
  461. page_content = document_node.page_content
  462. if page_content.startswith(".") or page_content.startswith("。"):
  463. page_content = page_content[1:]
  464. else:
  465. page_content = page_content
  466. document_node.page_content = page_content
  467. if document_node.page_content:
  468. split_documents.append(document_node)
  469. all_documents.extend(split_documents)
  470. # processing qa document
  471. if document_form == 'qa_model':
  472. for i in range(0, len(all_documents), 10):
  473. threads = []
  474. sub_documents = all_documents[i:i + 10]
  475. for doc in sub_documents:
  476. document_format_thread = threading.Thread(target=self.format_qa_document, kwargs={
  477. 'flask_app': current_app._get_current_object(),
  478. 'tenant_id': tenant_id, 'document_node': doc, 'all_qa_documents': all_qa_documents,
  479. 'document_language': document_language})
  480. threads.append(document_format_thread)
  481. document_format_thread.start()
  482. for thread in threads:
  483. thread.join()
  484. return all_qa_documents
  485. return all_documents
  486. def format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents, document_language):
  487. format_documents = []
  488. if document_node.page_content is None or not document_node.page_content.strip():
  489. return
  490. with flask_app.app_context():
  491. try:
  492. # qa model document
  493. response = LLMGenerator.generate_qa_document(tenant_id, document_node.page_content, document_language)
  494. document_qa_list = self.format_split_text(response)
  495. qa_documents = []
  496. for result in document_qa_list:
  497. qa_document = Document(page_content=result['question'], metadata=document_node.metadata.copy())
  498. doc_id = str(uuid.uuid4())
  499. hash = helper.generate_text_hash(result['question'])
  500. qa_document.metadata['answer'] = result['answer']
  501. qa_document.metadata['doc_id'] = doc_id
  502. qa_document.metadata['doc_hash'] = hash
  503. qa_documents.append(qa_document)
  504. format_documents.extend(qa_documents)
  505. except Exception as e:
  506. logging.exception(e)
  507. all_qa_documents.extend(format_documents)
  508. def _split_to_documents_for_estimate(self, text_docs: list[Document], splitter: TextSplitter,
  509. processing_rule: DatasetProcessRule) -> list[Document]:
  510. """
  511. Split the text documents into nodes.
  512. """
  513. all_documents = []
  514. for text_doc in text_docs:
  515. # document clean
  516. document_text = self._document_clean(text_doc.page_content, processing_rule)
  517. text_doc.page_content = document_text
  518. # parse document to nodes
  519. documents = splitter.split_documents([text_doc])
  520. split_documents = []
  521. for document in documents:
  522. if document.page_content is None or not document.page_content.strip():
  523. continue
  524. doc_id = str(uuid.uuid4())
  525. hash = helper.generate_text_hash(document.page_content)
  526. document.metadata['doc_id'] = doc_id
  527. document.metadata['doc_hash'] = hash
  528. split_documents.append(document)
  529. all_documents.extend(split_documents)
  530. return all_documents
  531. def _document_clean(self, text: str, processing_rule: DatasetProcessRule) -> str:
  532. """
  533. Clean the document text according to the processing rules.
  534. """
  535. if processing_rule.mode == "automatic":
  536. rules = DatasetProcessRule.AUTOMATIC_RULES
  537. else:
  538. rules = json.loads(processing_rule.rules) if processing_rule.rules else {}
  539. if 'pre_processing_rules' in rules:
  540. pre_processing_rules = rules["pre_processing_rules"]
  541. for pre_processing_rule in pre_processing_rules:
  542. if pre_processing_rule["id"] == "remove_extra_spaces" and pre_processing_rule["enabled"] is True:
  543. # Remove extra spaces
  544. pattern = r'\n{3,}'
  545. text = re.sub(pattern, '\n\n', text)
  546. pattern = r'[\t\f\r\x20\u00a0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]{2,}'
  547. text = re.sub(pattern, ' ', text)
  548. elif pre_processing_rule["id"] == "remove_urls_emails" and pre_processing_rule["enabled"] is True:
  549. # Remove email
  550. pattern = r'([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)'
  551. text = re.sub(pattern, '', text)
  552. # Remove URL
  553. pattern = r'https?://[^\s]+'
  554. text = re.sub(pattern, '', text)
  555. return text
  556. def format_split_text(self, text):
  557. regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q\d+:|$)"
  558. matches = re.findall(regex, text, re.UNICODE)
  559. return [
  560. {
  561. "question": q,
  562. "answer": re.sub(r"\n\s*", "\n", a.strip())
  563. }
  564. for q, a in matches if q and a
  565. ]
  566. def _load(self, index_processor: BaseIndexProcessor, dataset: Dataset,
  567. dataset_document: DatasetDocument, documents: list[Document]) -> None:
  568. """
  569. insert index and update document/segment status to completed
  570. """
  571. embedding_model_instance = None
  572. if dataset.indexing_technique == 'high_quality':
  573. embedding_model_instance = self.model_manager.get_model_instance(
  574. tenant_id=dataset.tenant_id,
  575. provider=dataset.embedding_model_provider,
  576. model_type=ModelType.TEXT_EMBEDDING,
  577. model=dataset.embedding_model
  578. )
  579. # chunk nodes by chunk size
  580. indexing_start_at = time.perf_counter()
  581. tokens = 0
  582. chunk_size = 10
  583. embedding_model_type_instance = None
  584. if embedding_model_instance:
  585. embedding_model_type_instance = embedding_model_instance.model_type_instance
  586. embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
  587. with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
  588. futures = []
  589. for i in range(0, len(documents), chunk_size):
  590. chunk_documents = documents[i:i + chunk_size]
  591. futures.append(executor.submit(self._process_chunk, current_app._get_current_object(), index_processor,
  592. chunk_documents, dataset,
  593. dataset_document, embedding_model_instance,
  594. embedding_model_type_instance))
  595. for future in futures:
  596. tokens += future.result()
  597. indexing_end_at = time.perf_counter()
  598. # update document status to completed
  599. self._update_document_index_status(
  600. document_id=dataset_document.id,
  601. after_indexing_status="completed",
  602. extra_update_params={
  603. DatasetDocument.tokens: tokens,
  604. DatasetDocument.completed_at: datetime.datetime.utcnow(),
  605. DatasetDocument.indexing_latency: indexing_end_at - indexing_start_at,
  606. }
  607. )
  608. def _process_chunk(self, flask_app, index_processor, chunk_documents, dataset, dataset_document,
  609. embedding_model_instance, embedding_model_type_instance):
  610. with flask_app.app_context():
  611. # check document is paused
  612. self._check_document_paused_status(dataset_document.id)
  613. tokens = 0
  614. if dataset.indexing_technique == 'high_quality' or embedding_model_type_instance:
  615. tokens += sum(
  616. embedding_model_type_instance.get_num_tokens(
  617. embedding_model_instance.model,
  618. embedding_model_instance.credentials,
  619. [document.page_content]
  620. )
  621. for document in chunk_documents
  622. )
  623. # load index
  624. index_processor.load(dataset, chunk_documents)
  625. document_ids = [document.metadata['doc_id'] for document in chunk_documents]
  626. db.session.query(DocumentSegment).filter(
  627. DocumentSegment.document_id == dataset_document.id,
  628. DocumentSegment.index_node_id.in_(document_ids),
  629. DocumentSegment.status == "indexing"
  630. ).update({
  631. DocumentSegment.status: "completed",
  632. DocumentSegment.enabled: True,
  633. DocumentSegment.completed_at: datetime.datetime.utcnow()
  634. })
  635. db.session.commit()
  636. return tokens
  637. def _check_document_paused_status(self, document_id: str):
  638. indexing_cache_key = 'document_{}_is_paused'.format(document_id)
  639. result = redis_client.get(indexing_cache_key)
  640. if result:
  641. raise DocumentIsPausedException()
  642. def _update_document_index_status(self, document_id: str, after_indexing_status: str,
  643. extra_update_params: Optional[dict] = None) -> None:
  644. """
  645. Update the document indexing status.
  646. """
  647. count = DatasetDocument.query.filter_by(id=document_id, is_paused=True).count()
  648. if count > 0:
  649. raise DocumentIsPausedException()
  650. document = DatasetDocument.query.filter_by(id=document_id).first()
  651. if not document:
  652. raise DocumentIsDeletedPausedException()
  653. update_params = {
  654. DatasetDocument.indexing_status: after_indexing_status
  655. }
  656. if extra_update_params:
  657. update_params.update(extra_update_params)
  658. DatasetDocument.query.filter_by(id=document_id).update(update_params)
  659. db.session.commit()
  660. def _update_segments_by_document(self, dataset_document_id: str, update_params: dict) -> None:
  661. """
  662. Update the document segment by document id.
  663. """
  664. DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params)
  665. db.session.commit()
  666. def batch_add_segments(self, segments: list[DocumentSegment], dataset: Dataset):
  667. """
  668. Batch add segments index processing
  669. """
  670. documents = []
  671. for segment in segments:
  672. document = Document(
  673. page_content=segment.content,
  674. metadata={
  675. "doc_id": segment.index_node_id,
  676. "doc_hash": segment.index_node_hash,
  677. "document_id": segment.document_id,
  678. "dataset_id": segment.dataset_id,
  679. }
  680. )
  681. documents.append(document)
  682. # save vector index
  683. index_type = dataset.doc_form
  684. index_processor = IndexProcessorFactory(index_type).init_index_processor()
  685. index_processor.load(dataset, documents)
  686. def _transform(self, index_processor: BaseIndexProcessor, dataset: Dataset,
  687. text_docs: list[Document], doc_language: str, process_rule: dict) -> list[Document]:
  688. # get embedding model instance
  689. embedding_model_instance = None
  690. if dataset.indexing_technique == 'high_quality':
  691. if dataset.embedding_model_provider:
  692. embedding_model_instance = self.model_manager.get_model_instance(
  693. tenant_id=dataset.tenant_id,
  694. provider=dataset.embedding_model_provider,
  695. model_type=ModelType.TEXT_EMBEDDING,
  696. model=dataset.embedding_model
  697. )
  698. else:
  699. embedding_model_instance = self.model_manager.get_default_model_instance(
  700. tenant_id=dataset.tenant_id,
  701. model_type=ModelType.TEXT_EMBEDDING,
  702. )
  703. documents = index_processor.transform(text_docs, embedding_model_instance=embedding_model_instance,
  704. process_rule=process_rule, tenant_id=dataset.tenant_id,
  705. doc_language=doc_language)
  706. return documents
  707. def _load_segments(self, dataset, dataset_document, documents):
  708. # save node to document segment
  709. doc_store = DatasetDocumentStore(
  710. dataset=dataset,
  711. user_id=dataset_document.created_by,
  712. document_id=dataset_document.id
  713. )
  714. # add document segments
  715. doc_store.add_documents(documents)
  716. # update document status to indexing
  717. cur_time = datetime.datetime.utcnow()
  718. self._update_document_index_status(
  719. document_id=dataset_document.id,
  720. after_indexing_status="indexing",
  721. extra_update_params={
  722. DatasetDocument.cleaning_completed_at: cur_time,
  723. DatasetDocument.splitting_completed_at: cur_time,
  724. }
  725. )
  726. # update segment status to indexing
  727. self._update_segments_by_document(
  728. dataset_document_id=dataset_document.id,
  729. update_params={
  730. DocumentSegment.status: "indexing",
  731. DocumentSegment.indexing_at: datetime.datetime.utcnow()
  732. }
  733. )
  734. pass
  735. class DocumentIsPausedException(Exception):
  736. pass
  737. class DocumentIsDeletedPausedException(Exception):
  738. pass