datasets_segments.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435
  1. import uuid
  2. from datetime import datetime
  3. import pandas as pd
  4. from flask import request
  5. from flask_login import current_user
  6. from flask_restful import Resource, marshal, reqparse
  7. from werkzeug.exceptions import Forbidden, NotFound
  8. import services
  9. from controllers.console import api
  10. from controllers.console.app.error import ProviderNotInitializeError
  11. from controllers.console.datasets.error import InvalidActionError, NoFileUploadedError, TooManyFilesError
  12. from controllers.console.setup import setup_required
  13. from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
  14. from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
  15. from core.model_manager import ModelManager
  16. from core.model_runtime.entities.model_entities import ModelType
  17. from extensions.ext_database import db
  18. from extensions.ext_redis import redis_client
  19. from fields.segment_fields import segment_fields
  20. from libs.login import login_required
  21. from models.dataset import DocumentSegment
  22. from services.dataset_service import DatasetService, DocumentService, SegmentService
  23. from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
  24. from tasks.disable_segment_from_index_task import disable_segment_from_index_task
  25. from tasks.enable_segment_to_index_task import enable_segment_to_index_task
  26. class DatasetDocumentSegmentListApi(Resource):
  27. @setup_required
  28. @login_required
  29. @account_initialization_required
  30. def get(self, dataset_id, document_id):
  31. dataset_id = str(dataset_id)
  32. document_id = str(document_id)
  33. dataset = DatasetService.get_dataset(dataset_id)
  34. if not dataset:
  35. raise NotFound('Dataset not found.')
  36. try:
  37. DatasetService.check_dataset_permission(dataset, current_user)
  38. except services.errors.account.NoPermissionError as e:
  39. raise Forbidden(str(e))
  40. document = DocumentService.get_document(dataset_id, document_id)
  41. if not document:
  42. raise NotFound('Document not found.')
  43. parser = reqparse.RequestParser()
  44. parser.add_argument('last_id', type=str, default=None, location='args')
  45. parser.add_argument('limit', type=int, default=20, location='args')
  46. parser.add_argument('status', type=str,
  47. action='append', default=[], location='args')
  48. parser.add_argument('hit_count_gte', type=int,
  49. default=None, location='args')
  50. parser.add_argument('enabled', type=str, default='all', location='args')
  51. parser.add_argument('keyword', type=str, default=None, location='args')
  52. args = parser.parse_args()
  53. last_id = args['last_id']
  54. limit = min(args['limit'], 100)
  55. status_list = args['status']
  56. hit_count_gte = args['hit_count_gte']
  57. keyword = args['keyword']
  58. query = DocumentSegment.query.filter(
  59. DocumentSegment.document_id == str(document_id),
  60. DocumentSegment.tenant_id == current_user.current_tenant_id
  61. )
  62. if last_id is not None:
  63. last_segment = DocumentSegment.query.get(str(last_id))
  64. if last_segment:
  65. query = query.filter(
  66. DocumentSegment.position > last_segment.position)
  67. else:
  68. return {'data': [], 'has_more': False, 'limit': limit}, 200
  69. if status_list:
  70. query = query.filter(DocumentSegment.status.in_(status_list))
  71. if hit_count_gte is not None:
  72. query = query.filter(DocumentSegment.hit_count >= hit_count_gte)
  73. if keyword:
  74. query = query.where(DocumentSegment.content.ilike(f'%{keyword}%'))
  75. if args['enabled'].lower() != 'all':
  76. if args['enabled'].lower() == 'true':
  77. query = query.filter(DocumentSegment.enabled == True)
  78. elif args['enabled'].lower() == 'false':
  79. query = query.filter(DocumentSegment.enabled == False)
  80. total = query.count()
  81. segments = query.order_by(DocumentSegment.position).limit(limit + 1).all()
  82. has_more = False
  83. if len(segments) > limit:
  84. has_more = True
  85. segments = segments[:-1]
  86. return {
  87. 'data': marshal(segments, segment_fields),
  88. 'doc_form': document.doc_form,
  89. 'has_more': has_more,
  90. 'limit': limit,
  91. 'total': total
  92. }, 200
  93. class DatasetDocumentSegmentApi(Resource):
  94. @setup_required
  95. @login_required
  96. @account_initialization_required
  97. @cloud_edition_billing_resource_check('vector_space')
  98. def patch(self, dataset_id, segment_id, action):
  99. dataset_id = str(dataset_id)
  100. dataset = DatasetService.get_dataset(dataset_id)
  101. if not dataset:
  102. raise NotFound('Dataset not found.')
  103. # check user's model setting
  104. DatasetService.check_dataset_model_setting(dataset)
  105. # The role of the current user in the ta table must be admin or owner
  106. if not current_user.is_admin_or_owner:
  107. raise Forbidden()
  108. try:
  109. DatasetService.check_dataset_permission(dataset, current_user)
  110. except services.errors.account.NoPermissionError as e:
  111. raise Forbidden(str(e))
  112. if dataset.indexing_technique == 'high_quality':
  113. # check embedding model setting
  114. try:
  115. model_manager = ModelManager()
  116. model_manager.get_model_instance(
  117. tenant_id=current_user.current_tenant_id,
  118. provider=dataset.embedding_model_provider,
  119. model_type=ModelType.TEXT_EMBEDDING,
  120. model=dataset.embedding_model
  121. )
  122. except LLMBadRequestError:
  123. raise ProviderNotInitializeError(
  124. "No Embedding Model available. Please configure a valid provider "
  125. "in the Settings -> Model Provider.")
  126. except ProviderTokenNotInitError as ex:
  127. raise ProviderNotInitializeError(ex.description)
  128. segment = DocumentSegment.query.filter(
  129. DocumentSegment.id == str(segment_id),
  130. DocumentSegment.tenant_id == current_user.current_tenant_id
  131. ).first()
  132. if not segment:
  133. raise NotFound('Segment not found.')
  134. if segment.status != 'completed':
  135. raise NotFound('Segment is not completed, enable or disable function is not allowed')
  136. document_indexing_cache_key = 'document_{}_indexing'.format(segment.document_id)
  137. cache_result = redis_client.get(document_indexing_cache_key)
  138. if cache_result is not None:
  139. raise InvalidActionError("Document is being indexed, please try again later")
  140. indexing_cache_key = 'segment_{}_indexing'.format(segment.id)
  141. cache_result = redis_client.get(indexing_cache_key)
  142. if cache_result is not None:
  143. raise InvalidActionError("Segment is being indexed, please try again later")
  144. if action == "enable":
  145. if segment.enabled:
  146. raise InvalidActionError("Segment is already enabled.")
  147. segment.enabled = True
  148. segment.disabled_at = None
  149. segment.disabled_by = None
  150. db.session.commit()
  151. # Set cache to prevent indexing the same segment multiple times
  152. redis_client.setex(indexing_cache_key, 600, 1)
  153. enable_segment_to_index_task.delay(segment.id)
  154. return {'result': 'success'}, 200
  155. elif action == "disable":
  156. if not segment.enabled:
  157. raise InvalidActionError("Segment is already disabled.")
  158. segment.enabled = False
  159. segment.disabled_at = datetime.utcnow()
  160. segment.disabled_by = current_user.id
  161. db.session.commit()
  162. # Set cache to prevent indexing the same segment multiple times
  163. redis_client.setex(indexing_cache_key, 600, 1)
  164. disable_segment_from_index_task.delay(segment.id)
  165. return {'result': 'success'}, 200
  166. else:
  167. raise InvalidActionError()
  168. class DatasetDocumentSegmentAddApi(Resource):
  169. @setup_required
  170. @login_required
  171. @account_initialization_required
  172. @cloud_edition_billing_resource_check('vector_space')
  173. def post(self, dataset_id, document_id):
  174. # check dataset
  175. dataset_id = str(dataset_id)
  176. dataset = DatasetService.get_dataset(dataset_id)
  177. if not dataset:
  178. raise NotFound('Dataset not found.')
  179. # check document
  180. document_id = str(document_id)
  181. document = DocumentService.get_document(dataset_id, document_id)
  182. if not document:
  183. raise NotFound('Document not found.')
  184. # The role of the current user in the ta table must be admin or owner
  185. if not current_user.is_admin_or_owner:
  186. raise Forbidden()
  187. # check embedding model setting
  188. if dataset.indexing_technique == 'high_quality':
  189. try:
  190. model_manager = ModelManager()
  191. model_manager.get_model_instance(
  192. tenant_id=current_user.current_tenant_id,
  193. provider=dataset.embedding_model_provider,
  194. model_type=ModelType.TEXT_EMBEDDING,
  195. model=dataset.embedding_model
  196. )
  197. except LLMBadRequestError:
  198. raise ProviderNotInitializeError(
  199. "No Embedding Model available. Please configure a valid provider "
  200. "in the Settings -> Model Provider.")
  201. except ProviderTokenNotInitError as ex:
  202. raise ProviderNotInitializeError(ex.description)
  203. try:
  204. DatasetService.check_dataset_permission(dataset, current_user)
  205. except services.errors.account.NoPermissionError as e:
  206. raise Forbidden(str(e))
  207. # validate args
  208. parser = reqparse.RequestParser()
  209. parser.add_argument('content', type=str, required=True, nullable=False, location='json')
  210. parser.add_argument('answer', type=str, required=False, nullable=True, location='json')
  211. parser.add_argument('keywords', type=list, required=False, nullable=True, location='json')
  212. args = parser.parse_args()
  213. SegmentService.segment_create_args_validate(args, document)
  214. segment = SegmentService.create_segment(args, document, dataset)
  215. return {
  216. 'data': marshal(segment, segment_fields),
  217. 'doc_form': document.doc_form
  218. }, 200
  219. class DatasetDocumentSegmentUpdateApi(Resource):
  220. @setup_required
  221. @login_required
  222. @account_initialization_required
  223. @cloud_edition_billing_resource_check('vector_space')
  224. def patch(self, dataset_id, document_id, segment_id):
  225. # check dataset
  226. dataset_id = str(dataset_id)
  227. dataset = DatasetService.get_dataset(dataset_id)
  228. if not dataset:
  229. raise NotFound('Dataset not found.')
  230. # check user's model setting
  231. DatasetService.check_dataset_model_setting(dataset)
  232. # check document
  233. document_id = str(document_id)
  234. document = DocumentService.get_document(dataset_id, document_id)
  235. if not document:
  236. raise NotFound('Document not found.')
  237. if dataset.indexing_technique == 'high_quality':
  238. # check embedding model setting
  239. try:
  240. model_manager = ModelManager()
  241. model_manager.get_model_instance(
  242. tenant_id=current_user.current_tenant_id,
  243. provider=dataset.embedding_model_provider,
  244. model_type=ModelType.TEXT_EMBEDDING,
  245. model=dataset.embedding_model
  246. )
  247. except LLMBadRequestError:
  248. raise ProviderNotInitializeError(
  249. "No Embedding Model available. Please configure a valid provider "
  250. "in the Settings -> Model Provider.")
  251. except ProviderTokenNotInitError as ex:
  252. raise ProviderNotInitializeError(ex.description)
  253. # check segment
  254. segment_id = str(segment_id)
  255. segment = DocumentSegment.query.filter(
  256. DocumentSegment.id == str(segment_id),
  257. DocumentSegment.tenant_id == current_user.current_tenant_id
  258. ).first()
  259. if not segment:
  260. raise NotFound('Segment not found.')
  261. # The role of the current user in the ta table must be admin or owner
  262. if not current_user.is_admin_or_owner:
  263. raise Forbidden()
  264. try:
  265. DatasetService.check_dataset_permission(dataset, current_user)
  266. except services.errors.account.NoPermissionError as e:
  267. raise Forbidden(str(e))
  268. # validate args
  269. parser = reqparse.RequestParser()
  270. parser.add_argument('content', type=str, required=True, nullable=False, location='json')
  271. parser.add_argument('answer', type=str, required=False, nullable=True, location='json')
  272. parser.add_argument('keywords', type=list, required=False, nullable=True, location='json')
  273. args = parser.parse_args()
  274. SegmentService.segment_create_args_validate(args, document)
  275. segment = SegmentService.update_segment(args, segment, document, dataset)
  276. return {
  277. 'data': marshal(segment, segment_fields),
  278. 'doc_form': document.doc_form
  279. }, 200
  280. @setup_required
  281. @login_required
  282. @account_initialization_required
  283. def delete(self, dataset_id, document_id, segment_id):
  284. # check dataset
  285. dataset_id = str(dataset_id)
  286. dataset = DatasetService.get_dataset(dataset_id)
  287. if not dataset:
  288. raise NotFound('Dataset not found.')
  289. # check user's model setting
  290. DatasetService.check_dataset_model_setting(dataset)
  291. # check document
  292. document_id = str(document_id)
  293. document = DocumentService.get_document(dataset_id, document_id)
  294. if not document:
  295. raise NotFound('Document not found.')
  296. # check segment
  297. segment_id = str(segment_id)
  298. segment = DocumentSegment.query.filter(
  299. DocumentSegment.id == str(segment_id),
  300. DocumentSegment.tenant_id == current_user.current_tenant_id
  301. ).first()
  302. if not segment:
  303. raise NotFound('Segment not found.')
  304. # The role of the current user in the ta table must be admin or owner
  305. if not current_user.is_admin_or_owner:
  306. raise Forbidden()
  307. try:
  308. DatasetService.check_dataset_permission(dataset, current_user)
  309. except services.errors.account.NoPermissionError as e:
  310. raise Forbidden(str(e))
  311. SegmentService.delete_segment(segment, document, dataset)
  312. return {'result': 'success'}, 200
  313. class DatasetDocumentSegmentBatchImportApi(Resource):
  314. @setup_required
  315. @login_required
  316. @account_initialization_required
  317. @cloud_edition_billing_resource_check('vector_space')
  318. def post(self, dataset_id, document_id):
  319. # check dataset
  320. dataset_id = str(dataset_id)
  321. dataset = DatasetService.get_dataset(dataset_id)
  322. if not dataset:
  323. raise NotFound('Dataset not found.')
  324. # check document
  325. document_id = str(document_id)
  326. document = DocumentService.get_document(dataset_id, document_id)
  327. if not document:
  328. raise NotFound('Document not found.')
  329. # get file from request
  330. file = request.files['file']
  331. # check file
  332. if 'file' not in request.files:
  333. raise NoFileUploadedError()
  334. if len(request.files) > 1:
  335. raise TooManyFilesError()
  336. # check file type
  337. if not file.filename.endswith('.csv'):
  338. raise ValueError("Invalid file type. Only CSV files are allowed")
  339. try:
  340. # Skip the first row
  341. df = pd.read_csv(file)
  342. result = []
  343. for index, row in df.iterrows():
  344. if document.doc_form == 'qa_model':
  345. data = {'content': row[0], 'answer': row[1]}
  346. else:
  347. data = {'content': row[0]}
  348. result.append(data)
  349. if len(result) == 0:
  350. raise ValueError("The CSV file is empty.")
  351. # async job
  352. job_id = str(uuid.uuid4())
  353. indexing_cache_key = 'segment_batch_import_{}'.format(str(job_id))
  354. # send batch add segments task
  355. redis_client.setnx(indexing_cache_key, 'waiting')
  356. batch_create_segment_to_index_task.delay(str(job_id), result, dataset_id, document_id,
  357. current_user.current_tenant_id, current_user.id)
  358. except Exception as e:
  359. return {'error': str(e)}, 500
  360. return {
  361. 'job_id': job_id,
  362. 'job_status': 'waiting'
  363. }, 200
  364. @setup_required
  365. @login_required
  366. @account_initialization_required
  367. def get(self, job_id):
  368. job_id = str(job_id)
  369. indexing_cache_key = 'segment_batch_import_{}'.format(job_id)
  370. cache_result = redis_client.get(indexing_cache_key)
  371. if cache_result is None:
  372. raise ValueError("The job is not exist.")
  373. return {
  374. 'job_id': job_id,
  375. 'job_status': cache_result.decode()
  376. }, 200
  377. api.add_resource(DatasetDocumentSegmentListApi,
  378. '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments')
  379. api.add_resource(DatasetDocumentSegmentApi,
  380. '/datasets/<uuid:dataset_id>/segments/<uuid:segment_id>/<string:action>')
  381. api.add_resource(DatasetDocumentSegmentAddApi,
  382. '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment')
  383. api.add_resource(DatasetDocumentSegmentUpdateApi,
  384. '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>')
  385. api.add_resource(DatasetDocumentSegmentBatchImportApi,
  386. '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/batch_import',
  387. '/datasets/batch_import_status/<uuid:job_id>')