indexing_runner.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867
  1. import datetime
  2. import json
  3. import logging
  4. import re
  5. import threading
  6. import time
  7. import uuid
  8. from typing import Optional, List, cast, Type, Union, Literal, AbstractSet, Collection, Any
  9. from flask import current_app, Flask
  10. from flask_login import current_user
  11. from langchain.schema import Document
  12. from langchain.text_splitter import TextSplitter, TS, TokenTextSplitter
  13. from sqlalchemy.orm.exc import ObjectDeletedError
  14. from core.data_loader.file_extractor import FileExtractor
  15. from core.data_loader.loader.notion import NotionLoader
  16. from core.docstore.dataset_docstore import DatasetDocumentStore
  17. from core.generator.llm_generator import LLMGenerator
  18. from core.index.index import IndexBuilder
  19. from core.model_manager import ModelManager
  20. from core.errors.error import ProviderTokenNotInitError
  21. from core.model_runtime.entities.model_entities import ModelType, PriceType
  22. from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
  23. from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
  24. from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
  25. from core.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter, EnhanceRecursiveCharacterTextSplitter
  26. from extensions.ext_database import db
  27. from extensions.ext_redis import redis_client
  28. from extensions.ext_storage import storage
  29. from libs import helper
  30. from models.dataset import Document as DatasetDocument
  31. from models.dataset import Dataset, DocumentSegment, DatasetProcessRule
  32. from models.model import UploadFile
  33. from models.source import DataSourceBinding
  34. class IndexingRunner:
  35. def __init__(self):
  36. self.storage = storage
  37. self.model_manager = ModelManager()
  38. def run(self, dataset_documents: List[DatasetDocument]):
  39. """Run the indexing process."""
  40. for dataset_document in dataset_documents:
  41. try:
  42. # get dataset
  43. dataset = Dataset.query.filter_by(
  44. id=dataset_document.dataset_id
  45. ).first()
  46. if not dataset:
  47. raise ValueError("no dataset found")
  48. # get the process rule
  49. processing_rule = db.session.query(DatasetProcessRule). \
  50. filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
  51. first()
  52. # load file
  53. text_docs = self._load_data(dataset_document, processing_rule.mode == 'automatic')
  54. # get splitter
  55. splitter = self._get_splitter(processing_rule)
  56. # split to documents
  57. documents = self._step_split(
  58. text_docs=text_docs,
  59. splitter=splitter,
  60. dataset=dataset,
  61. dataset_document=dataset_document,
  62. processing_rule=processing_rule
  63. )
  64. self._build_index(
  65. dataset=dataset,
  66. dataset_document=dataset_document,
  67. documents=documents
  68. )
  69. except DocumentIsPausedException:
  70. raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id))
  71. except ProviderTokenNotInitError as e:
  72. dataset_document.indexing_status = 'error'
  73. dataset_document.error = str(e.description)
  74. dataset_document.stopped_at = datetime.datetime.utcnow()
  75. db.session.commit()
  76. except ObjectDeletedError:
  77. logging.warning('Document deleted, document id: {}'.format(dataset_document.id))
  78. except Exception as e:
  79. logging.exception("consume document failed")
  80. dataset_document.indexing_status = 'error'
  81. dataset_document.error = str(e)
  82. dataset_document.stopped_at = datetime.datetime.utcnow()
  83. db.session.commit()
  84. def run_in_splitting_status(self, dataset_document: DatasetDocument):
  85. """Run the indexing process when the index_status is splitting."""
  86. try:
  87. # get dataset
  88. dataset = Dataset.query.filter_by(
  89. id=dataset_document.dataset_id
  90. ).first()
  91. if not dataset:
  92. raise ValueError("no dataset found")
  93. # get exist document_segment list and delete
  94. document_segments = DocumentSegment.query.filter_by(
  95. dataset_id=dataset.id,
  96. document_id=dataset_document.id
  97. ).all()
  98. for document_segment in document_segments:
  99. db.session.delete(document_segment)
  100. db.session.commit()
  101. # get the process rule
  102. processing_rule = db.session.query(DatasetProcessRule). \
  103. filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
  104. first()
  105. # load file
  106. text_docs = self._load_data(dataset_document, processing_rule.mode == 'automatic')
  107. # get splitter
  108. splitter = self._get_splitter(processing_rule)
  109. # split to documents
  110. documents = self._step_split(
  111. text_docs=text_docs,
  112. splitter=splitter,
  113. dataset=dataset,
  114. dataset_document=dataset_document,
  115. processing_rule=processing_rule
  116. )
  117. # build index
  118. self._build_index(
  119. dataset=dataset,
  120. dataset_document=dataset_document,
  121. documents=documents
  122. )
  123. except DocumentIsPausedException:
  124. raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id))
  125. except ProviderTokenNotInitError as e:
  126. dataset_document.indexing_status = 'error'
  127. dataset_document.error = str(e.description)
  128. dataset_document.stopped_at = datetime.datetime.utcnow()
  129. db.session.commit()
  130. except Exception as e:
  131. logging.exception("consume document failed")
  132. dataset_document.indexing_status = 'error'
  133. dataset_document.error = str(e)
  134. dataset_document.stopped_at = datetime.datetime.utcnow()
  135. db.session.commit()
  136. def run_in_indexing_status(self, dataset_document: DatasetDocument):
  137. """Run the indexing process when the index_status is indexing."""
  138. try:
  139. # get dataset
  140. dataset = Dataset.query.filter_by(
  141. id=dataset_document.dataset_id
  142. ).first()
  143. if not dataset:
  144. raise ValueError("no dataset found")
  145. # get exist document_segment list and delete
  146. document_segments = DocumentSegment.query.filter_by(
  147. dataset_id=dataset.id,
  148. document_id=dataset_document.id
  149. ).all()
  150. documents = []
  151. if document_segments:
  152. for document_segment in document_segments:
  153. # transform segment to node
  154. if document_segment.status != "completed":
  155. document = Document(
  156. page_content=document_segment.content,
  157. metadata={
  158. "doc_id": document_segment.index_node_id,
  159. "doc_hash": document_segment.index_node_hash,
  160. "document_id": document_segment.document_id,
  161. "dataset_id": document_segment.dataset_id,
  162. }
  163. )
  164. documents.append(document)
  165. # build index
  166. self._build_index(
  167. dataset=dataset,
  168. dataset_document=dataset_document,
  169. documents=documents
  170. )
  171. except DocumentIsPausedException:
  172. raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id))
  173. except ProviderTokenNotInitError as e:
  174. dataset_document.indexing_status = 'error'
  175. dataset_document.error = str(e.description)
  176. dataset_document.stopped_at = datetime.datetime.utcnow()
  177. db.session.commit()
  178. except Exception as e:
  179. logging.exception("consume document failed")
  180. dataset_document.indexing_status = 'error'
  181. dataset_document.error = str(e)
  182. dataset_document.stopped_at = datetime.datetime.utcnow()
  183. db.session.commit()
  184. def file_indexing_estimate(self, tenant_id: str, file_details: List[UploadFile], tmp_processing_rule: dict,
  185. doc_form: str = None, doc_language: str = 'English', dataset_id: str = None,
  186. indexing_technique: str = 'economy') -> dict:
  187. """
  188. Estimate the indexing for the document.
  189. """
  190. embedding_model_instance = None
  191. if dataset_id:
  192. dataset = Dataset.query.filter_by(
  193. id=dataset_id
  194. ).first()
  195. if not dataset:
  196. raise ValueError('Dataset not found.')
  197. if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality':
  198. if dataset.embedding_model_provider:
  199. embedding_model_instance = self.model_manager.get_model_instance(
  200. tenant_id=tenant_id,
  201. provider=dataset.embedding_model_provider,
  202. model_type=ModelType.TEXT_EMBEDDING,
  203. model=dataset.embedding_model
  204. )
  205. else:
  206. embedding_model_instance = self.model_manager.get_default_model_instance(
  207. tenant_id=tenant_id,
  208. model_type=ModelType.TEXT_EMBEDDING,
  209. )
  210. else:
  211. if indexing_technique == 'high_quality':
  212. embedding_model_instance = self.model_manager.get_default_model_instance(
  213. tenant_id=tenant_id,
  214. model_type=ModelType.TEXT_EMBEDDING,
  215. )
  216. tokens = 0
  217. preview_texts = []
  218. total_segments = 0
  219. for file_detail in file_details:
  220. processing_rule = DatasetProcessRule(
  221. mode=tmp_processing_rule["mode"],
  222. rules=json.dumps(tmp_processing_rule["rules"])
  223. )
  224. # load data from file
  225. text_docs = FileExtractor.load(file_detail, is_automatic=processing_rule.mode == 'automatic')
  226. # get splitter
  227. splitter = self._get_splitter(processing_rule)
  228. # split to documents
  229. documents = self._split_to_documents_for_estimate(
  230. text_docs=text_docs,
  231. splitter=splitter,
  232. processing_rule=processing_rule
  233. )
  234. total_segments += len(documents)
  235. for document in documents:
  236. if len(preview_texts) < 5:
  237. preview_texts.append(document.page_content)
  238. if indexing_technique == 'high_quality' or embedding_model_instance:
  239. embedding_model_type_instance = embedding_model_instance.model_type_instance
  240. embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
  241. tokens += embedding_model_type_instance.get_num_tokens(
  242. model=embedding_model_instance.model,
  243. credentials=embedding_model_instance.credentials,
  244. texts=[self.filter_string(document.page_content)]
  245. )
  246. if doc_form and doc_form == 'qa_model':
  247. model_instance = self.model_manager.get_default_model_instance(
  248. tenant_id=tenant_id,
  249. model_type=ModelType.LLM
  250. )
  251. model_type_instance = model_instance.model_type_instance
  252. model_type_instance = cast(LargeLanguageModel, model_type_instance)
  253. if len(preview_texts) > 0:
  254. # qa model document
  255. response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0],
  256. doc_language)
  257. document_qa_list = self.format_split_text(response)
  258. price_info = model_type_instance.get_price(
  259. model=model_instance.model,
  260. credentials=model_instance.credentials,
  261. price_type=PriceType.INPUT,
  262. tokens=total_segments * 2000,
  263. )
  264. return {
  265. "total_segments": total_segments * 20,
  266. "tokens": total_segments * 2000,
  267. "total_price": '{:f}'.format(price_info.total_amount),
  268. "currency": price_info.currency,
  269. "qa_preview": document_qa_list,
  270. "preview": preview_texts
  271. }
  272. if embedding_model_instance:
  273. embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_instance.model_type_instance)
  274. embedding_price_info = embedding_model_type_instance.get_price(
  275. model=embedding_model_instance.model,
  276. credentials=embedding_model_instance.credentials,
  277. price_type=PriceType.INPUT,
  278. tokens=tokens
  279. )
  280. return {
  281. "total_segments": total_segments,
  282. "tokens": tokens,
  283. "total_price": '{:f}'.format(embedding_price_info.total_amount) if embedding_model_instance else 0,
  284. "currency": embedding_price_info.currency if embedding_model_instance else 'USD',
  285. "preview": preview_texts
  286. }
  287. def notion_indexing_estimate(self, tenant_id: str, notion_info_list: list, tmp_processing_rule: dict,
  288. doc_form: str = None, doc_language: str = 'English', dataset_id: str = None,
  289. indexing_technique: str = 'economy') -> dict:
  290. """
  291. Estimate the indexing for the document.
  292. """
  293. embedding_model_instance = None
  294. if dataset_id:
  295. dataset = Dataset.query.filter_by(
  296. id=dataset_id
  297. ).first()
  298. if not dataset:
  299. raise ValueError('Dataset not found.')
  300. if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality':
  301. if dataset.embedding_model_provider:
  302. embedding_model_instance = self.model_manager.get_model_instance(
  303. tenant_id=tenant_id,
  304. provider=dataset.embedding_model_provider,
  305. model_type=ModelType.TEXT_EMBEDDING,
  306. model=dataset.embedding_model
  307. )
  308. else:
  309. embedding_model_instance = self.model_manager.get_default_model_instance(
  310. tenant_id=tenant_id,
  311. model_type=ModelType.TEXT_EMBEDDING,
  312. )
  313. else:
  314. if indexing_technique == 'high_quality':
  315. embedding_model_instance = self.model_manager.get_default_model_instance(
  316. tenant_id=tenant_id,
  317. model_type=ModelType.TEXT_EMBEDDING
  318. )
  319. # load data from notion
  320. tokens = 0
  321. preview_texts = []
  322. total_segments = 0
  323. for notion_info in notion_info_list:
  324. workspace_id = notion_info['workspace_id']
  325. data_source_binding = DataSourceBinding.query.filter(
  326. db.and_(
  327. DataSourceBinding.tenant_id == current_user.current_tenant_id,
  328. DataSourceBinding.provider == 'notion',
  329. DataSourceBinding.disabled == False,
  330. DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
  331. )
  332. ).first()
  333. if not data_source_binding:
  334. raise ValueError('Data source binding not found.')
  335. for page in notion_info['pages']:
  336. loader = NotionLoader(
  337. notion_access_token=data_source_binding.access_token,
  338. notion_workspace_id=workspace_id,
  339. notion_obj_id=page['page_id'],
  340. notion_page_type=page['type']
  341. )
  342. documents = loader.load()
  343. processing_rule = DatasetProcessRule(
  344. mode=tmp_processing_rule["mode"],
  345. rules=json.dumps(tmp_processing_rule["rules"])
  346. )
  347. # get splitter
  348. splitter = self._get_splitter(processing_rule)
  349. # split to documents
  350. documents = self._split_to_documents_for_estimate(
  351. text_docs=documents,
  352. splitter=splitter,
  353. processing_rule=processing_rule
  354. )
  355. total_segments += len(documents)
  356. embedding_model_type_instance = None
  357. if embedding_model_instance:
  358. embedding_model_type_instance = embedding_model_instance.model_type_instance
  359. embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
  360. for document in documents:
  361. if len(preview_texts) < 5:
  362. preview_texts.append(document.page_content)
  363. if indexing_technique == 'high_quality' and embedding_model_type_instance:
  364. tokens += embedding_model_type_instance.get_num_tokens(
  365. model=embedding_model_instance.model,
  366. credentials=embedding_model_instance.credentials,
  367. texts=[document.page_content]
  368. )
  369. if doc_form and doc_form == 'qa_model':
  370. model_instance = self.model_manager.get_default_model_instance(
  371. tenant_id=tenant_id,
  372. model_type=ModelType.LLM
  373. )
  374. model_type_instance = model_instance.model_type_instance
  375. model_type_instance = cast(LargeLanguageModel, model_type_instance)
  376. if len(preview_texts) > 0:
  377. # qa model document
  378. response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0],
  379. doc_language)
  380. document_qa_list = self.format_split_text(response)
  381. price_info = model_type_instance.get_price(
  382. model=model_instance.model,
  383. credentials=model_instance.credentials,
  384. price_type=PriceType.INPUT,
  385. tokens=total_segments * 2000,
  386. )
  387. return {
  388. "total_segments": total_segments * 20,
  389. "tokens": total_segments * 2000,
  390. "total_price": '{:f}'.format(price_info.total_amount),
  391. "currency": price_info.currency,
  392. "qa_preview": document_qa_list,
  393. "preview": preview_texts
  394. }
  395. embedding_model_type_instance = embedding_model_instance.model_type_instance
  396. embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
  397. embedding_price_info = embedding_model_type_instance.get_price(
  398. model=embedding_model_instance.model,
  399. credentials=embedding_model_instance.credentials,
  400. price_type=PriceType.INPUT,
  401. tokens=tokens
  402. )
  403. return {
  404. "total_segments": total_segments,
  405. "tokens": tokens,
  406. "total_price": '{:f}'.format(embedding_price_info.total_amount) if embedding_model_instance else 0,
  407. "currency": embedding_price_info.currency if embedding_model_instance else 'USD',
  408. "preview": preview_texts
  409. }
  410. def _load_data(self, dataset_document: DatasetDocument, automatic: bool = False) -> List[Document]:
  411. # load file
  412. if dataset_document.data_source_type not in ["upload_file", "notion_import"]:
  413. return []
  414. data_source_info = dataset_document.data_source_info_dict
  415. text_docs = []
  416. if dataset_document.data_source_type == 'upload_file':
  417. if not data_source_info or 'upload_file_id' not in data_source_info:
  418. raise ValueError("no upload file found")
  419. file_detail = db.session.query(UploadFile). \
  420. filter(UploadFile.id == data_source_info['upload_file_id']). \
  421. one_or_none()
  422. if file_detail:
  423. text_docs = FileExtractor.load(file_detail, is_automatic=automatic)
  424. elif dataset_document.data_source_type == 'notion_import':
  425. loader = NotionLoader.from_document(dataset_document)
  426. text_docs = loader.load()
  427. # update document status to splitting
  428. self._update_document_index_status(
  429. document_id=dataset_document.id,
  430. after_indexing_status="splitting",
  431. extra_update_params={
  432. DatasetDocument.word_count: sum([len(text_doc.page_content) for text_doc in text_docs]),
  433. DatasetDocument.parsing_completed_at: datetime.datetime.utcnow()
  434. }
  435. )
  436. # replace doc id to document model id
  437. text_docs = cast(List[Document], text_docs)
  438. for text_doc in text_docs:
  439. # remove invalid symbol
  440. text_doc.page_content = self.filter_string(text_doc.page_content)
  441. text_doc.metadata['document_id'] = dataset_document.id
  442. text_doc.metadata['dataset_id'] = dataset_document.dataset_id
  443. return text_docs
  444. def filter_string(self, text):
  445. text = re.sub(r'<\|', '<', text)
  446. text = re.sub(r'\|>', '>', text)
  447. text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\x80-\xFF]', '', text)
  448. return text
  449. def _get_splitter(self, processing_rule: DatasetProcessRule) -> TextSplitter:
  450. """
  451. Get the NodeParser object according to the processing rule.
  452. """
  453. if processing_rule.mode == "custom":
  454. # The user-defined segmentation rule
  455. rules = json.loads(processing_rule.rules)
  456. segmentation = rules["segmentation"]
  457. if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > 1000:
  458. raise ValueError("Custom segment length should be between 50 and 1000.")
  459. separator = segmentation["separator"]
  460. if separator:
  461. separator = separator.replace('\\n', '\n')
  462. character_splitter = FixedRecursiveCharacterTextSplitter.from_gpt2_encoder(
  463. chunk_size=segmentation["max_tokens"],
  464. chunk_overlap=0,
  465. fixed_separator=separator,
  466. separators=["\n\n", "。", ".", " ", ""]
  467. )
  468. else:
  469. # Automatic segmentation
  470. character_splitter = EnhanceRecursiveCharacterTextSplitter.from_gpt2_encoder(
  471. chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'],
  472. chunk_overlap=0,
  473. separators=["\n\n", "。", ".", " ", ""]
  474. )
  475. return character_splitter
  476. def _step_split(self, text_docs: List[Document], splitter: TextSplitter,
  477. dataset: Dataset, dataset_document: DatasetDocument, processing_rule: DatasetProcessRule) \
  478. -> List[Document]:
  479. """
  480. Split the text documents into documents and save them to the document segment.
  481. """
  482. documents = self._split_to_documents(
  483. text_docs=text_docs,
  484. splitter=splitter,
  485. processing_rule=processing_rule,
  486. tenant_id=dataset.tenant_id,
  487. document_form=dataset_document.doc_form,
  488. document_language=dataset_document.doc_language
  489. )
  490. # save node to document segment
  491. doc_store = DatasetDocumentStore(
  492. dataset=dataset,
  493. user_id=dataset_document.created_by,
  494. document_id=dataset_document.id
  495. )
  496. # add document segments
  497. doc_store.add_documents(documents)
  498. # update document status to indexing
  499. cur_time = datetime.datetime.utcnow()
  500. self._update_document_index_status(
  501. document_id=dataset_document.id,
  502. after_indexing_status="indexing",
  503. extra_update_params={
  504. DatasetDocument.cleaning_completed_at: cur_time,
  505. DatasetDocument.splitting_completed_at: cur_time,
  506. }
  507. )
  508. # update segment status to indexing
  509. self._update_segments_by_document(
  510. dataset_document_id=dataset_document.id,
  511. update_params={
  512. DocumentSegment.status: "indexing",
  513. DocumentSegment.indexing_at: datetime.datetime.utcnow()
  514. }
  515. )
  516. return documents
  517. def _split_to_documents(self, text_docs: List[Document], splitter: TextSplitter,
  518. processing_rule: DatasetProcessRule, tenant_id: str,
  519. document_form: str, document_language: str) -> List[Document]:
  520. """
  521. Split the text documents into nodes.
  522. """
  523. all_documents = []
  524. all_qa_documents = []
  525. for text_doc in text_docs:
  526. # document clean
  527. document_text = self._document_clean(text_doc.page_content, processing_rule)
  528. text_doc.page_content = document_text
  529. # parse document to nodes
  530. documents = splitter.split_documents([text_doc])
  531. split_documents = []
  532. for document_node in documents:
  533. if document_node.page_content.strip():
  534. doc_id = str(uuid.uuid4())
  535. hash = helper.generate_text_hash(document_node.page_content)
  536. document_node.metadata['doc_id'] = doc_id
  537. document_node.metadata['doc_hash'] = hash
  538. # delete Spliter character
  539. page_content = document_node.page_content
  540. if page_content.startswith(".") or page_content.startswith("。"):
  541. page_content = page_content[1:]
  542. else:
  543. page_content = page_content
  544. document_node.page_content = page_content
  545. split_documents.append(document_node)
  546. all_documents.extend(split_documents)
  547. # processing qa document
  548. if document_form == 'qa_model':
  549. for i in range(0, len(all_documents), 10):
  550. threads = []
  551. sub_documents = all_documents[i:i + 10]
  552. for doc in sub_documents:
  553. document_format_thread = threading.Thread(target=self.format_qa_document, kwargs={
  554. 'flask_app': current_app._get_current_object(),
  555. 'tenant_id': tenant_id, 'document_node': doc, 'all_qa_documents': all_qa_documents,
  556. 'document_language': document_language})
  557. threads.append(document_format_thread)
  558. document_format_thread.start()
  559. for thread in threads:
  560. thread.join()
  561. return all_qa_documents
  562. return all_documents
  563. def format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents, document_language):
  564. format_documents = []
  565. if document_node.page_content is None or not document_node.page_content.strip():
  566. return
  567. with flask_app.app_context():
  568. try:
  569. # qa model document
  570. response = LLMGenerator.generate_qa_document(tenant_id, document_node.page_content, document_language)
  571. document_qa_list = self.format_split_text(response)
  572. qa_documents = []
  573. for result in document_qa_list:
  574. qa_document = Document(page_content=result['question'], metadata=document_node.metadata.copy())
  575. doc_id = str(uuid.uuid4())
  576. hash = helper.generate_text_hash(result['question'])
  577. qa_document.metadata['answer'] = result['answer']
  578. qa_document.metadata['doc_id'] = doc_id
  579. qa_document.metadata['doc_hash'] = hash
  580. qa_documents.append(qa_document)
  581. format_documents.extend(qa_documents)
  582. except Exception as e:
  583. logging.exception(e)
  584. all_qa_documents.extend(format_documents)
  585. def _split_to_documents_for_estimate(self, text_docs: List[Document], splitter: TextSplitter,
  586. processing_rule: DatasetProcessRule) -> List[Document]:
  587. """
  588. Split the text documents into nodes.
  589. """
  590. all_documents = []
  591. for text_doc in text_docs:
  592. # document clean
  593. document_text = self._document_clean(text_doc.page_content, processing_rule)
  594. text_doc.page_content = document_text
  595. # parse document to nodes
  596. documents = splitter.split_documents([text_doc])
  597. split_documents = []
  598. for document in documents:
  599. if document.page_content is None or not document.page_content.strip():
  600. continue
  601. doc_id = str(uuid.uuid4())
  602. hash = helper.generate_text_hash(document.page_content)
  603. document.metadata['doc_id'] = doc_id
  604. document.metadata['doc_hash'] = hash
  605. split_documents.append(document)
  606. all_documents.extend(split_documents)
  607. return all_documents
  608. def _document_clean(self, text: str, processing_rule: DatasetProcessRule) -> str:
  609. """
  610. Clean the document text according to the processing rules.
  611. """
  612. if processing_rule.mode == "automatic":
  613. rules = DatasetProcessRule.AUTOMATIC_RULES
  614. else:
  615. rules = json.loads(processing_rule.rules) if processing_rule.rules else {}
  616. if 'pre_processing_rules' in rules:
  617. pre_processing_rules = rules["pre_processing_rules"]
  618. for pre_processing_rule in pre_processing_rules:
  619. if pre_processing_rule["id"] == "remove_extra_spaces" and pre_processing_rule["enabled"] is True:
  620. # Remove extra spaces
  621. pattern = r'\n{3,}'
  622. text = re.sub(pattern, '\n\n', text)
  623. pattern = r'[\t\f\r\x20\u00a0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]{2,}'
  624. text = re.sub(pattern, ' ', text)
  625. elif pre_processing_rule["id"] == "remove_urls_emails" and pre_processing_rule["enabled"] is True:
  626. # Remove email
  627. pattern = r'([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)'
  628. text = re.sub(pattern, '', text)
  629. # Remove URL
  630. pattern = r'https?://[^\s]+'
  631. text = re.sub(pattern, '', text)
  632. return text
  633. def format_split_text(self, text):
  634. regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q\d+:|$)"
  635. matches = re.findall(regex, text, re.UNICODE)
  636. return [
  637. {
  638. "question": q,
  639. "answer": re.sub(r"\n\s*", "\n", a.strip())
  640. }
  641. for q, a in matches if q and a
  642. ]
  643. def _build_index(self, dataset: Dataset, dataset_document: DatasetDocument, documents: List[Document]) -> None:
  644. """
  645. Build the index for the document.
  646. """
  647. vector_index = IndexBuilder.get_index(dataset, 'high_quality')
  648. keyword_table_index = IndexBuilder.get_index(dataset, 'economy')
  649. embedding_model_instance = None
  650. if dataset.indexing_technique == 'high_quality':
  651. embedding_model_instance = self.model_manager.get_model_instance(
  652. tenant_id=dataset.tenant_id,
  653. provider=dataset.embedding_model_provider,
  654. model_type=ModelType.TEXT_EMBEDDING,
  655. model=dataset.embedding_model
  656. )
  657. # chunk nodes by chunk size
  658. indexing_start_at = time.perf_counter()
  659. tokens = 0
  660. chunk_size = 100
  661. embedding_model_type_instance = None
  662. if embedding_model_instance:
  663. embedding_model_type_instance = embedding_model_instance.model_type_instance
  664. embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
  665. for i in range(0, len(documents), chunk_size):
  666. # check document is paused
  667. self._check_document_paused_status(dataset_document.id)
  668. chunk_documents = documents[i:i + chunk_size]
  669. if dataset.indexing_technique == 'high_quality' or embedding_model_type_instance:
  670. tokens += sum(
  671. embedding_model_type_instance.get_num_tokens(
  672. embedding_model_instance.model,
  673. embedding_model_instance.credentials,
  674. [document.page_content]
  675. )
  676. for document in chunk_documents
  677. )
  678. # save vector index
  679. if vector_index:
  680. vector_index.add_texts(chunk_documents)
  681. # save keyword index
  682. keyword_table_index.add_texts(chunk_documents)
  683. document_ids = [document.metadata['doc_id'] for document in chunk_documents]
  684. db.session.query(DocumentSegment).filter(
  685. DocumentSegment.document_id == dataset_document.id,
  686. DocumentSegment.index_node_id.in_(document_ids),
  687. DocumentSegment.status == "indexing"
  688. ).update({
  689. DocumentSegment.status: "completed",
  690. DocumentSegment.enabled: True,
  691. DocumentSegment.completed_at: datetime.datetime.utcnow()
  692. })
  693. db.session.commit()
  694. indexing_end_at = time.perf_counter()
  695. # update document status to completed
  696. self._update_document_index_status(
  697. document_id=dataset_document.id,
  698. after_indexing_status="completed",
  699. extra_update_params={
  700. DatasetDocument.tokens: tokens,
  701. DatasetDocument.completed_at: datetime.datetime.utcnow(),
  702. DatasetDocument.indexing_latency: indexing_end_at - indexing_start_at,
  703. }
  704. )
  705. def _check_document_paused_status(self, document_id: str):
  706. indexing_cache_key = 'document_{}_is_paused'.format(document_id)
  707. result = redis_client.get(indexing_cache_key)
  708. if result:
  709. raise DocumentIsPausedException()
  710. def _update_document_index_status(self, document_id: str, after_indexing_status: str,
  711. extra_update_params: Optional[dict] = None) -> None:
  712. """
  713. Update the document indexing status.
  714. """
  715. count = DatasetDocument.query.filter_by(id=document_id, is_paused=True).count()
  716. if count > 0:
  717. raise DocumentIsPausedException()
  718. document = DatasetDocument.query.filter_by(id=document_id).first()
  719. if not document:
  720. raise DocumentIsDeletedPausedException()
  721. update_params = {
  722. DatasetDocument.indexing_status: after_indexing_status
  723. }
  724. if extra_update_params:
  725. update_params.update(extra_update_params)
  726. DatasetDocument.query.filter_by(id=document_id).update(update_params)
  727. db.session.commit()
  728. def _update_segments_by_document(self, dataset_document_id: str, update_params: dict) -> None:
  729. """
  730. Update the document segment by document id.
  731. """
  732. DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params)
  733. db.session.commit()
  734. def batch_add_segments(self, segments: List[DocumentSegment], dataset: Dataset):
  735. """
  736. Batch add segments index processing
  737. """
  738. documents = []
  739. for segment in segments:
  740. document = Document(
  741. page_content=segment.content,
  742. metadata={
  743. "doc_id": segment.index_node_id,
  744. "doc_hash": segment.index_node_hash,
  745. "document_id": segment.document_id,
  746. "dataset_id": segment.dataset_id,
  747. }
  748. )
  749. documents.append(document)
  750. # save vector index
  751. index = IndexBuilder.get_index(dataset, 'high_quality')
  752. if index:
  753. index.add_texts(documents, duplicate_check=True)
  754. # save keyword index
  755. index = IndexBuilder.get_index(dataset, 'economy')
  756. if index:
  757. index.add_texts(documents)
  758. class DocumentIsPausedException(Exception):
  759. pass
  760. class DocumentIsDeletedPausedException(Exception):
  761. pass