datasets_segments.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457
  1. # -*- coding:utf-8 -*-
  2. import uuid
  3. from datetime import datetime
  4. from flask import request
  5. from flask_login import current_user
  6. from flask_restful import Resource, reqparse, fields, marshal
  7. from werkzeug.exceptions import NotFound, Forbidden
  8. import services
  9. from controllers.console import api
  10. from controllers.console.app.error import ProviderNotInitializeError
  11. from controllers.console.datasets.error import InvalidActionError, NoFileUploadedError, TooManyFilesError
  12. from controllers.console.setup import setup_required
  13. from controllers.console.wraps import account_initialization_required
  14. from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
  15. from core.model_providers.model_factory import ModelFactory
  16. from core.login.login import login_required
  17. from extensions.ext_database import db
  18. from extensions.ext_redis import redis_client
  19. from models.dataset import DocumentSegment
  20. from libs.helper import TimestampField
  21. from services.dataset_service import DatasetService, DocumentService, SegmentService
  22. from tasks.enable_segment_to_index_task import enable_segment_to_index_task
  23. from tasks.disable_segment_from_index_task import disable_segment_from_index_task
  24. from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
  25. import pandas as pd
  26. segment_fields = {
  27. 'id': fields.String,
  28. 'position': fields.Integer,
  29. 'document_id': fields.String,
  30. 'content': fields.String,
  31. 'answer': fields.String,
  32. 'word_count': fields.Integer,
  33. 'tokens': fields.Integer,
  34. 'keywords': fields.List(fields.String),
  35. 'index_node_id': fields.String,
  36. 'index_node_hash': fields.String,
  37. 'hit_count': fields.Integer,
  38. 'enabled': fields.Boolean,
  39. 'disabled_at': TimestampField,
  40. 'disabled_by': fields.String,
  41. 'status': fields.String,
  42. 'created_by': fields.String,
  43. 'created_at': TimestampField,
  44. 'indexing_at': TimestampField,
  45. 'completed_at': TimestampField,
  46. 'error': fields.String,
  47. 'stopped_at': TimestampField
  48. }
  49. segment_list_response = {
  50. 'data': fields.List(fields.Nested(segment_fields)),
  51. 'has_more': fields.Boolean,
  52. 'limit': fields.Integer
  53. }
  54. class DatasetDocumentSegmentListApi(Resource):
  55. @setup_required
  56. @login_required
  57. @account_initialization_required
  58. def get(self, dataset_id, document_id):
  59. dataset_id = str(dataset_id)
  60. document_id = str(document_id)
  61. dataset = DatasetService.get_dataset(dataset_id)
  62. if not dataset:
  63. raise NotFound('Dataset not found.')
  64. try:
  65. DatasetService.check_dataset_permission(dataset, current_user)
  66. except services.errors.account.NoPermissionError as e:
  67. raise Forbidden(str(e))
  68. document = DocumentService.get_document(dataset_id, document_id)
  69. if not document:
  70. raise NotFound('Document not found.')
  71. parser = reqparse.RequestParser()
  72. parser.add_argument('last_id', type=str, default=None, location='args')
  73. parser.add_argument('limit', type=int, default=20, location='args')
  74. parser.add_argument('status', type=str,
  75. action='append', default=[], location='args')
  76. parser.add_argument('hit_count_gte', type=int,
  77. default=None, location='args')
  78. parser.add_argument('enabled', type=str, default='all', location='args')
  79. parser.add_argument('keyword', type=str, default=None, location='args')
  80. args = parser.parse_args()
  81. last_id = args['last_id']
  82. limit = min(args['limit'], 100)
  83. status_list = args['status']
  84. hit_count_gte = args['hit_count_gte']
  85. keyword = args['keyword']
  86. query = DocumentSegment.query.filter(
  87. DocumentSegment.document_id == str(document_id),
  88. DocumentSegment.tenant_id == current_user.current_tenant_id
  89. )
  90. if last_id is not None:
  91. last_segment = DocumentSegment.query.get(str(last_id))
  92. if last_segment:
  93. query = query.filter(
  94. DocumentSegment.position > last_segment.position)
  95. else:
  96. return {'data': [], 'has_more': False, 'limit': limit}, 200
  97. if status_list:
  98. query = query.filter(DocumentSegment.status.in_(status_list))
  99. if hit_count_gte is not None:
  100. query = query.filter(DocumentSegment.hit_count >= hit_count_gte)
  101. if keyword:
  102. query = query.where(DocumentSegment.content.ilike(f'%{keyword}%'))
  103. if args['enabled'].lower() != 'all':
  104. if args['enabled'].lower() == 'true':
  105. query = query.filter(DocumentSegment.enabled == True)
  106. elif args['enabled'].lower() == 'false':
  107. query = query.filter(DocumentSegment.enabled == False)
  108. total = query.count()
  109. segments = query.order_by(DocumentSegment.position).limit(limit + 1).all()
  110. has_more = False
  111. if len(segments) > limit:
  112. has_more = True
  113. segments = segments[:-1]
  114. return {
  115. 'data': marshal(segments, segment_fields),
  116. 'doc_form': document.doc_form,
  117. 'has_more': has_more,
  118. 'limit': limit,
  119. 'total': total
  120. }, 200
  121. class DatasetDocumentSegmentApi(Resource):
  122. @setup_required
  123. @login_required
  124. @account_initialization_required
  125. def patch(self, dataset_id, segment_id, action):
  126. dataset_id = str(dataset_id)
  127. dataset = DatasetService.get_dataset(dataset_id)
  128. if not dataset:
  129. raise NotFound('Dataset not found.')
  130. # The role of the current user in the ta table must be admin or owner
  131. if current_user.current_tenant.current_role not in ['admin', 'owner']:
  132. raise Forbidden()
  133. try:
  134. DatasetService.check_dataset_permission(dataset, current_user)
  135. except services.errors.account.NoPermissionError as e:
  136. raise Forbidden(str(e))
  137. # check embedding model setting
  138. try:
  139. ModelFactory.get_embedding_model(
  140. tenant_id=current_user.current_tenant_id,
  141. model_provider_name=dataset.embedding_model_provider,
  142. model_name=dataset.embedding_model
  143. )
  144. except LLMBadRequestError:
  145. raise ProviderNotInitializeError(
  146. f"No Embedding Model available. Please configure a valid provider "
  147. f"in the Settings -> Model Provider.")
  148. except ProviderTokenNotInitError as ex:
  149. raise ProviderNotInitializeError(ex.description)
  150. segment = DocumentSegment.query.filter(
  151. DocumentSegment.id == str(segment_id),
  152. DocumentSegment.tenant_id == current_user.current_tenant_id
  153. ).first()
  154. if not segment:
  155. raise NotFound('Segment not found.')
  156. document_indexing_cache_key = 'document_{}_indexing'.format(segment.document_id)
  157. cache_result = redis_client.get(document_indexing_cache_key)
  158. if cache_result is not None:
  159. raise InvalidActionError("Document is being indexed, please try again later")
  160. indexing_cache_key = 'segment_{}_indexing'.format(segment.id)
  161. cache_result = redis_client.get(indexing_cache_key)
  162. if cache_result is not None:
  163. raise InvalidActionError("Segment is being indexed, please try again later")
  164. if action == "enable":
  165. if segment.enabled:
  166. raise InvalidActionError("Segment is already enabled.")
  167. segment.enabled = True
  168. segment.disabled_at = None
  169. segment.disabled_by = None
  170. db.session.commit()
  171. # Set cache to prevent indexing the same segment multiple times
  172. redis_client.setex(indexing_cache_key, 600, 1)
  173. enable_segment_to_index_task.delay(segment.id)
  174. return {'result': 'success'}, 200
  175. elif action == "disable":
  176. if not segment.enabled:
  177. raise InvalidActionError("Segment is already disabled.")
  178. segment.enabled = False
  179. segment.disabled_at = datetime.utcnow()
  180. segment.disabled_by = current_user.id
  181. db.session.commit()
  182. # Set cache to prevent indexing the same segment multiple times
  183. redis_client.setex(indexing_cache_key, 600, 1)
  184. disable_segment_from_index_task.delay(segment.id)
  185. return {'result': 'success'}, 200
  186. else:
  187. raise InvalidActionError()
  188. class DatasetDocumentSegmentAddApi(Resource):
  189. @setup_required
  190. @login_required
  191. @account_initialization_required
  192. def post(self, dataset_id, document_id):
  193. # check dataset
  194. dataset_id = str(dataset_id)
  195. dataset = DatasetService.get_dataset(dataset_id)
  196. if not dataset:
  197. raise NotFound('Dataset not found.')
  198. # check document
  199. document_id = str(document_id)
  200. document = DocumentService.get_document(dataset_id, document_id)
  201. if not document:
  202. raise NotFound('Document not found.')
  203. # The role of the current user in the ta table must be admin or owner
  204. if current_user.current_tenant.current_role not in ['admin', 'owner']:
  205. raise Forbidden()
  206. # check embedding model setting
  207. try:
  208. ModelFactory.get_embedding_model(
  209. tenant_id=current_user.current_tenant_id,
  210. model_provider_name=dataset.embedding_model_provider,
  211. model_name=dataset.embedding_model
  212. )
  213. except LLMBadRequestError:
  214. raise ProviderNotInitializeError(
  215. f"No Embedding Model available. Please configure a valid provider "
  216. f"in the Settings -> Model Provider.")
  217. except ProviderTokenNotInitError as ex:
  218. raise ProviderNotInitializeError(ex.description)
  219. try:
  220. DatasetService.check_dataset_permission(dataset, current_user)
  221. except services.errors.account.NoPermissionError as e:
  222. raise Forbidden(str(e))
  223. # validate args
  224. parser = reqparse.RequestParser()
  225. parser.add_argument('content', type=str, required=True, nullable=False, location='json')
  226. parser.add_argument('answer', type=str, required=False, nullable=True, location='json')
  227. parser.add_argument('keywords', type=list, required=False, nullable=True, location='json')
  228. args = parser.parse_args()
  229. SegmentService.segment_create_args_validate(args, document)
  230. segment = SegmentService.create_segment(args, document, dataset)
  231. return {
  232. 'data': marshal(segment, segment_fields),
  233. 'doc_form': document.doc_form
  234. }, 200
  235. class DatasetDocumentSegmentUpdateApi(Resource):
  236. @setup_required
  237. @login_required
  238. @account_initialization_required
  239. def patch(self, dataset_id, document_id, segment_id):
  240. # check dataset
  241. dataset_id = str(dataset_id)
  242. dataset = DatasetService.get_dataset(dataset_id)
  243. if not dataset:
  244. raise NotFound('Dataset not found.')
  245. # check document
  246. document_id = str(document_id)
  247. document = DocumentService.get_document(dataset_id, document_id)
  248. if not document:
  249. raise NotFound('Document not found.')
  250. # check embedding model setting
  251. try:
  252. ModelFactory.get_embedding_model(
  253. tenant_id=current_user.current_tenant_id,
  254. model_provider_name=dataset.embedding_model_provider,
  255. model_name=dataset.embedding_model
  256. )
  257. except LLMBadRequestError:
  258. raise ProviderNotInitializeError(
  259. f"No Embedding Model available. Please configure a valid provider "
  260. f"in the Settings -> Model Provider.")
  261. except ProviderTokenNotInitError as ex:
  262. raise ProviderNotInitializeError(ex.description)
  263. # check segment
  264. segment_id = str(segment_id)
  265. segment = DocumentSegment.query.filter(
  266. DocumentSegment.id == str(segment_id),
  267. DocumentSegment.tenant_id == current_user.current_tenant_id
  268. ).first()
  269. if not segment:
  270. raise NotFound('Segment not found.')
  271. # The role of the current user in the ta table must be admin or owner
  272. if current_user.current_tenant.current_role not in ['admin', 'owner']:
  273. raise Forbidden()
  274. try:
  275. DatasetService.check_dataset_permission(dataset, current_user)
  276. except services.errors.account.NoPermissionError as e:
  277. raise Forbidden(str(e))
  278. # validate args
  279. parser = reqparse.RequestParser()
  280. parser.add_argument('content', type=str, required=True, nullable=False, location='json')
  281. parser.add_argument('answer', type=str, required=False, nullable=True, location='json')
  282. parser.add_argument('keywords', type=list, required=False, nullable=True, location='json')
  283. args = parser.parse_args()
  284. SegmentService.segment_create_args_validate(args, document)
  285. segment = SegmentService.update_segment(args, segment, document, dataset)
  286. return {
  287. 'data': marshal(segment, segment_fields),
  288. 'doc_form': document.doc_form
  289. }, 200
  290. @setup_required
  291. @login_required
  292. @account_initialization_required
  293. def delete(self, dataset_id, document_id, segment_id):
  294. # check dataset
  295. dataset_id = str(dataset_id)
  296. dataset = DatasetService.get_dataset(dataset_id)
  297. if not dataset:
  298. raise NotFound('Dataset not found.')
  299. # check document
  300. document_id = str(document_id)
  301. document = DocumentService.get_document(dataset_id, document_id)
  302. if not document:
  303. raise NotFound('Document not found.')
  304. # check segment
  305. segment_id = str(segment_id)
  306. segment = DocumentSegment.query.filter(
  307. DocumentSegment.id == str(segment_id),
  308. DocumentSegment.tenant_id == current_user.current_tenant_id
  309. ).first()
  310. if not segment:
  311. raise NotFound('Segment not found.')
  312. # The role of the current user in the ta table must be admin or owner
  313. if current_user.current_tenant.current_role not in ['admin', 'owner']:
  314. raise Forbidden()
  315. try:
  316. DatasetService.check_dataset_permission(dataset, current_user)
  317. except services.errors.account.NoPermissionError as e:
  318. raise Forbidden(str(e))
  319. SegmentService.delete_segment(segment, document, dataset)
  320. return {'result': 'success'}, 200
  321. class DatasetDocumentSegmentBatchImportApi(Resource):
  322. @setup_required
  323. @login_required
  324. @account_initialization_required
  325. def post(self, dataset_id, document_id):
  326. # check dataset
  327. dataset_id = str(dataset_id)
  328. dataset = DatasetService.get_dataset(dataset_id)
  329. if not dataset:
  330. raise NotFound('Dataset not found.')
  331. # check document
  332. document_id = str(document_id)
  333. document = DocumentService.get_document(dataset_id, document_id)
  334. if not document:
  335. raise NotFound('Document not found.')
  336. try:
  337. ModelFactory.get_embedding_model(
  338. tenant_id=current_user.current_tenant_id,
  339. model_provider_name=dataset.embedding_model_provider,
  340. model_name=dataset.embedding_model
  341. )
  342. except LLMBadRequestError:
  343. raise ProviderNotInitializeError(
  344. f"No Embedding Model available. Please configure a valid provider "
  345. f"in the Settings -> Model Provider.")
  346. except ProviderTokenNotInitError as ex:
  347. raise ProviderNotInitializeError(ex.description)
  348. # get file from request
  349. file = request.files['file']
  350. # check file
  351. if 'file' not in request.files:
  352. raise NoFileUploadedError()
  353. if len(request.files) > 1:
  354. raise TooManyFilesError()
  355. # check file type
  356. if not file.filename.endswith('.csv'):
  357. raise ValueError("Invalid file type. Only CSV files are allowed")
  358. try:
  359. # Skip the first row
  360. df = pd.read_csv(file)
  361. result = []
  362. for index, row in df.iterrows():
  363. if document.doc_form == 'qa_model':
  364. data = {'content': row[0], 'answer': row[1]}
  365. else:
  366. data = {'content': row[0]}
  367. result.append(data)
  368. if len(result) == 0:
  369. raise ValueError("The CSV file is empty.")
  370. # async job
  371. job_id = str(uuid.uuid4())
  372. indexing_cache_key = 'segment_batch_import_{}'.format(str(job_id))
  373. # send batch add segments task
  374. redis_client.setnx(indexing_cache_key, 'waiting')
  375. batch_create_segment_to_index_task.delay(str(job_id), result, dataset_id, document_id,
  376. current_user.current_tenant_id, current_user.id)
  377. except Exception as e:
  378. return {'error': str(e)}, 500
  379. return {
  380. 'job_id': job_id,
  381. 'job_status': 'waiting'
  382. }, 200
  383. @setup_required
  384. @login_required
  385. @account_initialization_required
  386. def get(self, job_id):
  387. job_id = str(job_id)
  388. indexing_cache_key = 'segment_batch_import_{}'.format(job_id)
  389. cache_result = redis_client.get(indexing_cache_key)
  390. if cache_result is None:
  391. raise ValueError("The job is not exist.")
  392. return {
  393. 'job_id': job_id,
  394. 'job_status': cache_result.decode()
  395. }, 200
  396. api.add_resource(DatasetDocumentSegmentListApi,
  397. '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments')
  398. api.add_resource(DatasetDocumentSegmentApi,
  399. '/datasets/<uuid:dataset_id>/segments/<uuid:segment_id>/<string:action>')
  400. api.add_resource(DatasetDocumentSegmentAddApi,
  401. '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment')
  402. api.add_resource(DatasetDocumentSegmentUpdateApi,
  403. '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>')
  404. api.add_resource(DatasetDocumentSegmentBatchImportApi,
  405. '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/batch_import',
  406. '/datasets/batch_import_status/<uuid:job_id>')