indexing_runner.py 39 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913
  1. import datetime
  2. import json
  3. import logging
  4. import re
  5. import threading
  6. import time
  7. import uuid
  8. from typing import AbstractSet, Any, Collection, List, Literal, Optional, Type, Union, cast
  9. from core.data_loader.file_extractor import FileExtractor
  10. from core.data_loader.loader.notion import NotionLoader
  11. from core.docstore.dataset_docstore import DatasetDocumentStore
  12. from core.errors.error import ProviderTokenNotInitError
  13. from core.generator.llm_generator import LLMGenerator
  14. from core.index.index import IndexBuilder
  15. from core.model_manager import ModelManager, ModelInstance
  16. from core.model_runtime.entities.model_entities import ModelType, PriceType
  17. from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
  18. from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
  19. from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
  20. from core.spiltter.fixed_text_splitter import EnhanceRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitter
  21. from extensions.ext_database import db
  22. from extensions.ext_redis import redis_client
  23. from extensions.ext_storage import storage
  24. from flask import Flask, current_app
  25. from flask_login import current_user
  26. from langchain.schema import Document
  27. from langchain.text_splitter import TS, TextSplitter, TokenTextSplitter
  28. from libs import helper
  29. from models.dataset import Dataset, DatasetProcessRule
  30. from models.dataset import Document as DatasetDocument
  31. from models.dataset import DocumentSegment
  32. from models.model import UploadFile
  33. from models.source import DataSourceBinding
  34. from sqlalchemy.orm.exc import ObjectDeletedError
  35. class IndexingRunner:
  36. def __init__(self):
  37. self.storage = storage
  38. self.model_manager = ModelManager()
  39. def run(self, dataset_documents: List[DatasetDocument]):
  40. """Run the indexing process."""
  41. for dataset_document in dataset_documents:
  42. try:
  43. # get dataset
  44. dataset = Dataset.query.filter_by(
  45. id=dataset_document.dataset_id
  46. ).first()
  47. if not dataset:
  48. raise ValueError("no dataset found")
  49. # get the process rule
  50. processing_rule = db.session.query(DatasetProcessRule). \
  51. filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
  52. first()
  53. # load file
  54. text_docs = self._load_data(dataset_document, processing_rule.mode == 'automatic')
  55. # get embedding model instance
  56. embedding_model_instance = None
  57. if dataset.indexing_technique == 'high_quality':
  58. if dataset.embedding_model_provider:
  59. embedding_model_instance = self.model_manager.get_model_instance(
  60. tenant_id=dataset.tenant_id,
  61. provider=dataset.embedding_model_provider,
  62. model_type=ModelType.TEXT_EMBEDDING,
  63. model=dataset.embedding_model
  64. )
  65. else:
  66. embedding_model_instance = self.model_manager.get_default_model_instance(
  67. tenant_id=dataset.tenant_id,
  68. model_type=ModelType.TEXT_EMBEDDING,
  69. )
  70. # get splitter
  71. splitter = self._get_splitter(processing_rule, embedding_model_instance)
  72. # split to documents
  73. documents = self._step_split(
  74. text_docs=text_docs,
  75. splitter=splitter,
  76. dataset=dataset,
  77. dataset_document=dataset_document,
  78. processing_rule=processing_rule
  79. )
  80. self._build_index(
  81. dataset=dataset,
  82. dataset_document=dataset_document,
  83. documents=documents
  84. )
  85. except DocumentIsPausedException:
  86. raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id))
  87. except ProviderTokenNotInitError as e:
  88. dataset_document.indexing_status = 'error'
  89. dataset_document.error = str(e.description)
  90. dataset_document.stopped_at = datetime.datetime.utcnow()
  91. db.session.commit()
  92. except ObjectDeletedError:
  93. logging.warning('Document deleted, document id: {}'.format(dataset_document.id))
  94. except Exception as e:
  95. logging.exception("consume document failed")
  96. dataset_document.indexing_status = 'error'
  97. dataset_document.error = str(e)
  98. dataset_document.stopped_at = datetime.datetime.utcnow()
  99. db.session.commit()
  100. def run_in_splitting_status(self, dataset_document: DatasetDocument):
  101. """Run the indexing process when the index_status is splitting."""
  102. try:
  103. # get dataset
  104. dataset = Dataset.query.filter_by(
  105. id=dataset_document.dataset_id
  106. ).first()
  107. if not dataset:
  108. raise ValueError("no dataset found")
  109. # get exist document_segment list and delete
  110. document_segments = DocumentSegment.query.filter_by(
  111. dataset_id=dataset.id,
  112. document_id=dataset_document.id
  113. ).all()
  114. for document_segment in document_segments:
  115. db.session.delete(document_segment)
  116. db.session.commit()
  117. # get the process rule
  118. processing_rule = db.session.query(DatasetProcessRule). \
  119. filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
  120. first()
  121. # load file
  122. text_docs = self._load_data(dataset_document, processing_rule.mode == 'automatic')
  123. # get embedding model instance
  124. embedding_model_instance = None
  125. if dataset.indexing_technique == 'high_quality':
  126. if dataset.embedding_model_provider:
  127. embedding_model_instance = self.model_manager.get_model_instance(
  128. tenant_id=dataset.tenant_id,
  129. provider=dataset.embedding_model_provider,
  130. model_type=ModelType.TEXT_EMBEDDING,
  131. model=dataset.embedding_model
  132. )
  133. else:
  134. embedding_model_instance = self.model_manager.get_default_model_instance(
  135. tenant_id=dataset.tenant_id,
  136. model_type=ModelType.TEXT_EMBEDDING,
  137. )
  138. # get splitter
  139. splitter = self._get_splitter(processing_rule, embedding_model_instance)
  140. # split to documents
  141. documents = self._step_split(
  142. text_docs=text_docs,
  143. splitter=splitter,
  144. dataset=dataset,
  145. dataset_document=dataset_document,
  146. processing_rule=processing_rule
  147. )
  148. # build index
  149. self._build_index(
  150. dataset=dataset,
  151. dataset_document=dataset_document,
  152. documents=documents
  153. )
  154. except DocumentIsPausedException:
  155. raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id))
  156. except ProviderTokenNotInitError as e:
  157. dataset_document.indexing_status = 'error'
  158. dataset_document.error = str(e.description)
  159. dataset_document.stopped_at = datetime.datetime.utcnow()
  160. db.session.commit()
  161. except Exception as e:
  162. logging.exception("consume document failed")
  163. dataset_document.indexing_status = 'error'
  164. dataset_document.error = str(e)
  165. dataset_document.stopped_at = datetime.datetime.utcnow()
  166. db.session.commit()
  167. def run_in_indexing_status(self, dataset_document: DatasetDocument):
  168. """Run the indexing process when the index_status is indexing."""
  169. try:
  170. # get dataset
  171. dataset = Dataset.query.filter_by(
  172. id=dataset_document.dataset_id
  173. ).first()
  174. if not dataset:
  175. raise ValueError("no dataset found")
  176. # get exist document_segment list and delete
  177. document_segments = DocumentSegment.query.filter_by(
  178. dataset_id=dataset.id,
  179. document_id=dataset_document.id
  180. ).all()
  181. documents = []
  182. if document_segments:
  183. for document_segment in document_segments:
  184. # transform segment to node
  185. if document_segment.status != "completed":
  186. document = Document(
  187. page_content=document_segment.content,
  188. metadata={
  189. "doc_id": document_segment.index_node_id,
  190. "doc_hash": document_segment.index_node_hash,
  191. "document_id": document_segment.document_id,
  192. "dataset_id": document_segment.dataset_id,
  193. }
  194. )
  195. documents.append(document)
  196. # build index
  197. self._build_index(
  198. dataset=dataset,
  199. dataset_document=dataset_document,
  200. documents=documents
  201. )
  202. except DocumentIsPausedException:
  203. raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id))
  204. except ProviderTokenNotInitError as e:
  205. dataset_document.indexing_status = 'error'
  206. dataset_document.error = str(e.description)
  207. dataset_document.stopped_at = datetime.datetime.utcnow()
  208. db.session.commit()
  209. except Exception as e:
  210. logging.exception("consume document failed")
  211. dataset_document.indexing_status = 'error'
  212. dataset_document.error = str(e)
  213. dataset_document.stopped_at = datetime.datetime.utcnow()
  214. db.session.commit()
  215. def file_indexing_estimate(self, tenant_id: str, file_details: List[UploadFile], tmp_processing_rule: dict,
  216. doc_form: str = None, doc_language: str = 'English', dataset_id: str = None,
  217. indexing_technique: str = 'economy') -> dict:
  218. """
  219. Estimate the indexing for the document.
  220. """
  221. embedding_model_instance = None
  222. if dataset_id:
  223. dataset = Dataset.query.filter_by(
  224. id=dataset_id
  225. ).first()
  226. if not dataset:
  227. raise ValueError('Dataset not found.')
  228. if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality':
  229. if dataset.embedding_model_provider:
  230. embedding_model_instance = self.model_manager.get_model_instance(
  231. tenant_id=tenant_id,
  232. provider=dataset.embedding_model_provider,
  233. model_type=ModelType.TEXT_EMBEDDING,
  234. model=dataset.embedding_model
  235. )
  236. else:
  237. embedding_model_instance = self.model_manager.get_default_model_instance(
  238. tenant_id=tenant_id,
  239. model_type=ModelType.TEXT_EMBEDDING,
  240. )
  241. else:
  242. if indexing_technique == 'high_quality':
  243. embedding_model_instance = self.model_manager.get_default_model_instance(
  244. tenant_id=tenant_id,
  245. model_type=ModelType.TEXT_EMBEDDING,
  246. )
  247. tokens = 0
  248. preview_texts = []
  249. total_segments = 0
  250. total_price = 0
  251. currency = 'USD'
  252. for file_detail in file_details:
  253. processing_rule = DatasetProcessRule(
  254. mode=tmp_processing_rule["mode"],
  255. rules=json.dumps(tmp_processing_rule["rules"])
  256. )
  257. # load data from file
  258. text_docs = FileExtractor.load(file_detail, is_automatic=processing_rule.mode == 'automatic')
  259. # get splitter
  260. splitter = self._get_splitter(processing_rule, embedding_model_instance)
  261. # split to documents
  262. documents = self._split_to_documents_for_estimate(
  263. text_docs=text_docs,
  264. splitter=splitter,
  265. processing_rule=processing_rule
  266. )
  267. total_segments += len(documents)
  268. for document in documents:
  269. if len(preview_texts) < 5:
  270. preview_texts.append(document.page_content)
  271. if indexing_technique == 'high_quality' or embedding_model_instance:
  272. embedding_model_type_instance = embedding_model_instance.model_type_instance
  273. embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
  274. tokens += embedding_model_type_instance.get_num_tokens(
  275. model=embedding_model_instance.model,
  276. credentials=embedding_model_instance.credentials,
  277. texts=[self.filter_string(document.page_content)]
  278. )
  279. if doc_form and doc_form == 'qa_model':
  280. model_instance = self.model_manager.get_default_model_instance(
  281. tenant_id=tenant_id,
  282. model_type=ModelType.LLM
  283. )
  284. model_type_instance = model_instance.model_type_instance
  285. model_type_instance = cast(LargeLanguageModel, model_type_instance)
  286. if len(preview_texts) > 0:
  287. # qa model document
  288. response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0],
  289. doc_language)
  290. document_qa_list = self.format_split_text(response)
  291. price_info = model_type_instance.get_price(
  292. model=model_instance.model,
  293. credentials=model_instance.credentials,
  294. price_type=PriceType.INPUT,
  295. tokens=total_segments * 2000,
  296. )
  297. return {
  298. "total_segments": total_segments * 20,
  299. "tokens": total_segments * 2000,
  300. "total_price": '{:f}'.format(price_info.total_amount),
  301. "currency": price_info.currency,
  302. "qa_preview": document_qa_list,
  303. "preview": preview_texts
  304. }
  305. if embedding_model_instance:
  306. embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_instance.model_type_instance)
  307. embedding_price_info = embedding_model_type_instance.get_price(
  308. model=embedding_model_instance.model,
  309. credentials=embedding_model_instance.credentials,
  310. price_type=PriceType.INPUT,
  311. tokens=tokens
  312. )
  313. total_price = '{:f}'.format(embedding_price_info.total_amount)
  314. currency = embedding_price_info.currency
  315. return {
  316. "total_segments": total_segments,
  317. "tokens": tokens,
  318. "total_price": total_price,
  319. "currency": currency,
  320. "preview": preview_texts
  321. }
  322. def notion_indexing_estimate(self, tenant_id: str, notion_info_list: list, tmp_processing_rule: dict,
  323. doc_form: str = None, doc_language: str = 'English', dataset_id: str = None,
  324. indexing_technique: str = 'economy') -> dict:
  325. """
  326. Estimate the indexing for the document.
  327. """
  328. embedding_model_instance = None
  329. if dataset_id:
  330. dataset = Dataset.query.filter_by(
  331. id=dataset_id
  332. ).first()
  333. if not dataset:
  334. raise ValueError('Dataset not found.')
  335. if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality':
  336. if dataset.embedding_model_provider:
  337. embedding_model_instance = self.model_manager.get_model_instance(
  338. tenant_id=tenant_id,
  339. provider=dataset.embedding_model_provider,
  340. model_type=ModelType.TEXT_EMBEDDING,
  341. model=dataset.embedding_model
  342. )
  343. else:
  344. embedding_model_instance = self.model_manager.get_default_model_instance(
  345. tenant_id=tenant_id,
  346. model_type=ModelType.TEXT_EMBEDDING,
  347. )
  348. else:
  349. if indexing_technique == 'high_quality':
  350. embedding_model_instance = self.model_manager.get_default_model_instance(
  351. tenant_id=tenant_id,
  352. model_type=ModelType.TEXT_EMBEDDING
  353. )
  354. # load data from notion
  355. tokens = 0
  356. preview_texts = []
  357. total_segments = 0
  358. total_price = 0
  359. currency = 'USD'
  360. for notion_info in notion_info_list:
  361. workspace_id = notion_info['workspace_id']
  362. data_source_binding = DataSourceBinding.query.filter(
  363. db.and_(
  364. DataSourceBinding.tenant_id == current_user.current_tenant_id,
  365. DataSourceBinding.provider == 'notion',
  366. DataSourceBinding.disabled == False,
  367. DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
  368. )
  369. ).first()
  370. if not data_source_binding:
  371. raise ValueError('Data source binding not found.')
  372. for page in notion_info['pages']:
  373. loader = NotionLoader(
  374. notion_access_token=data_source_binding.access_token,
  375. notion_workspace_id=workspace_id,
  376. notion_obj_id=page['page_id'],
  377. notion_page_type=page['type']
  378. )
  379. documents = loader.load()
  380. processing_rule = DatasetProcessRule(
  381. mode=tmp_processing_rule["mode"],
  382. rules=json.dumps(tmp_processing_rule["rules"])
  383. )
  384. # get splitter
  385. splitter = self._get_splitter(processing_rule, embedding_model_instance)
  386. # split to documents
  387. documents = self._split_to_documents_for_estimate(
  388. text_docs=documents,
  389. splitter=splitter,
  390. processing_rule=processing_rule
  391. )
  392. total_segments += len(documents)
  393. embedding_model_type_instance = None
  394. if embedding_model_instance:
  395. embedding_model_type_instance = embedding_model_instance.model_type_instance
  396. embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
  397. for document in documents:
  398. if len(preview_texts) < 5:
  399. preview_texts.append(document.page_content)
  400. if indexing_technique == 'high_quality' and embedding_model_type_instance:
  401. tokens += embedding_model_type_instance.get_num_tokens(
  402. model=embedding_model_instance.model,
  403. credentials=embedding_model_instance.credentials,
  404. texts=[document.page_content]
  405. )
  406. if doc_form and doc_form == 'qa_model':
  407. model_instance = self.model_manager.get_default_model_instance(
  408. tenant_id=tenant_id,
  409. model_type=ModelType.LLM
  410. )
  411. model_type_instance = model_instance.model_type_instance
  412. model_type_instance = cast(LargeLanguageModel, model_type_instance)
  413. if len(preview_texts) > 0:
  414. # qa model document
  415. response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0],
  416. doc_language)
  417. document_qa_list = self.format_split_text(response)
  418. price_info = model_type_instance.get_price(
  419. model=model_instance.model,
  420. credentials=model_instance.credentials,
  421. price_type=PriceType.INPUT,
  422. tokens=total_segments * 2000,
  423. )
  424. return {
  425. "total_segments": total_segments * 20,
  426. "tokens": total_segments * 2000,
  427. "total_price": '{:f}'.format(price_info.total_amount),
  428. "currency": price_info.currency,
  429. "qa_preview": document_qa_list,
  430. "preview": preview_texts
  431. }
  432. if embedding_model_instance:
  433. embedding_model_type_instance = embedding_model_instance.model_type_instance
  434. embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
  435. embedding_price_info = embedding_model_type_instance.get_price(
  436. model=embedding_model_instance.model,
  437. credentials=embedding_model_instance.credentials,
  438. price_type=PriceType.INPUT,
  439. tokens=tokens
  440. )
  441. total_price = '{:f}'.format(embedding_price_info.total_amount)
  442. currency = embedding_price_info.currency
  443. return {
  444. "total_segments": total_segments,
  445. "tokens": tokens,
  446. "total_price": total_price,
  447. "currency": currency,
  448. "preview": preview_texts
  449. }
  450. def _load_data(self, dataset_document: DatasetDocument, automatic: bool = False) -> List[Document]:
  451. # load file
  452. if dataset_document.data_source_type not in ["upload_file", "notion_import"]:
  453. return []
  454. data_source_info = dataset_document.data_source_info_dict
  455. text_docs = []
  456. if dataset_document.data_source_type == 'upload_file':
  457. if not data_source_info or 'upload_file_id' not in data_source_info:
  458. raise ValueError("no upload file found")
  459. file_detail = db.session.query(UploadFile). \
  460. filter(UploadFile.id == data_source_info['upload_file_id']). \
  461. one_or_none()
  462. if file_detail:
  463. text_docs = FileExtractor.load(file_detail, is_automatic=automatic)
  464. elif dataset_document.data_source_type == 'notion_import':
  465. loader = NotionLoader.from_document(dataset_document)
  466. text_docs = loader.load()
  467. # update document status to splitting
  468. self._update_document_index_status(
  469. document_id=dataset_document.id,
  470. after_indexing_status="splitting",
  471. extra_update_params={
  472. DatasetDocument.word_count: sum([len(text_doc.page_content) for text_doc in text_docs]),
  473. DatasetDocument.parsing_completed_at: datetime.datetime.utcnow()
  474. }
  475. )
  476. # replace doc id to document model id
  477. text_docs = cast(List[Document], text_docs)
  478. for text_doc in text_docs:
  479. # remove invalid symbol
  480. text_doc.page_content = self.filter_string(text_doc.page_content)
  481. text_doc.metadata['document_id'] = dataset_document.id
  482. text_doc.metadata['dataset_id'] = dataset_document.dataset_id
  483. return text_docs
  484. def filter_string(self, text):
  485. text = re.sub(r'<\|', '<', text)
  486. text = re.sub(r'\|>', '>', text)
  487. text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\xEF\xBF\xBE]', '', text)
  488. # Unicode U+FFFE
  489. text = re.sub(u'\uFFFE', '', text)
  490. return text
  491. def _get_splitter(self, processing_rule: DatasetProcessRule,
  492. embedding_model_instance: Optional[ModelInstance]) -> TextSplitter:
  493. """
  494. Get the NodeParser object according to the processing rule.
  495. """
  496. if processing_rule.mode == "custom":
  497. # The user-defined segmentation rule
  498. rules = json.loads(processing_rule.rules)
  499. segmentation = rules["segmentation"]
  500. if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > 1000:
  501. raise ValueError("Custom segment length should be between 50 and 1000.")
  502. separator = segmentation["separator"]
  503. if separator:
  504. separator = separator.replace('\\n', '\n')
  505. character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder(
  506. chunk_size=segmentation["max_tokens"],
  507. chunk_overlap=segmentation.get('chunk_overlap', 0),
  508. fixed_separator=separator,
  509. separators=["\n\n", "。", ".", " ", ""],
  510. embedding_model_instance=embedding_model_instance
  511. )
  512. else:
  513. # Automatic segmentation
  514. character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder(
  515. chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'],
  516. chunk_overlap=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['chunk_overlap'],
  517. separators=["\n\n", "。", ".", " ", ""],
  518. embedding_model_instance=embedding_model_instance
  519. )
  520. return character_splitter
  521. def _step_split(self, text_docs: List[Document], splitter: TextSplitter,
  522. dataset: Dataset, dataset_document: DatasetDocument, processing_rule: DatasetProcessRule) \
  523. -> List[Document]:
  524. """
  525. Split the text documents into documents and save them to the document segment.
  526. """
  527. documents = self._split_to_documents(
  528. text_docs=text_docs,
  529. splitter=splitter,
  530. processing_rule=processing_rule,
  531. tenant_id=dataset.tenant_id,
  532. document_form=dataset_document.doc_form,
  533. document_language=dataset_document.doc_language
  534. )
  535. # save node to document segment
  536. doc_store = DatasetDocumentStore(
  537. dataset=dataset,
  538. user_id=dataset_document.created_by,
  539. document_id=dataset_document.id
  540. )
  541. # add document segments
  542. doc_store.add_documents(documents)
  543. # update document status to indexing
  544. cur_time = datetime.datetime.utcnow()
  545. self._update_document_index_status(
  546. document_id=dataset_document.id,
  547. after_indexing_status="indexing",
  548. extra_update_params={
  549. DatasetDocument.cleaning_completed_at: cur_time,
  550. DatasetDocument.splitting_completed_at: cur_time,
  551. }
  552. )
  553. # update segment status to indexing
  554. self._update_segments_by_document(
  555. dataset_document_id=dataset_document.id,
  556. update_params={
  557. DocumentSegment.status: "indexing",
  558. DocumentSegment.indexing_at: datetime.datetime.utcnow()
  559. }
  560. )
  561. return documents
  562. def _split_to_documents(self, text_docs: List[Document], splitter: TextSplitter,
  563. processing_rule: DatasetProcessRule, tenant_id: str,
  564. document_form: str, document_language: str) -> List[Document]:
  565. """
  566. Split the text documents into nodes.
  567. """
  568. all_documents = []
  569. all_qa_documents = []
  570. for text_doc in text_docs:
  571. # document clean
  572. document_text = self._document_clean(text_doc.page_content, processing_rule)
  573. text_doc.page_content = document_text
  574. # parse document to nodes
  575. documents = splitter.split_documents([text_doc])
  576. split_documents = []
  577. for document_node in documents:
  578. if document_node.page_content.strip():
  579. doc_id = str(uuid.uuid4())
  580. hash = helper.generate_text_hash(document_node.page_content)
  581. document_node.metadata['doc_id'] = doc_id
  582. document_node.metadata['doc_hash'] = hash
  583. # delete Spliter character
  584. page_content = document_node.page_content
  585. if page_content.startswith(".") or page_content.startswith("。"):
  586. page_content = page_content[1:]
  587. else:
  588. page_content = page_content
  589. document_node.page_content = page_content
  590. if document_node.page_content:
  591. split_documents.append(document_node)
  592. all_documents.extend(split_documents)
  593. # processing qa document
  594. if document_form == 'qa_model':
  595. for i in range(0, len(all_documents), 10):
  596. threads = []
  597. sub_documents = all_documents[i:i + 10]
  598. for doc in sub_documents:
  599. document_format_thread = threading.Thread(target=self.format_qa_document, kwargs={
  600. 'flask_app': current_app._get_current_object(),
  601. 'tenant_id': tenant_id, 'document_node': doc, 'all_qa_documents': all_qa_documents,
  602. 'document_language': document_language})
  603. threads.append(document_format_thread)
  604. document_format_thread.start()
  605. for thread in threads:
  606. thread.join()
  607. return all_qa_documents
  608. return all_documents
  609. def format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents, document_language):
  610. format_documents = []
  611. if document_node.page_content is None or not document_node.page_content.strip():
  612. return
  613. with flask_app.app_context():
  614. try:
  615. # qa model document
  616. response = LLMGenerator.generate_qa_document(tenant_id, document_node.page_content, document_language)
  617. document_qa_list = self.format_split_text(response)
  618. qa_documents = []
  619. for result in document_qa_list:
  620. qa_document = Document(page_content=result['question'], metadata=document_node.metadata.copy())
  621. doc_id = str(uuid.uuid4())
  622. hash = helper.generate_text_hash(result['question'])
  623. qa_document.metadata['answer'] = result['answer']
  624. qa_document.metadata['doc_id'] = doc_id
  625. qa_document.metadata['doc_hash'] = hash
  626. qa_documents.append(qa_document)
  627. format_documents.extend(qa_documents)
  628. except Exception as e:
  629. logging.exception(e)
  630. all_qa_documents.extend(format_documents)
  631. def _split_to_documents_for_estimate(self, text_docs: List[Document], splitter: TextSplitter,
  632. processing_rule: DatasetProcessRule) -> List[Document]:
  633. """
  634. Split the text documents into nodes.
  635. """
  636. all_documents = []
  637. for text_doc in text_docs:
  638. # document clean
  639. document_text = self._document_clean(text_doc.page_content, processing_rule)
  640. text_doc.page_content = document_text
  641. # parse document to nodes
  642. documents = splitter.split_documents([text_doc])
  643. split_documents = []
  644. for document in documents:
  645. if document.page_content is None or not document.page_content.strip():
  646. continue
  647. doc_id = str(uuid.uuid4())
  648. hash = helper.generate_text_hash(document.page_content)
  649. document.metadata['doc_id'] = doc_id
  650. document.metadata['doc_hash'] = hash
  651. split_documents.append(document)
  652. all_documents.extend(split_documents)
  653. return all_documents
  654. def _document_clean(self, text: str, processing_rule: DatasetProcessRule) -> str:
  655. """
  656. Clean the document text according to the processing rules.
  657. """
  658. if processing_rule.mode == "automatic":
  659. rules = DatasetProcessRule.AUTOMATIC_RULES
  660. else:
  661. rules = json.loads(processing_rule.rules) if processing_rule.rules else {}
  662. if 'pre_processing_rules' in rules:
  663. pre_processing_rules = rules["pre_processing_rules"]
  664. for pre_processing_rule in pre_processing_rules:
  665. if pre_processing_rule["id"] == "remove_extra_spaces" and pre_processing_rule["enabled"] is True:
  666. # Remove extra spaces
  667. pattern = r'\n{3,}'
  668. text = re.sub(pattern, '\n\n', text)
  669. pattern = r'[\t\f\r\x20\u00a0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]{2,}'
  670. text = re.sub(pattern, ' ', text)
  671. elif pre_processing_rule["id"] == "remove_urls_emails" and pre_processing_rule["enabled"] is True:
  672. # Remove email
  673. pattern = r'([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)'
  674. text = re.sub(pattern, '', text)
  675. # Remove URL
  676. pattern = r'https?://[^\s]+'
  677. text = re.sub(pattern, '', text)
  678. return text
  679. def format_split_text(self, text):
  680. regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q\d+:|$)"
  681. matches = re.findall(regex, text, re.UNICODE)
  682. return [
  683. {
  684. "question": q,
  685. "answer": re.sub(r"\n\s*", "\n", a.strip())
  686. }
  687. for q, a in matches if q and a
  688. ]
  689. def _build_index(self, dataset: Dataset, dataset_document: DatasetDocument, documents: List[Document]) -> None:
  690. """
  691. Build the index for the document.
  692. """
  693. vector_index = IndexBuilder.get_index(dataset, 'high_quality')
  694. keyword_table_index = IndexBuilder.get_index(dataset, 'economy')
  695. embedding_model_instance = None
  696. if dataset.indexing_technique == 'high_quality':
  697. embedding_model_instance = self.model_manager.get_model_instance(
  698. tenant_id=dataset.tenant_id,
  699. provider=dataset.embedding_model_provider,
  700. model_type=ModelType.TEXT_EMBEDDING,
  701. model=dataset.embedding_model
  702. )
  703. # chunk nodes by chunk size
  704. indexing_start_at = time.perf_counter()
  705. tokens = 0
  706. chunk_size = 100
  707. embedding_model_type_instance = None
  708. if embedding_model_instance:
  709. embedding_model_type_instance = embedding_model_instance.model_type_instance
  710. embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
  711. for i in range(0, len(documents), chunk_size):
  712. # check document is paused
  713. self._check_document_paused_status(dataset_document.id)
  714. chunk_documents = documents[i:i + chunk_size]
  715. if dataset.indexing_technique == 'high_quality' or embedding_model_type_instance:
  716. tokens += sum(
  717. embedding_model_type_instance.get_num_tokens(
  718. embedding_model_instance.model,
  719. embedding_model_instance.credentials,
  720. [document.page_content]
  721. )
  722. for document in chunk_documents
  723. )
  724. # save vector index
  725. if vector_index:
  726. vector_index.add_texts(chunk_documents)
  727. # save keyword index
  728. keyword_table_index.add_texts(chunk_documents)
  729. document_ids = [document.metadata['doc_id'] for document in chunk_documents]
  730. db.session.query(DocumentSegment).filter(
  731. DocumentSegment.document_id == dataset_document.id,
  732. DocumentSegment.index_node_id.in_(document_ids),
  733. DocumentSegment.status == "indexing"
  734. ).update({
  735. DocumentSegment.status: "completed",
  736. DocumentSegment.enabled: True,
  737. DocumentSegment.completed_at: datetime.datetime.utcnow()
  738. })
  739. db.session.commit()
  740. indexing_end_at = time.perf_counter()
  741. # update document status to completed
  742. self._update_document_index_status(
  743. document_id=dataset_document.id,
  744. after_indexing_status="completed",
  745. extra_update_params={
  746. DatasetDocument.tokens: tokens,
  747. DatasetDocument.completed_at: datetime.datetime.utcnow(),
  748. DatasetDocument.indexing_latency: indexing_end_at - indexing_start_at,
  749. }
  750. )
  751. def _check_document_paused_status(self, document_id: str):
  752. indexing_cache_key = 'document_{}_is_paused'.format(document_id)
  753. result = redis_client.get(indexing_cache_key)
  754. if result:
  755. raise DocumentIsPausedException()
  756. def _update_document_index_status(self, document_id: str, after_indexing_status: str,
  757. extra_update_params: Optional[dict] = None) -> None:
  758. """
  759. Update the document indexing status.
  760. """
  761. count = DatasetDocument.query.filter_by(id=document_id, is_paused=True).count()
  762. if count > 0:
  763. raise DocumentIsPausedException()
  764. document = DatasetDocument.query.filter_by(id=document_id).first()
  765. if not document:
  766. raise DocumentIsDeletedPausedException()
  767. update_params = {
  768. DatasetDocument.indexing_status: after_indexing_status
  769. }
  770. if extra_update_params:
  771. update_params.update(extra_update_params)
  772. DatasetDocument.query.filter_by(id=document_id).update(update_params)
  773. db.session.commit()
  774. def _update_segments_by_document(self, dataset_document_id: str, update_params: dict) -> None:
  775. """
  776. Update the document segment by document id.
  777. """
  778. DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params)
  779. db.session.commit()
  780. def batch_add_segments(self, segments: List[DocumentSegment], dataset: Dataset):
  781. """
  782. Batch add segments index processing
  783. """
  784. documents = []
  785. for segment in segments:
  786. document = Document(
  787. page_content=segment.content,
  788. metadata={
  789. "doc_id": segment.index_node_id,
  790. "doc_hash": segment.index_node_hash,
  791. "document_id": segment.document_id,
  792. "dataset_id": segment.dataset_id,
  793. }
  794. )
  795. documents.append(document)
  796. # save vector index
  797. index = IndexBuilder.get_index(dataset, 'high_quality')
  798. if index:
  799. index.add_texts(documents, duplicate_check=True)
  800. # save keyword index
  801. index = IndexBuilder.get_index(dataset, 'economy')
  802. if index:
  803. index.add_texts(documents)
  804. class DocumentIsPausedException(Exception):
  805. pass
  806. class DocumentIsDeletedPausedException(Exception):
  807. pass