datasets_segments.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441
  1. import uuid
  2. from datetime import datetime, timezone
  3. import pandas as pd
  4. from flask import request
  5. from flask_login import current_user
  6. from flask_restful import Resource, marshal, reqparse
  7. from werkzeug.exceptions import Forbidden, NotFound
  8. import services
  9. from controllers.console import api
  10. from controllers.console.app.error import ProviderNotInitializeError
  11. from controllers.console.datasets.error import InvalidActionError, NoFileUploadedError, TooManyFilesError
  12. from controllers.console.setup import setup_required
  13. from controllers.console.wraps import (
  14. account_initialization_required,
  15. cloud_edition_billing_knowledge_limit_check,
  16. cloud_edition_billing_resource_check,
  17. )
  18. from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
  19. from core.model_manager import ModelManager
  20. from core.model_runtime.entities.model_entities import ModelType
  21. from extensions.ext_database import db
  22. from extensions.ext_redis import redis_client
  23. from fields.segment_fields import segment_fields
  24. from libs.login import login_required
  25. from models.dataset import DocumentSegment
  26. from services.dataset_service import DatasetService, DocumentService, SegmentService
  27. from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
  28. from tasks.disable_segment_from_index_task import disable_segment_from_index_task
  29. from tasks.enable_segment_to_index_task import enable_segment_to_index_task
  30. class DatasetDocumentSegmentListApi(Resource):
  31. @setup_required
  32. @login_required
  33. @account_initialization_required
  34. def get(self, dataset_id, document_id):
  35. dataset_id = str(dataset_id)
  36. document_id = str(document_id)
  37. dataset = DatasetService.get_dataset(dataset_id)
  38. if not dataset:
  39. raise NotFound('Dataset not found.')
  40. try:
  41. DatasetService.check_dataset_permission(dataset, current_user)
  42. except services.errors.account.NoPermissionError as e:
  43. raise Forbidden(str(e))
  44. document = DocumentService.get_document(dataset_id, document_id)
  45. if not document:
  46. raise NotFound('Document not found.')
  47. parser = reqparse.RequestParser()
  48. parser.add_argument('last_id', type=str, default=None, location='args')
  49. parser.add_argument('limit', type=int, default=20, location='args')
  50. parser.add_argument('status', type=str,
  51. action='append', default=[], location='args')
  52. parser.add_argument('hit_count_gte', type=int,
  53. default=None, location='args')
  54. parser.add_argument('enabled', type=str, default='all', location='args')
  55. parser.add_argument('keyword', type=str, default=None, location='args')
  56. args = parser.parse_args()
  57. last_id = args['last_id']
  58. limit = min(args['limit'], 100)
  59. status_list = args['status']
  60. hit_count_gte = args['hit_count_gte']
  61. keyword = args['keyword']
  62. query = DocumentSegment.query.filter(
  63. DocumentSegment.document_id == str(document_id),
  64. DocumentSegment.tenant_id == current_user.current_tenant_id
  65. )
  66. if last_id is not None:
  67. last_segment = DocumentSegment.query.get(str(last_id))
  68. if last_segment:
  69. query = query.filter(
  70. DocumentSegment.position > last_segment.position)
  71. else:
  72. return {'data': [], 'has_more': False, 'limit': limit}, 200
  73. if status_list:
  74. query = query.filter(DocumentSegment.status.in_(status_list))
  75. if hit_count_gte is not None:
  76. query = query.filter(DocumentSegment.hit_count >= hit_count_gte)
  77. if keyword:
  78. query = query.where(DocumentSegment.content.ilike(f'%{keyword}%'))
  79. if args['enabled'].lower() != 'all':
  80. if args['enabled'].lower() == 'true':
  81. query = query.filter(DocumentSegment.enabled == True)
  82. elif args['enabled'].lower() == 'false':
  83. query = query.filter(DocumentSegment.enabled == False)
  84. total = query.count()
  85. segments = query.order_by(DocumentSegment.position).limit(limit + 1).all()
  86. has_more = False
  87. if len(segments) > limit:
  88. has_more = True
  89. segments = segments[:-1]
  90. return {
  91. 'data': marshal(segments, segment_fields),
  92. 'doc_form': document.doc_form,
  93. 'has_more': has_more,
  94. 'limit': limit,
  95. 'total': total
  96. }, 200
  97. class DatasetDocumentSegmentApi(Resource):
  98. @setup_required
  99. @login_required
  100. @account_initialization_required
  101. @cloud_edition_billing_resource_check('vector_space')
  102. def patch(self, dataset_id, segment_id, action):
  103. dataset_id = str(dataset_id)
  104. dataset = DatasetService.get_dataset(dataset_id)
  105. if not dataset:
  106. raise NotFound('Dataset not found.')
  107. # check user's model setting
  108. DatasetService.check_dataset_model_setting(dataset)
  109. # The role of the current user in the ta table must be admin, owner, or editor
  110. if not current_user.is_editor:
  111. raise Forbidden()
  112. try:
  113. DatasetService.check_dataset_permission(dataset, current_user)
  114. except services.errors.account.NoPermissionError as e:
  115. raise Forbidden(str(e))
  116. if dataset.indexing_technique == 'high_quality':
  117. # check embedding model setting
  118. try:
  119. model_manager = ModelManager()
  120. model_manager.get_model_instance(
  121. tenant_id=current_user.current_tenant_id,
  122. provider=dataset.embedding_model_provider,
  123. model_type=ModelType.TEXT_EMBEDDING,
  124. model=dataset.embedding_model
  125. )
  126. except LLMBadRequestError:
  127. raise ProviderNotInitializeError(
  128. "No Embedding Model available. Please configure a valid provider "
  129. "in the Settings -> Model Provider.")
  130. except ProviderTokenNotInitError as ex:
  131. raise ProviderNotInitializeError(ex.description)
  132. segment = DocumentSegment.query.filter(
  133. DocumentSegment.id == str(segment_id),
  134. DocumentSegment.tenant_id == current_user.current_tenant_id
  135. ).first()
  136. if not segment:
  137. raise NotFound('Segment not found.')
  138. if segment.status != 'completed':
  139. raise NotFound('Segment is not completed, enable or disable function is not allowed')
  140. document_indexing_cache_key = 'document_{}_indexing'.format(segment.document_id)
  141. cache_result = redis_client.get(document_indexing_cache_key)
  142. if cache_result is not None:
  143. raise InvalidActionError("Document is being indexed, please try again later")
  144. indexing_cache_key = 'segment_{}_indexing'.format(segment.id)
  145. cache_result = redis_client.get(indexing_cache_key)
  146. if cache_result is not None:
  147. raise InvalidActionError("Segment is being indexed, please try again later")
  148. if action == "enable":
  149. if segment.enabled:
  150. raise InvalidActionError("Segment is already enabled.")
  151. segment.enabled = True
  152. segment.disabled_at = None
  153. segment.disabled_by = None
  154. db.session.commit()
  155. # Set cache to prevent indexing the same segment multiple times
  156. redis_client.setex(indexing_cache_key, 600, 1)
  157. enable_segment_to_index_task.delay(segment.id)
  158. return {'result': 'success'}, 200
  159. elif action == "disable":
  160. if not segment.enabled:
  161. raise InvalidActionError("Segment is already disabled.")
  162. segment.enabled = False
  163. segment.disabled_at = datetime.now(timezone.utc).replace(tzinfo=None)
  164. segment.disabled_by = current_user.id
  165. db.session.commit()
  166. # Set cache to prevent indexing the same segment multiple times
  167. redis_client.setex(indexing_cache_key, 600, 1)
  168. disable_segment_from_index_task.delay(segment.id)
  169. return {'result': 'success'}, 200
  170. else:
  171. raise InvalidActionError()
  172. class DatasetDocumentSegmentAddApi(Resource):
  173. @setup_required
  174. @login_required
  175. @account_initialization_required
  176. @cloud_edition_billing_resource_check('vector_space')
  177. @cloud_edition_billing_knowledge_limit_check('add_segment')
  178. def post(self, dataset_id, document_id):
  179. # check dataset
  180. dataset_id = str(dataset_id)
  181. dataset = DatasetService.get_dataset(dataset_id)
  182. if not dataset:
  183. raise NotFound('Dataset not found.')
  184. # check document
  185. document_id = str(document_id)
  186. document = DocumentService.get_document(dataset_id, document_id)
  187. if not document:
  188. raise NotFound('Document not found.')
  189. # The role of the current user in the ta table must be admin or owner
  190. if not current_user.is_admin_or_owner:
  191. raise Forbidden()
  192. # check embedding model setting
  193. if dataset.indexing_technique == 'high_quality':
  194. try:
  195. model_manager = ModelManager()
  196. model_manager.get_model_instance(
  197. tenant_id=current_user.current_tenant_id,
  198. provider=dataset.embedding_model_provider,
  199. model_type=ModelType.TEXT_EMBEDDING,
  200. model=dataset.embedding_model
  201. )
  202. except LLMBadRequestError:
  203. raise ProviderNotInitializeError(
  204. "No Embedding Model available. Please configure a valid provider "
  205. "in the Settings -> Model Provider.")
  206. except ProviderTokenNotInitError as ex:
  207. raise ProviderNotInitializeError(ex.description)
  208. try:
  209. DatasetService.check_dataset_permission(dataset, current_user)
  210. except services.errors.account.NoPermissionError as e:
  211. raise Forbidden(str(e))
  212. # validate args
  213. parser = reqparse.RequestParser()
  214. parser.add_argument('content', type=str, required=True, nullable=False, location='json')
  215. parser.add_argument('answer', type=str, required=False, nullable=True, location='json')
  216. parser.add_argument('keywords', type=list, required=False, nullable=True, location='json')
  217. args = parser.parse_args()
  218. SegmentService.segment_create_args_validate(args, document)
  219. segment = SegmentService.create_segment(args, document, dataset)
  220. return {
  221. 'data': marshal(segment, segment_fields),
  222. 'doc_form': document.doc_form
  223. }, 200
  224. class DatasetDocumentSegmentUpdateApi(Resource):
  225. @setup_required
  226. @login_required
  227. @account_initialization_required
  228. @cloud_edition_billing_resource_check('vector_space')
  229. def patch(self, dataset_id, document_id, segment_id):
  230. # check dataset
  231. dataset_id = str(dataset_id)
  232. dataset = DatasetService.get_dataset(dataset_id)
  233. if not dataset:
  234. raise NotFound('Dataset not found.')
  235. # check user's model setting
  236. DatasetService.check_dataset_model_setting(dataset)
  237. # check document
  238. document_id = str(document_id)
  239. document = DocumentService.get_document(dataset_id, document_id)
  240. if not document:
  241. raise NotFound('Document not found.')
  242. if dataset.indexing_technique == 'high_quality':
  243. # check embedding model setting
  244. try:
  245. model_manager = ModelManager()
  246. model_manager.get_model_instance(
  247. tenant_id=current_user.current_tenant_id,
  248. provider=dataset.embedding_model_provider,
  249. model_type=ModelType.TEXT_EMBEDDING,
  250. model=dataset.embedding_model
  251. )
  252. except LLMBadRequestError:
  253. raise ProviderNotInitializeError(
  254. "No Embedding Model available. Please configure a valid provider "
  255. "in the Settings -> Model Provider.")
  256. except ProviderTokenNotInitError as ex:
  257. raise ProviderNotInitializeError(ex.description)
  258. # check segment
  259. segment_id = str(segment_id)
  260. segment = DocumentSegment.query.filter(
  261. DocumentSegment.id == str(segment_id),
  262. DocumentSegment.tenant_id == current_user.current_tenant_id
  263. ).first()
  264. if not segment:
  265. raise NotFound('Segment not found.')
  266. # The role of the current user in the ta table must be admin, owner, or editor
  267. if not current_user.is_editor:
  268. raise Forbidden()
  269. try:
  270. DatasetService.check_dataset_permission(dataset, current_user)
  271. except services.errors.account.NoPermissionError as e:
  272. raise Forbidden(str(e))
  273. # validate args
  274. parser = reqparse.RequestParser()
  275. parser.add_argument('content', type=str, required=True, nullable=False, location='json')
  276. parser.add_argument('answer', type=str, required=False, nullable=True, location='json')
  277. parser.add_argument('keywords', type=list, required=False, nullable=True, location='json')
  278. args = parser.parse_args()
  279. SegmentService.segment_create_args_validate(args, document)
  280. segment = SegmentService.update_segment(args, segment, document, dataset)
  281. return {
  282. 'data': marshal(segment, segment_fields),
  283. 'doc_form': document.doc_form
  284. }, 200
  285. @setup_required
  286. @login_required
  287. @account_initialization_required
  288. def delete(self, dataset_id, document_id, segment_id):
  289. # check dataset
  290. dataset_id = str(dataset_id)
  291. dataset = DatasetService.get_dataset(dataset_id)
  292. if not dataset:
  293. raise NotFound('Dataset not found.')
  294. # check user's model setting
  295. DatasetService.check_dataset_model_setting(dataset)
  296. # check document
  297. document_id = str(document_id)
  298. document = DocumentService.get_document(dataset_id, document_id)
  299. if not document:
  300. raise NotFound('Document not found.')
  301. # check segment
  302. segment_id = str(segment_id)
  303. segment = DocumentSegment.query.filter(
  304. DocumentSegment.id == str(segment_id),
  305. DocumentSegment.tenant_id == current_user.current_tenant_id
  306. ).first()
  307. if not segment:
  308. raise NotFound('Segment not found.')
  309. # The role of the current user in the ta table must be admin or owner
  310. if not current_user.is_admin_or_owner:
  311. raise Forbidden()
  312. try:
  313. DatasetService.check_dataset_permission(dataset, current_user)
  314. except services.errors.account.NoPermissionError as e:
  315. raise Forbidden(str(e))
  316. SegmentService.delete_segment(segment, document, dataset)
  317. return {'result': 'success'}, 200
  318. class DatasetDocumentSegmentBatchImportApi(Resource):
  319. @setup_required
  320. @login_required
  321. @account_initialization_required
  322. @cloud_edition_billing_resource_check('vector_space')
  323. @cloud_edition_billing_knowledge_limit_check('add_segment')
  324. def post(self, dataset_id, document_id):
  325. # check dataset
  326. dataset_id = str(dataset_id)
  327. dataset = DatasetService.get_dataset(dataset_id)
  328. if not dataset:
  329. raise NotFound('Dataset not found.')
  330. # check document
  331. document_id = str(document_id)
  332. document = DocumentService.get_document(dataset_id, document_id)
  333. if not document:
  334. raise NotFound('Document not found.')
  335. # get file from request
  336. file = request.files['file']
  337. # check file
  338. if 'file' not in request.files:
  339. raise NoFileUploadedError()
  340. if len(request.files) > 1:
  341. raise TooManyFilesError()
  342. # check file type
  343. if not file.filename.endswith('.csv'):
  344. raise ValueError("Invalid file type. Only CSV files are allowed")
  345. try:
  346. # Skip the first row
  347. df = pd.read_csv(file)
  348. result = []
  349. for index, row in df.iterrows():
  350. if document.doc_form == 'qa_model':
  351. data = {'content': row[0], 'answer': row[1]}
  352. else:
  353. data = {'content': row[0]}
  354. result.append(data)
  355. if len(result) == 0:
  356. raise ValueError("The CSV file is empty.")
  357. # async job
  358. job_id = str(uuid.uuid4())
  359. indexing_cache_key = 'segment_batch_import_{}'.format(str(job_id))
  360. # send batch add segments task
  361. redis_client.setnx(indexing_cache_key, 'waiting')
  362. batch_create_segment_to_index_task.delay(str(job_id), result, dataset_id, document_id,
  363. current_user.current_tenant_id, current_user.id)
  364. except Exception as e:
  365. return {'error': str(e)}, 500
  366. return {
  367. 'job_id': job_id,
  368. 'job_status': 'waiting'
  369. }, 200
  370. @setup_required
  371. @login_required
  372. @account_initialization_required
  373. def get(self, job_id):
  374. job_id = str(job_id)
  375. indexing_cache_key = 'segment_batch_import_{}'.format(job_id)
  376. cache_result = redis_client.get(indexing_cache_key)
  377. if cache_result is None:
  378. raise ValueError("The job is not exist.")
  379. return {
  380. 'job_id': job_id,
  381. 'job_status': cache_result.decode()
  382. }, 200
  383. api.add_resource(DatasetDocumentSegmentListApi,
  384. '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments')
  385. api.add_resource(DatasetDocumentSegmentApi,
  386. '/datasets/<uuid:dataset_id>/segments/<uuid:segment_id>/<string:action>')
  387. api.add_resource(DatasetDocumentSegmentAddApi,
  388. '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment')
  389. api.add_resource(DatasetDocumentSegmentUpdateApi,
  390. '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>')
  391. api.add_resource(DatasetDocumentSegmentBatchImportApi,
  392. '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/batch_import',
  393. '/datasets/batch_import_status/<uuid:job_id>')