dataset_service.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527
  1. import json
  2. import logging
  3. import datetime
  4. import time
  5. import random
  6. from typing import Optional
  7. from extensions.ext_redis import redis_client
  8. from flask_login import current_user
  9. from core.index.index_builder import IndexBuilder
  10. from events.dataset_event import dataset_was_deleted
  11. from events.document_event import document_was_deleted
  12. from extensions.ext_database import db
  13. from models.account import Account
  14. from models.dataset import Dataset, Document, DatasetQuery, DatasetProcessRule, AppDatasetJoin
  15. from models.model import UploadFile
  16. from services.errors.account import NoPermissionError
  17. from services.errors.dataset import DatasetNameDuplicateError
  18. from services.errors.document import DocumentIndexingError
  19. from services.errors.file import FileNotExistsError
  20. from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task
  21. from tasks.document_indexing_task import document_indexing_task
  22. class DatasetService:
  23. @staticmethod
  24. def get_datasets(page, per_page, provider="vendor", tenant_id=None, user=None):
  25. if user:
  26. permission_filter = db.or_(Dataset.created_by == user.id,
  27. Dataset.permission == 'all_team_members')
  28. else:
  29. permission_filter = Dataset.permission == 'all_team_members'
  30. datasets = Dataset.query.filter(
  31. db.and_(Dataset.provider == provider, Dataset.tenant_id == tenant_id, permission_filter)) \
  32. .paginate(
  33. page=page,
  34. per_page=per_page,
  35. max_per_page=100,
  36. error_out=False
  37. )
  38. return datasets.items, datasets.total
  39. @staticmethod
  40. def get_process_rules(dataset_id):
  41. # get the latest process rule
  42. dataset_process_rule = db.session.query(DatasetProcessRule). \
  43. filter(DatasetProcessRule.dataset_id == dataset_id). \
  44. order_by(DatasetProcessRule.created_at.desc()). \
  45. limit(1). \
  46. one_or_none()
  47. if dataset_process_rule:
  48. mode = dataset_process_rule.mode
  49. rules = dataset_process_rule.rules_dict
  50. else:
  51. mode = DocumentService.DEFAULT_RULES['mode']
  52. rules = DocumentService.DEFAULT_RULES['rules']
  53. return {
  54. 'mode': mode,
  55. 'rules': rules
  56. }
  57. @staticmethod
  58. def get_datasets_by_ids(ids, tenant_id):
  59. datasets = Dataset.query.filter(Dataset.id.in_(ids),
  60. Dataset.tenant_id == tenant_id).paginate(
  61. page=1, per_page=len(ids), max_per_page=len(ids), error_out=False)
  62. return datasets.items, datasets.total
  63. @staticmethod
  64. def create_empty_dataset(tenant_id: str, name: str, indexing_technique: Optional[str], account: Account):
  65. # check if dataset name already exists
  66. if Dataset.query.filter_by(name=name, tenant_id=tenant_id).first():
  67. raise DatasetNameDuplicateError(
  68. f'Dataset with name {name} already exists.')
  69. dataset = Dataset(name=name, indexing_technique=indexing_technique, data_source_type='upload_file')
  70. # dataset = Dataset(name=name, provider=provider, config=config)
  71. dataset.created_by = account.id
  72. dataset.updated_by = account.id
  73. dataset.tenant_id = tenant_id
  74. db.session.add(dataset)
  75. db.session.commit()
  76. return dataset
  77. @staticmethod
  78. def get_dataset(dataset_id):
  79. dataset = Dataset.query.filter_by(
  80. id=dataset_id
  81. ).first()
  82. if dataset is None:
  83. return None
  84. else:
  85. return dataset
  86. @staticmethod
  87. def update_dataset(dataset_id, data, user):
  88. dataset = DatasetService.get_dataset(dataset_id)
  89. DatasetService.check_dataset_permission(dataset, user)
  90. if dataset.indexing_technique != data['indexing_technique']:
  91. # if update indexing_technique
  92. if data['indexing_technique'] == 'economy':
  93. deal_dataset_vector_index_task.delay(dataset_id, 'remove')
  94. elif data['indexing_technique'] == 'high_quality':
  95. deal_dataset_vector_index_task.delay(dataset_id, 'add')
  96. filtered_data = {k: v for k, v in data.items() if v is not None or k == 'description'}
  97. filtered_data['updated_by'] = user.id
  98. filtered_data['updated_at'] = datetime.datetime.now()
  99. dataset.query.filter_by(id=dataset_id).update(filtered_data)
  100. db.session.commit()
  101. return dataset
  102. @staticmethod
  103. def delete_dataset(dataset_id, user):
  104. # todo: cannot delete dataset if it is being processed
  105. dataset = DatasetService.get_dataset(dataset_id)
  106. if dataset is None:
  107. return False
  108. DatasetService.check_dataset_permission(dataset, user)
  109. dataset_was_deleted.send(dataset)
  110. db.session.delete(dataset)
  111. db.session.commit()
  112. return True
  113. @staticmethod
  114. def check_dataset_permission(dataset, user):
  115. if dataset.tenant_id != user.current_tenant_id:
  116. logging.debug(
  117. f'User {user.id} does not have permission to access dataset {dataset.id}')
  118. raise NoPermissionError(
  119. 'You do not have permission to access this dataset.')
  120. if dataset.permission == 'only_me' and dataset.created_by != user.id:
  121. logging.debug(
  122. f'User {user.id} does not have permission to access dataset {dataset.id}')
  123. raise NoPermissionError(
  124. 'You do not have permission to access this dataset.')
  125. @staticmethod
  126. def get_dataset_queries(dataset_id: str, page: int, per_page: int):
  127. dataset_queries = DatasetQuery.query.filter_by(dataset_id=dataset_id) \
  128. .order_by(db.desc(DatasetQuery.created_at)) \
  129. .paginate(
  130. page=page, per_page=per_page, max_per_page=100, error_out=False
  131. )
  132. return dataset_queries.items, dataset_queries.total
  133. @staticmethod
  134. def get_related_apps(dataset_id: str):
  135. return AppDatasetJoin.query.filter(AppDatasetJoin.dataset_id == dataset_id) \
  136. .order_by(db.desc(AppDatasetJoin.created_at)).all()
  137. class DocumentService:
  138. DEFAULT_RULES = {
  139. 'mode': 'custom',
  140. 'rules': {
  141. 'pre_processing_rules': [
  142. {'id': 'remove_extra_spaces', 'enabled': True},
  143. {'id': 'remove_urls_emails', 'enabled': False}
  144. ],
  145. 'segmentation': {
  146. 'delimiter': '\n',
  147. 'max_tokens': 500
  148. }
  149. }
  150. }
  151. DOCUMENT_METADATA_SCHEMA = {
  152. "book": {
  153. "title": str,
  154. "language": str,
  155. "author": str,
  156. "publisher": str,
  157. "publication_date": str,
  158. "isbn": str,
  159. "category": str,
  160. },
  161. "web_page": {
  162. "title": str,
  163. "url": str,
  164. "language": str,
  165. "publish_date": str,
  166. "author/publisher": str,
  167. "topic/keywords": str,
  168. "description": str,
  169. },
  170. "paper": {
  171. "title": str,
  172. "language": str,
  173. "author": str,
  174. "publish_date": str,
  175. "journal/conference_name": str,
  176. "volume/issue/page_numbers": str,
  177. "doi": str,
  178. "topic/keywords": str,
  179. "abstract": str,
  180. },
  181. "social_media_post": {
  182. "platform": str,
  183. "author/username": str,
  184. "publish_date": str,
  185. "post_url": str,
  186. "topic/tags": str,
  187. },
  188. "wikipedia_entry": {
  189. "title": str,
  190. "language": str,
  191. "web_page_url": str,
  192. "last_edit_date": str,
  193. "editor/contributor": str,
  194. "summary/introduction": str,
  195. },
  196. "personal_document": {
  197. "title": str,
  198. "author": str,
  199. "creation_date": str,
  200. "last_modified_date": str,
  201. "document_type": str,
  202. "tags/category": str,
  203. },
  204. "business_document": {
  205. "title": str,
  206. "author": str,
  207. "creation_date": str,
  208. "last_modified_date": str,
  209. "document_type": str,
  210. "department/team": str,
  211. },
  212. "im_chat_log": {
  213. "chat_platform": str,
  214. "chat_participants/group_name": str,
  215. "start_date": str,
  216. "end_date": str,
  217. "summary": str,
  218. },
  219. "synced_from_notion": {
  220. "title": str,
  221. "language": str,
  222. "author/creator": str,
  223. "creation_date": str,
  224. "last_modified_date": str,
  225. "notion_page_link": str,
  226. "category/tags": str,
  227. "description": str,
  228. },
  229. "synced_from_github": {
  230. "repository_name": str,
  231. "repository_description": str,
  232. "repository_owner/organization": str,
  233. "code_filename": str,
  234. "code_file_path": str,
  235. "programming_language": str,
  236. "github_link": str,
  237. "open_source_license": str,
  238. "commit_date": str,
  239. "commit_author": str
  240. }
  241. }
  242. @staticmethod
  243. def get_document(dataset_id: str, document_id: str) -> Optional[Document]:
  244. document = db.session.query(Document).filter(
  245. Document.id == document_id,
  246. Document.dataset_id == dataset_id
  247. ).first()
  248. return document
  249. @staticmethod
  250. def get_document_file_detail(file_id: str):
  251. file_detail = db.session.query(UploadFile). \
  252. filter(UploadFile.id == file_id). \
  253. one_or_none()
  254. return file_detail
  255. @staticmethod
  256. def check_archived(document):
  257. if document.archived:
  258. return True
  259. else:
  260. return False
  261. @staticmethod
  262. def delete_document(document):
  263. if document.indexing_status in ["parsing", "cleaning", "splitting", "indexing"]:
  264. raise DocumentIndexingError()
  265. # trigger document_was_deleted signal
  266. document_was_deleted.send(document.id, dataset_id=document.dataset_id)
  267. db.session.delete(document)
  268. db.session.commit()
  269. @staticmethod
  270. def pause_document(document):
  271. if document.indexing_status not in ["waiting", "parsing", "cleaning", "splitting", "indexing"]:
  272. raise DocumentIndexingError()
  273. # update document to be paused
  274. document.is_paused = True
  275. document.paused_by = current_user.id
  276. document.paused_at = datetime.datetime.utcnow()
  277. db.session.add(document)
  278. db.session.commit()
  279. # set document paused flag
  280. indexing_cache_key = 'document_{}_is_paused'.format(document.id)
  281. redis_client.setnx(indexing_cache_key, "True")
  282. @staticmethod
  283. def recover_document(document):
  284. if not document.is_paused:
  285. raise DocumentIndexingError()
  286. # update document to be recover
  287. document.is_paused = False
  288. document.paused_by = current_user.id
  289. document.paused_at = time.time()
  290. db.session.add(document)
  291. db.session.commit()
  292. # delete paused flag
  293. indexing_cache_key = 'document_{}_is_paused'.format(document.id)
  294. redis_client.delete(indexing_cache_key)
  295. # trigger async task
  296. document_indexing_task.delay(document.dataset_id, document.id)
  297. @staticmethod
  298. def get_documents_position(dataset_id):
  299. documents = Document.query.filter_by(dataset_id=dataset_id).all()
  300. if documents:
  301. return len(documents) + 1
  302. else:
  303. return 1
  304. @staticmethod
  305. def save_document_with_dataset_id(dataset: Dataset, document_data: dict,
  306. account: Account, dataset_process_rule: Optional[DatasetProcessRule] = None,
  307. created_from: str = 'web'):
  308. if not dataset.indexing_technique:
  309. if 'indexing_technique' not in document_data \
  310. or document_data['indexing_technique'] not in Dataset.INDEXING_TECHNIQUE_LIST:
  311. raise ValueError("Indexing technique is required")
  312. dataset.indexing_technique = document_data["indexing_technique"]
  313. if dataset.indexing_technique == 'high_quality':
  314. IndexBuilder.get_default_service_context(dataset.tenant_id)
  315. # save process rule
  316. if not dataset_process_rule:
  317. process_rule = document_data["process_rule"]
  318. if process_rule["mode"] == "custom":
  319. dataset_process_rule = DatasetProcessRule(
  320. dataset_id=dataset.id,
  321. mode=process_rule["mode"],
  322. rules=json.dumps(process_rule["rules"]),
  323. created_by=account.id
  324. )
  325. elif process_rule["mode"] == "automatic":
  326. dataset_process_rule = DatasetProcessRule(
  327. dataset_id=dataset.id,
  328. mode=process_rule["mode"],
  329. rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES),
  330. created_by=account.id
  331. )
  332. db.session.add(dataset_process_rule)
  333. db.session.commit()
  334. file_name = ''
  335. data_source_info = {}
  336. if document_data["data_source"]["type"] == "upload_file":
  337. file_id = document_data["data_source"]["info"]
  338. file = db.session.query(UploadFile).filter(
  339. UploadFile.tenant_id == dataset.tenant_id,
  340. UploadFile.id == file_id
  341. ).first()
  342. # raise error if file not found
  343. if not file:
  344. raise FileNotExistsError()
  345. file_name = file.name
  346. data_source_info = {
  347. "upload_file_id": file_id,
  348. }
  349. # save document
  350. position = DocumentService.get_documents_position(dataset.id)
  351. document = Document(
  352. tenant_id=dataset.tenant_id,
  353. dataset_id=dataset.id,
  354. position=position,
  355. data_source_type=document_data["data_source"]["type"],
  356. data_source_info=json.dumps(data_source_info),
  357. dataset_process_rule_id=dataset_process_rule.id,
  358. batch=time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999)),
  359. name=file_name,
  360. created_from=created_from,
  361. created_by=account.id,
  362. # created_api_request_id = db.Column(UUID, nullable=True)
  363. )
  364. db.session.add(document)
  365. db.session.commit()
  366. # trigger async task
  367. document_indexing_task.delay(document.dataset_id, document.id)
  368. return document
  369. @staticmethod
  370. def save_document_without_dataset_id(tenant_id: str, document_data: dict, account: Account):
  371. # save dataset
  372. dataset = Dataset(
  373. tenant_id=tenant_id,
  374. name='',
  375. data_source_type=document_data["data_source"]["type"],
  376. indexing_technique=document_data["indexing_technique"],
  377. created_by=account.id
  378. )
  379. db.session.add(dataset)
  380. db.session.flush()
  381. document = DocumentService.save_document_with_dataset_id(dataset, document_data, account)
  382. cut_length = 18
  383. cut_name = document.name[:cut_length]
  384. dataset.name = cut_name + '...' if len(document.name) > cut_length else cut_name
  385. dataset.description = 'useful for when you want to answer queries about the ' + document.name
  386. db.session.commit()
  387. return dataset, document
  388. @classmethod
  389. def document_create_args_validate(cls, args: dict):
  390. if 'data_source' not in args or not args['data_source']:
  391. raise ValueError("Data source is required")
  392. if not isinstance(args['data_source'], dict):
  393. raise ValueError("Data source is invalid")
  394. if 'type' not in args['data_source'] or not args['data_source']['type']:
  395. raise ValueError("Data source type is required")
  396. if args['data_source']['type'] not in Document.DATA_SOURCES:
  397. raise ValueError("Data source type is invalid")
  398. if args['data_source']['type'] == 'upload_file':
  399. if 'info' not in args['data_source'] or not args['data_source']['info']:
  400. raise ValueError("Data source info is required")
  401. if 'process_rule' not in args or not args['process_rule']:
  402. raise ValueError("Process rule is required")
  403. if not isinstance(args['process_rule'], dict):
  404. raise ValueError("Process rule is invalid")
  405. if 'mode' not in args['process_rule'] or not args['process_rule']['mode']:
  406. raise ValueError("Process rule mode is required")
  407. if args['process_rule']['mode'] not in DatasetProcessRule.MODES:
  408. raise ValueError("Process rule mode is invalid")
  409. if args['process_rule']['mode'] == 'automatic':
  410. args['process_rule']['rules'] = {}
  411. else:
  412. if 'rules' not in args['process_rule'] or not args['process_rule']['rules']:
  413. raise ValueError("Process rule rules is required")
  414. if not isinstance(args['process_rule']['rules'], dict):
  415. raise ValueError("Process rule rules is invalid")
  416. if 'pre_processing_rules' not in args['process_rule']['rules'] \
  417. or args['process_rule']['rules']['pre_processing_rules'] is None:
  418. raise ValueError("Process rule pre_processing_rules is required")
  419. if not isinstance(args['process_rule']['rules']['pre_processing_rules'], list):
  420. raise ValueError("Process rule pre_processing_rules is invalid")
  421. unique_pre_processing_rule_dicts = {}
  422. for pre_processing_rule in args['process_rule']['rules']['pre_processing_rules']:
  423. if 'id' not in pre_processing_rule or not pre_processing_rule['id']:
  424. raise ValueError("Process rule pre_processing_rules id is required")
  425. if pre_processing_rule['id'] not in DatasetProcessRule.PRE_PROCESSING_RULES:
  426. raise ValueError("Process rule pre_processing_rules id is invalid")
  427. if 'enabled' not in pre_processing_rule or pre_processing_rule['enabled'] is None:
  428. raise ValueError("Process rule pre_processing_rules enabled is required")
  429. if not isinstance(pre_processing_rule['enabled'], bool):
  430. raise ValueError("Process rule pre_processing_rules enabled is invalid")
  431. unique_pre_processing_rule_dicts[pre_processing_rule['id']] = pre_processing_rule
  432. args['process_rule']['rules']['pre_processing_rules'] = list(unique_pre_processing_rule_dicts.values())
  433. if 'segmentation' not in args['process_rule']['rules'] \
  434. or args['process_rule']['rules']['segmentation'] is None:
  435. raise ValueError("Process rule segmentation is required")
  436. if not isinstance(args['process_rule']['rules']['segmentation'], dict):
  437. raise ValueError("Process rule segmentation is invalid")
  438. if 'separator' not in args['process_rule']['rules']['segmentation'] \
  439. or not args['process_rule']['rules']['segmentation']['separator']:
  440. raise ValueError("Process rule segmentation separator is required")
  441. if not isinstance(args['process_rule']['rules']['segmentation']['separator'], str):
  442. raise ValueError("Process rule segmentation separator is invalid")
  443. if 'max_tokens' not in args['process_rule']['rules']['segmentation'] \
  444. or not args['process_rule']['rules']['segmentation']['max_tokens']:
  445. raise ValueError("Process rule segmentation max_tokens is required")
  446. if not isinstance(args['process_rule']['rules']['segmentation']['max_tokens'], int):
  447. raise ValueError("Process rule segmentation max_tokens is invalid")