dataset_service.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845
  1. import json
  2. import logging
  3. import datetime
  4. import time
  5. import random
  6. from typing import Optional, List
  7. from flask import current_app
  8. from extensions.ext_redis import redis_client
  9. from flask_login import current_user
  10. from events.dataset_event import dataset_was_deleted
  11. from events.document_event import document_was_deleted
  12. from extensions.ext_database import db
  13. from models.account import Account
  14. from models.dataset import Dataset, Document, DatasetQuery, DatasetProcessRule, AppDatasetJoin, DocumentSegment
  15. from models.model import UploadFile
  16. from models.source import DataSourceBinding
  17. from services.errors.account import NoPermissionError
  18. from services.errors.dataset import DatasetNameDuplicateError
  19. from services.errors.document import DocumentIndexingError
  20. from services.errors.file import FileNotExistsError
  21. from tasks.clean_notion_document_task import clean_notion_document_task
  22. from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task
  23. from tasks.document_indexing_task import document_indexing_task
  24. from tasks.document_indexing_update_task import document_indexing_update_task
  25. class DatasetService:
  26. @staticmethod
  27. def get_datasets(page, per_page, provider="vendor", tenant_id=None, user=None):
  28. if user:
  29. permission_filter = db.or_(Dataset.created_by == user.id,
  30. Dataset.permission == 'all_team_members')
  31. else:
  32. permission_filter = Dataset.permission == 'all_team_members'
  33. datasets = Dataset.query.filter(
  34. db.and_(Dataset.provider == provider, Dataset.tenant_id == tenant_id, permission_filter)) \
  35. .order_by(Dataset.created_at.desc()) \
  36. .paginate(
  37. page=page,
  38. per_page=per_page,
  39. max_per_page=100,
  40. error_out=False
  41. )
  42. return datasets.items, datasets.total
  43. @staticmethod
  44. def get_process_rules(dataset_id):
  45. # get the latest process rule
  46. dataset_process_rule = db.session.query(DatasetProcessRule). \
  47. filter(DatasetProcessRule.dataset_id == dataset_id). \
  48. order_by(DatasetProcessRule.created_at.desc()). \
  49. limit(1). \
  50. one_or_none()
  51. if dataset_process_rule:
  52. mode = dataset_process_rule.mode
  53. rules = dataset_process_rule.rules_dict
  54. else:
  55. mode = DocumentService.DEFAULT_RULES['mode']
  56. rules = DocumentService.DEFAULT_RULES['rules']
  57. return {
  58. 'mode': mode,
  59. 'rules': rules
  60. }
  61. @staticmethod
  62. def get_datasets_by_ids(ids, tenant_id):
  63. datasets = Dataset.query.filter(Dataset.id.in_(ids),
  64. Dataset.tenant_id == tenant_id).paginate(
  65. page=1, per_page=len(ids), max_per_page=len(ids), error_out=False)
  66. return datasets.items, datasets.total
  67. @staticmethod
  68. def create_empty_dataset(tenant_id: str, name: str, indexing_technique: Optional[str], account: Account):
  69. # check if dataset name already exists
  70. if Dataset.query.filter_by(name=name, tenant_id=tenant_id).first():
  71. raise DatasetNameDuplicateError(
  72. f'Dataset with name {name} already exists.')
  73. dataset = Dataset(name=name, indexing_technique=indexing_technique)
  74. # dataset = Dataset(name=name, provider=provider, config=config)
  75. dataset.created_by = account.id
  76. dataset.updated_by = account.id
  77. dataset.tenant_id = tenant_id
  78. db.session.add(dataset)
  79. db.session.commit()
  80. return dataset
  81. @staticmethod
  82. def get_dataset(dataset_id):
  83. dataset = Dataset.query.filter_by(
  84. id=dataset_id
  85. ).first()
  86. if dataset is None:
  87. return None
  88. else:
  89. return dataset
  90. @staticmethod
  91. def update_dataset(dataset_id, data, user):
  92. dataset = DatasetService.get_dataset(dataset_id)
  93. DatasetService.check_dataset_permission(dataset, user)
  94. if dataset.indexing_technique != data['indexing_technique']:
  95. # if update indexing_technique
  96. if data['indexing_technique'] == 'economy':
  97. deal_dataset_vector_index_task.delay(dataset_id, 'remove')
  98. elif data['indexing_technique'] == 'high_quality':
  99. deal_dataset_vector_index_task.delay(dataset_id, 'add')
  100. filtered_data = {k: v for k, v in data.items() if v is not None or k == 'description'}
  101. filtered_data['updated_by'] = user.id
  102. filtered_data['updated_at'] = datetime.datetime.now()
  103. dataset.query.filter_by(id=dataset_id).update(filtered_data)
  104. db.session.commit()
  105. return dataset
  106. @staticmethod
  107. def delete_dataset(dataset_id, user):
  108. # todo: cannot delete dataset if it is being processed
  109. dataset = DatasetService.get_dataset(dataset_id)
  110. if dataset is None:
  111. return False
  112. DatasetService.check_dataset_permission(dataset, user)
  113. dataset_was_deleted.send(dataset)
  114. db.session.delete(dataset)
  115. db.session.commit()
  116. return True
  117. @staticmethod
  118. def check_dataset_permission(dataset, user):
  119. if dataset.tenant_id != user.current_tenant_id:
  120. logging.debug(
  121. f'User {user.id} does not have permission to access dataset {dataset.id}')
  122. raise NoPermissionError(
  123. 'You do not have permission to access this dataset.')
  124. if dataset.permission == 'only_me' and dataset.created_by != user.id:
  125. logging.debug(
  126. f'User {user.id} does not have permission to access dataset {dataset.id}')
  127. raise NoPermissionError(
  128. 'You do not have permission to access this dataset.')
  129. @staticmethod
  130. def get_dataset_queries(dataset_id: str, page: int, per_page: int):
  131. dataset_queries = DatasetQuery.query.filter_by(dataset_id=dataset_id) \
  132. .order_by(db.desc(DatasetQuery.created_at)) \
  133. .paginate(
  134. page=page, per_page=per_page, max_per_page=100, error_out=False
  135. )
  136. return dataset_queries.items, dataset_queries.total
  137. @staticmethod
  138. def get_related_apps(dataset_id: str):
  139. return AppDatasetJoin.query.filter(AppDatasetJoin.dataset_id == dataset_id) \
  140. .order_by(db.desc(AppDatasetJoin.created_at)).all()
  141. class DocumentService:
  142. DEFAULT_RULES = {
  143. 'mode': 'custom',
  144. 'rules': {
  145. 'pre_processing_rules': [
  146. {'id': 'remove_extra_spaces', 'enabled': True},
  147. {'id': 'remove_urls_emails', 'enabled': False}
  148. ],
  149. 'segmentation': {
  150. 'delimiter': '\n',
  151. 'max_tokens': 500
  152. }
  153. }
  154. }
  155. DOCUMENT_METADATA_SCHEMA = {
  156. "book": {
  157. "title": str,
  158. "language": str,
  159. "author": str,
  160. "publisher": str,
  161. "publication_date": str,
  162. "isbn": str,
  163. "category": str,
  164. },
  165. "web_page": {
  166. "title": str,
  167. "url": str,
  168. "language": str,
  169. "publish_date": str,
  170. "author/publisher": str,
  171. "topic/keywords": str,
  172. "description": str,
  173. },
  174. "paper": {
  175. "title": str,
  176. "language": str,
  177. "author": str,
  178. "publish_date": str,
  179. "journal/conference_name": str,
  180. "volume/issue/page_numbers": str,
  181. "doi": str,
  182. "topic/keywords": str,
  183. "abstract": str,
  184. },
  185. "social_media_post": {
  186. "platform": str,
  187. "author/username": str,
  188. "publish_date": str,
  189. "post_url": str,
  190. "topic/tags": str,
  191. },
  192. "wikipedia_entry": {
  193. "title": str,
  194. "language": str,
  195. "web_page_url": str,
  196. "last_edit_date": str,
  197. "editor/contributor": str,
  198. "summary/introduction": str,
  199. },
  200. "personal_document": {
  201. "title": str,
  202. "author": str,
  203. "creation_date": str,
  204. "last_modified_date": str,
  205. "document_type": str,
  206. "tags/category": str,
  207. },
  208. "business_document": {
  209. "title": str,
  210. "author": str,
  211. "creation_date": str,
  212. "last_modified_date": str,
  213. "document_type": str,
  214. "department/team": str,
  215. },
  216. "im_chat_log": {
  217. "chat_platform": str,
  218. "chat_participants/group_name": str,
  219. "start_date": str,
  220. "end_date": str,
  221. "summary": str,
  222. },
  223. "synced_from_notion": {
  224. "title": str,
  225. "language": str,
  226. "author/creator": str,
  227. "creation_date": str,
  228. "last_modified_date": str,
  229. "notion_page_link": str,
  230. "category/tags": str,
  231. "description": str,
  232. },
  233. "synced_from_github": {
  234. "repository_name": str,
  235. "repository_description": str,
  236. "repository_owner/organization": str,
  237. "code_filename": str,
  238. "code_file_path": str,
  239. "programming_language": str,
  240. "github_link": str,
  241. "open_source_license": str,
  242. "commit_date": str,
  243. "commit_author": str
  244. }
  245. }
  246. @staticmethod
  247. def get_document(dataset_id: str, document_id: str) -> Optional[Document]:
  248. document = db.session.query(Document).filter(
  249. Document.id == document_id,
  250. Document.dataset_id == dataset_id
  251. ).first()
  252. return document
  253. @staticmethod
  254. def get_document_by_id(document_id: str) -> Optional[Document]:
  255. document = db.session.query(Document).filter(
  256. Document.id == document_id
  257. ).first()
  258. return document
  259. @staticmethod
  260. def get_document_by_dataset_id(dataset_id: str) -> List[Document]:
  261. documents = db.session.query(Document).filter(
  262. Document.dataset_id == dataset_id,
  263. Document.enabled == True
  264. ).all()
  265. return documents
  266. @staticmethod
  267. def get_batch_documents(dataset_id: str, batch: str) -> List[Document]:
  268. documents = db.session.query(Document).filter(
  269. Document.batch == batch,
  270. Document.dataset_id == dataset_id,
  271. Document.tenant_id == current_user.current_tenant_id
  272. ).all()
  273. return documents
  274. @staticmethod
  275. def get_document_file_detail(file_id: str):
  276. file_detail = db.session.query(UploadFile). \
  277. filter(UploadFile.id == file_id). \
  278. one_or_none()
  279. return file_detail
  280. @staticmethod
  281. def check_archived(document):
  282. if document.archived:
  283. return True
  284. else:
  285. return False
  286. @staticmethod
  287. def delete_document(document):
  288. if document.indexing_status in ["parsing", "cleaning", "splitting", "indexing"]:
  289. raise DocumentIndexingError()
  290. # trigger document_was_deleted signal
  291. document_was_deleted.send(document.id, dataset_id=document.dataset_id)
  292. db.session.delete(document)
  293. db.session.commit()
  294. @staticmethod
  295. def pause_document(document):
  296. if document.indexing_status not in ["waiting", "parsing", "cleaning", "splitting", "indexing"]:
  297. raise DocumentIndexingError()
  298. # update document to be paused
  299. document.is_paused = True
  300. document.paused_by = current_user.id
  301. document.paused_at = datetime.datetime.utcnow()
  302. db.session.add(document)
  303. db.session.commit()
  304. # set document paused flag
  305. indexing_cache_key = 'document_{}_is_paused'.format(document.id)
  306. redis_client.setnx(indexing_cache_key, "True")
  307. @staticmethod
  308. def recover_document(document):
  309. if not document.is_paused:
  310. raise DocumentIndexingError()
  311. # update document to be recover
  312. document.is_paused = False
  313. document.paused_by = current_user.id
  314. document.paused_at = time.time()
  315. db.session.add(document)
  316. db.session.commit()
  317. # delete paused flag
  318. indexing_cache_key = 'document_{}_is_paused'.format(document.id)
  319. redis_client.delete(indexing_cache_key)
  320. # trigger async task
  321. document_indexing_task.delay(document.dataset_id, document.id)
  322. @staticmethod
  323. def get_documents_position(dataset_id):
  324. document = Document.query.filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first()
  325. if document:
  326. return document.position + 1
  327. else:
  328. return 1
  329. @staticmethod
  330. def save_document_with_dataset_id(dataset: Dataset, document_data: dict,
  331. account: Account, dataset_process_rule: Optional[DatasetProcessRule] = None,
  332. created_from: str = 'web'):
  333. # check document limit
  334. if current_app.config['EDITION'] == 'CLOUD':
  335. documents_count = DocumentService.get_tenant_documents_count()
  336. tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT'])
  337. if documents_count > tenant_document_count:
  338. raise ValueError(f"over document limit {tenant_document_count}.")
  339. # if dataset is empty, update dataset data_source_type
  340. if not dataset.data_source_type:
  341. dataset.data_source_type = document_data["data_source"]["type"]
  342. db.session.commit()
  343. if not dataset.indexing_technique:
  344. if 'indexing_technique' not in document_data \
  345. or document_data['indexing_technique'] not in Dataset.INDEXING_TECHNIQUE_LIST:
  346. raise ValueError("Indexing technique is required")
  347. dataset.indexing_technique = document_data["indexing_technique"]
  348. documents = []
  349. batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999))
  350. if 'original_document_id' in document_data and document_data["original_document_id"]:
  351. document = DocumentService.update_document_with_dataset_id(dataset, document_data, account)
  352. documents.append(document)
  353. else:
  354. # save process rule
  355. if not dataset_process_rule:
  356. process_rule = document_data["process_rule"]
  357. if process_rule["mode"] == "custom":
  358. dataset_process_rule = DatasetProcessRule(
  359. dataset_id=dataset.id,
  360. mode=process_rule["mode"],
  361. rules=json.dumps(process_rule["rules"]),
  362. created_by=account.id
  363. )
  364. elif process_rule["mode"] == "automatic":
  365. dataset_process_rule = DatasetProcessRule(
  366. dataset_id=dataset.id,
  367. mode=process_rule["mode"],
  368. rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES),
  369. created_by=account.id
  370. )
  371. db.session.add(dataset_process_rule)
  372. db.session.commit()
  373. position = DocumentService.get_documents_position(dataset.id)
  374. document_ids = []
  375. if document_data["data_source"]["type"] == "upload_file":
  376. upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids']
  377. for file_id in upload_file_list:
  378. file = db.session.query(UploadFile).filter(
  379. UploadFile.tenant_id == dataset.tenant_id,
  380. UploadFile.id == file_id
  381. ).first()
  382. # raise error if file not found
  383. if not file:
  384. raise FileNotExistsError()
  385. file_name = file.name
  386. data_source_info = {
  387. "upload_file_id": file_id,
  388. }
  389. document = DocumentService.save_document(dataset, dataset_process_rule.id,
  390. document_data["data_source"]["type"],
  391. data_source_info, created_from, position,
  392. account, file_name, batch)
  393. db.session.add(document)
  394. db.session.flush()
  395. document_ids.append(document.id)
  396. documents.append(document)
  397. position += 1
  398. elif document_data["data_source"]["type"] == "notion_import":
  399. notion_info_list = document_data["data_source"]['info_list']['notion_info_list']
  400. exist_page_ids = []
  401. exist_document = dict()
  402. documents = Document.query.filter_by(
  403. dataset_id=dataset.id,
  404. tenant_id=current_user.current_tenant_id,
  405. data_source_type='notion_import',
  406. enabled=True
  407. ).all()
  408. if documents:
  409. for document in documents:
  410. data_source_info = json.loads(document.data_source_info)
  411. exist_page_ids.append(data_source_info['notion_page_id'])
  412. exist_document[data_source_info['notion_page_id']] = document.id
  413. for notion_info in notion_info_list:
  414. workspace_id = notion_info['workspace_id']
  415. data_source_binding = DataSourceBinding.query.filter(
  416. db.and_(
  417. DataSourceBinding.tenant_id == current_user.current_tenant_id,
  418. DataSourceBinding.provider == 'notion',
  419. DataSourceBinding.disabled == False,
  420. DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
  421. )
  422. ).first()
  423. if not data_source_binding:
  424. raise ValueError('Data source binding not found.')
  425. for page in notion_info['pages']:
  426. if page['page_id'] not in exist_page_ids:
  427. data_source_info = {
  428. "notion_workspace_id": workspace_id,
  429. "notion_page_id": page['page_id'],
  430. "notion_page_icon": page['page_icon'],
  431. "type": page['type']
  432. }
  433. document = DocumentService.save_document(dataset, dataset_process_rule.id,
  434. document_data["data_source"]["type"],
  435. data_source_info, created_from, position,
  436. account, page['page_name'], batch)
  437. # if page['type'] == 'database':
  438. # document.splitting_completed_at = datetime.datetime.utcnow()
  439. # document.cleaning_completed_at = datetime.datetime.utcnow()
  440. # document.parsing_completed_at = datetime.datetime.utcnow()
  441. # document.completed_at = datetime.datetime.utcnow()
  442. # document.indexing_status = 'completed'
  443. # document.word_count = 0
  444. # document.tokens = 0
  445. # document.indexing_latency = 0
  446. db.session.add(document)
  447. db.session.flush()
  448. # if page['type'] != 'database':
  449. document_ids.append(document.id)
  450. documents.append(document)
  451. position += 1
  452. else:
  453. exist_document.pop(page['page_id'])
  454. # delete not selected documents
  455. if len(exist_document) > 0:
  456. clean_notion_document_task.delay(list(exist_document.values()), dataset.id)
  457. db.session.commit()
  458. # trigger async task
  459. document_indexing_task.delay(dataset.id, document_ids)
  460. return documents, batch
  461. @staticmethod
  462. def save_document(dataset: Dataset, process_rule_id: str, data_source_type: str, data_source_info: dict,
  463. created_from: str, position: int, account: Account, name: str, batch: str):
  464. document = Document(
  465. tenant_id=dataset.tenant_id,
  466. dataset_id=dataset.id,
  467. position=position,
  468. data_source_type=data_source_type,
  469. data_source_info=json.dumps(data_source_info),
  470. dataset_process_rule_id=process_rule_id,
  471. batch=batch,
  472. name=name,
  473. created_from=created_from,
  474. created_by=account.id,
  475. )
  476. return document
  477. @staticmethod
  478. def get_tenant_documents_count():
  479. documents_count = Document.query.filter(Document.completed_at.isnot(None),
  480. Document.enabled == True,
  481. Document.archived == False,
  482. Document.tenant_id == current_user.current_tenant_id).count()
  483. return documents_count
  484. @staticmethod
  485. def update_document_with_dataset_id(dataset: Dataset, document_data: dict,
  486. account: Account, dataset_process_rule: Optional[DatasetProcessRule] = None,
  487. created_from: str = 'web'):
  488. document = DocumentService.get_document(dataset.id, document_data["original_document_id"])
  489. if document.display_status != 'available':
  490. raise ValueError("Document is not available")
  491. # save process rule
  492. if 'process_rule' in document_data and document_data['process_rule']:
  493. process_rule = document_data["process_rule"]
  494. if process_rule["mode"] == "custom":
  495. dataset_process_rule = DatasetProcessRule(
  496. dataset_id=dataset.id,
  497. mode=process_rule["mode"],
  498. rules=json.dumps(process_rule["rules"]),
  499. created_by=account.id
  500. )
  501. elif process_rule["mode"] == "automatic":
  502. dataset_process_rule = DatasetProcessRule(
  503. dataset_id=dataset.id,
  504. mode=process_rule["mode"],
  505. rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES),
  506. created_by=account.id
  507. )
  508. db.session.add(dataset_process_rule)
  509. db.session.commit()
  510. document.dataset_process_rule_id = dataset_process_rule.id
  511. # update document data source
  512. if 'data_source' in document_data and document_data['data_source']:
  513. file_name = ''
  514. data_source_info = {}
  515. if document_data["data_source"]["type"] == "upload_file":
  516. upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids']
  517. for file_id in upload_file_list:
  518. file = db.session.query(UploadFile).filter(
  519. UploadFile.tenant_id == dataset.tenant_id,
  520. UploadFile.id == file_id
  521. ).first()
  522. # raise error if file not found
  523. if not file:
  524. raise FileNotExistsError()
  525. file_name = file.name
  526. data_source_info = {
  527. "upload_file_id": file_id,
  528. }
  529. elif document_data["data_source"]["type"] == "notion_import":
  530. notion_info_list = document_data["data_source"]['info_list']['notion_info_list']
  531. for notion_info in notion_info_list:
  532. workspace_id = notion_info['workspace_id']
  533. data_source_binding = DataSourceBinding.query.filter(
  534. db.and_(
  535. DataSourceBinding.tenant_id == current_user.current_tenant_id,
  536. DataSourceBinding.provider == 'notion',
  537. DataSourceBinding.disabled == False,
  538. DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
  539. )
  540. ).first()
  541. if not data_source_binding:
  542. raise ValueError('Data source binding not found.')
  543. for page in notion_info['pages']:
  544. data_source_info = {
  545. "notion_workspace_id": workspace_id,
  546. "notion_page_id": page['page_id'],
  547. "notion_page_icon": page['page_icon'],
  548. "type": page['type']
  549. }
  550. document.data_source_type = document_data["data_source"]["type"]
  551. document.data_source_info = json.dumps(data_source_info)
  552. document.name = file_name
  553. # update document to be waiting
  554. document.indexing_status = 'waiting'
  555. document.completed_at = None
  556. document.processing_started_at = None
  557. document.parsing_completed_at = None
  558. document.cleaning_completed_at = None
  559. document.splitting_completed_at = None
  560. document.updated_at = datetime.datetime.utcnow()
  561. document.created_from = created_from
  562. db.session.add(document)
  563. db.session.commit()
  564. # update document segment
  565. update_params = {
  566. DocumentSegment.status: 're_segment'
  567. }
  568. DocumentSegment.query.filter_by(document_id=document.id).update(update_params)
  569. db.session.commit()
  570. # trigger async task
  571. document_indexing_update_task.delay(document.dataset_id, document.id)
  572. return document
  573. @staticmethod
  574. def save_document_without_dataset_id(tenant_id: str, document_data: dict, account: Account):
  575. # check document limit
  576. if current_app.config['EDITION'] == 'CLOUD':
  577. documents_count = DocumentService.get_tenant_documents_count()
  578. tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT'])
  579. if documents_count > tenant_document_count:
  580. raise ValueError(f"over document limit {tenant_document_count}.")
  581. # save dataset
  582. dataset = Dataset(
  583. tenant_id=tenant_id,
  584. name='',
  585. data_source_type=document_data["data_source"]["type"],
  586. indexing_technique=document_data["indexing_technique"],
  587. created_by=account.id
  588. )
  589. db.session.add(dataset)
  590. db.session.flush()
  591. documents, batch = DocumentService.save_document_with_dataset_id(dataset, document_data, account)
  592. cut_length = 18
  593. cut_name = documents[0].name[:cut_length]
  594. dataset.name = cut_name + '...'
  595. dataset.description = 'useful for when you want to answer queries about the ' + documents[0].name
  596. db.session.commit()
  597. return dataset, documents, batch
  598. @classmethod
  599. def document_create_args_validate(cls, args: dict):
  600. if 'original_document_id' not in args or not args['original_document_id']:
  601. DocumentService.data_source_args_validate(args)
  602. DocumentService.process_rule_args_validate(args)
  603. else:
  604. if ('data_source' not in args and not args['data_source'])\
  605. and ('process_rule' not in args and not args['process_rule']):
  606. raise ValueError("Data source or Process rule is required")
  607. else:
  608. if 'data_source' in args and args['data_source']:
  609. DocumentService.data_source_args_validate(args)
  610. if 'process_rule' in args and args['process_rule']:
  611. DocumentService.process_rule_args_validate(args)
  612. @classmethod
  613. def data_source_args_validate(cls, args: dict):
  614. if 'data_source' not in args or not args['data_source']:
  615. raise ValueError("Data source is required")
  616. if not isinstance(args['data_source'], dict):
  617. raise ValueError("Data source is invalid")
  618. if 'type' not in args['data_source'] or not args['data_source']['type']:
  619. raise ValueError("Data source type is required")
  620. if args['data_source']['type'] not in Document.DATA_SOURCES:
  621. raise ValueError("Data source type is invalid")
  622. if 'info_list' not in args['data_source'] or not args['data_source']['info_list']:
  623. raise ValueError("Data source info is required")
  624. if args['data_source']['type'] == 'upload_file':
  625. if 'file_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list']['file_info_list']:
  626. raise ValueError("File source info is required")
  627. if args['data_source']['type'] == 'notion_import':
  628. if 'notion_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list']['notion_info_list']:
  629. raise ValueError("Notion source info is required")
  630. @classmethod
  631. def process_rule_args_validate(cls, args: dict):
  632. if 'process_rule' not in args or not args['process_rule']:
  633. raise ValueError("Process rule is required")
  634. if not isinstance(args['process_rule'], dict):
  635. raise ValueError("Process rule is invalid")
  636. if 'mode' not in args['process_rule'] or not args['process_rule']['mode']:
  637. raise ValueError("Process rule mode is required")
  638. if args['process_rule']['mode'] not in DatasetProcessRule.MODES:
  639. raise ValueError("Process rule mode is invalid")
  640. if args['process_rule']['mode'] == 'automatic':
  641. args['process_rule']['rules'] = {}
  642. else:
  643. if 'rules' not in args['process_rule'] or not args['process_rule']['rules']:
  644. raise ValueError("Process rule rules is required")
  645. if not isinstance(args['process_rule']['rules'], dict):
  646. raise ValueError("Process rule rules is invalid")
  647. if 'pre_processing_rules' not in args['process_rule']['rules'] \
  648. or args['process_rule']['rules']['pre_processing_rules'] is None:
  649. raise ValueError("Process rule pre_processing_rules is required")
  650. if not isinstance(args['process_rule']['rules']['pre_processing_rules'], list):
  651. raise ValueError("Process rule pre_processing_rules is invalid")
  652. unique_pre_processing_rule_dicts = {}
  653. for pre_processing_rule in args['process_rule']['rules']['pre_processing_rules']:
  654. if 'id' not in pre_processing_rule or not pre_processing_rule['id']:
  655. raise ValueError("Process rule pre_processing_rules id is required")
  656. if pre_processing_rule['id'] not in DatasetProcessRule.PRE_PROCESSING_RULES:
  657. raise ValueError("Process rule pre_processing_rules id is invalid")
  658. if 'enabled' not in pre_processing_rule or pre_processing_rule['enabled'] is None:
  659. raise ValueError("Process rule pre_processing_rules enabled is required")
  660. if not isinstance(pre_processing_rule['enabled'], bool):
  661. raise ValueError("Process rule pre_processing_rules enabled is invalid")
  662. unique_pre_processing_rule_dicts[pre_processing_rule['id']] = pre_processing_rule
  663. args['process_rule']['rules']['pre_processing_rules'] = list(unique_pre_processing_rule_dicts.values())
  664. if 'segmentation' not in args['process_rule']['rules'] \
  665. or args['process_rule']['rules']['segmentation'] is None:
  666. raise ValueError("Process rule segmentation is required")
  667. if not isinstance(args['process_rule']['rules']['segmentation'], dict):
  668. raise ValueError("Process rule segmentation is invalid")
  669. if 'separator' not in args['process_rule']['rules']['segmentation'] \
  670. or not args['process_rule']['rules']['segmentation']['separator']:
  671. raise ValueError("Process rule segmentation separator is required")
  672. if not isinstance(args['process_rule']['rules']['segmentation']['separator'], str):
  673. raise ValueError("Process rule segmentation separator is invalid")
  674. if 'max_tokens' not in args['process_rule']['rules']['segmentation'] \
  675. or not args['process_rule']['rules']['segmentation']['max_tokens']:
  676. raise ValueError("Process rule segmentation max_tokens is required")
  677. if not isinstance(args['process_rule']['rules']['segmentation']['max_tokens'], int):
  678. raise ValueError("Process rule segmentation max_tokens is invalid")
  679. @classmethod
  680. def estimate_args_validate(cls, args: dict):
  681. if 'info_list' not in args or not args['info_list']:
  682. raise ValueError("Data source info is required")
  683. if not isinstance(args['info_list'], dict):
  684. raise ValueError("Data info is invalid")
  685. if 'process_rule' not in args or not args['process_rule']:
  686. raise ValueError("Process rule is required")
  687. if not isinstance(args['process_rule'], dict):
  688. raise ValueError("Process rule is invalid")
  689. if 'mode' not in args['process_rule'] or not args['process_rule']['mode']:
  690. raise ValueError("Process rule mode is required")
  691. if args['process_rule']['mode'] not in DatasetProcessRule.MODES:
  692. raise ValueError("Process rule mode is invalid")
  693. if args['process_rule']['mode'] == 'automatic':
  694. args['process_rule']['rules'] = {}
  695. else:
  696. if 'rules' not in args['process_rule'] or not args['process_rule']['rules']:
  697. raise ValueError("Process rule rules is required")
  698. if not isinstance(args['process_rule']['rules'], dict):
  699. raise ValueError("Process rule rules is invalid")
  700. if 'pre_processing_rules' not in args['process_rule']['rules'] \
  701. or args['process_rule']['rules']['pre_processing_rules'] is None:
  702. raise ValueError("Process rule pre_processing_rules is required")
  703. if not isinstance(args['process_rule']['rules']['pre_processing_rules'], list):
  704. raise ValueError("Process rule pre_processing_rules is invalid")
  705. unique_pre_processing_rule_dicts = {}
  706. for pre_processing_rule in args['process_rule']['rules']['pre_processing_rules']:
  707. if 'id' not in pre_processing_rule or not pre_processing_rule['id']:
  708. raise ValueError("Process rule pre_processing_rules id is required")
  709. if pre_processing_rule['id'] not in DatasetProcessRule.PRE_PROCESSING_RULES:
  710. raise ValueError("Process rule pre_processing_rules id is invalid")
  711. if 'enabled' not in pre_processing_rule or pre_processing_rule['enabled'] is None:
  712. raise ValueError("Process rule pre_processing_rules enabled is required")
  713. if not isinstance(pre_processing_rule['enabled'], bool):
  714. raise ValueError("Process rule pre_processing_rules enabled is invalid")
  715. unique_pre_processing_rule_dicts[pre_processing_rule['id']] = pre_processing_rule
  716. args['process_rule']['rules']['pre_processing_rules'] = list(unique_pre_processing_rule_dicts.values())
  717. if 'segmentation' not in args['process_rule']['rules'] \
  718. or args['process_rule']['rules']['segmentation'] is None:
  719. raise ValueError("Process rule segmentation is required")
  720. if not isinstance(args['process_rule']['rules']['segmentation'], dict):
  721. raise ValueError("Process rule segmentation is invalid")
  722. if 'separator' not in args['process_rule']['rules']['segmentation'] \
  723. or not args['process_rule']['rules']['segmentation']['separator']:
  724. raise ValueError("Process rule segmentation separator is required")
  725. if not isinstance(args['process_rule']['rules']['segmentation']['separator'], str):
  726. raise ValueError("Process rule segmentation separator is invalid")
  727. if 'max_tokens' not in args['process_rule']['rules']['segmentation'] \
  728. or not args['process_rule']['rules']['segmentation']['max_tokens']:
  729. raise ValueError("Process rule segmentation max_tokens is required")
  730. if not isinstance(args['process_rule']['rules']['segmentation']['max_tokens'], int):
  731. raise ValueError("Process rule segmentation max_tokens is invalid")