datasets.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341
  1. # -*- coding:utf-8 -*-
  2. from flask import request
  3. from flask_login import login_required, current_user
  4. from flask_restful import Resource, reqparse, fields, marshal, marshal_with
  5. from werkzeug.exceptions import NotFound, Forbidden
  6. import services
  7. from controllers.console import api
  8. from controllers.console.datasets.error import DatasetNameDuplicateError
  9. from controllers.console.setup import setup_required
  10. from controllers.console.wraps import account_initialization_required
  11. from core.indexing_runner import IndexingRunner
  12. from libs.helper import TimestampField
  13. from extensions.ext_database import db
  14. from models.dataset import DocumentSegment, Document
  15. from models.model import UploadFile
  16. from services.dataset_service import DatasetService, DocumentService
  17. dataset_detail_fields = {
  18. 'id': fields.String,
  19. 'name': fields.String,
  20. 'description': fields.String,
  21. 'provider': fields.String,
  22. 'permission': fields.String,
  23. 'data_source_type': fields.String,
  24. 'indexing_technique': fields.String,
  25. 'app_count': fields.Integer,
  26. 'document_count': fields.Integer,
  27. 'word_count': fields.Integer,
  28. 'created_by': fields.String,
  29. 'created_at': TimestampField,
  30. 'updated_by': fields.String,
  31. 'updated_at': TimestampField,
  32. }
  33. dataset_query_detail_fields = {
  34. "id": fields.String,
  35. "content": fields.String,
  36. "source": fields.String,
  37. "source_app_id": fields.String,
  38. "created_by_role": fields.String,
  39. "created_by": fields.String,
  40. "created_at": TimestampField
  41. }
  42. def _validate_name(name):
  43. if not name or len(name) < 1 or len(name) > 40:
  44. raise ValueError('Name must be between 1 to 40 characters.')
  45. return name
  46. def _validate_description_length(description):
  47. if len(description) > 400:
  48. raise ValueError('Description cannot exceed 400 characters.')
  49. return description
  50. class DatasetListApi(Resource):
  51. @setup_required
  52. @login_required
  53. @account_initialization_required
  54. def get(self):
  55. page = request.args.get('page', default=1, type=int)
  56. limit = request.args.get('limit', default=20, type=int)
  57. ids = request.args.getlist('ids')
  58. provider = request.args.get('provider', default="vendor")
  59. if ids:
  60. datasets, total = DatasetService.get_datasets_by_ids(ids, current_user.current_tenant_id)
  61. else:
  62. datasets, total = DatasetService.get_datasets(page, limit, provider,
  63. current_user.current_tenant_id, current_user)
  64. response = {
  65. 'data': marshal(datasets, dataset_detail_fields),
  66. 'has_more': len(datasets) == limit,
  67. 'limit': limit,
  68. 'total': total,
  69. 'page': page
  70. }
  71. return response, 200
  72. @setup_required
  73. @login_required
  74. @account_initialization_required
  75. def post(self):
  76. parser = reqparse.RequestParser()
  77. parser.add_argument('name', nullable=False, required=True,
  78. help='type is required. Name must be between 1 to 40 characters.',
  79. type=_validate_name)
  80. parser.add_argument('indexing_technique', type=str, location='json',
  81. choices=('high_quality', 'economy'),
  82. help='Invalid indexing technique.')
  83. args = parser.parse_args()
  84. # The role of the current user in the ta table must be admin or owner
  85. if current_user.current_tenant.current_role not in ['admin', 'owner']:
  86. raise Forbidden()
  87. try:
  88. dataset = DatasetService.create_empty_dataset(
  89. tenant_id=current_user.current_tenant_id,
  90. name=args['name'],
  91. indexing_technique=args['indexing_technique'],
  92. account=current_user
  93. )
  94. except services.errors.dataset.DatasetNameDuplicateError:
  95. raise DatasetNameDuplicateError()
  96. return marshal(dataset, dataset_detail_fields), 201
  97. class DatasetApi(Resource):
  98. @setup_required
  99. @login_required
  100. @account_initialization_required
  101. def get(self, dataset_id):
  102. dataset_id_str = str(dataset_id)
  103. dataset = DatasetService.get_dataset(dataset_id_str)
  104. if dataset is None:
  105. raise NotFound("Dataset not found.")
  106. try:
  107. DatasetService.check_dataset_permission(
  108. dataset, current_user)
  109. except services.errors.account.NoPermissionError as e:
  110. raise Forbidden(str(e))
  111. return marshal(dataset, dataset_detail_fields), 200
  112. @setup_required
  113. @login_required
  114. @account_initialization_required
  115. def patch(self, dataset_id):
  116. dataset_id_str = str(dataset_id)
  117. parser = reqparse.RequestParser()
  118. parser.add_argument('name', nullable=False,
  119. help='type is required. Name must be between 1 to 40 characters.',
  120. type=_validate_name)
  121. parser.add_argument('description',
  122. location='json', store_missing=False,
  123. type=_validate_description_length)
  124. parser.add_argument('indexing_technique', type=str, location='json',
  125. choices=('high_quality', 'economy'),
  126. help='Invalid indexing technique.')
  127. parser.add_argument('permission', type=str, location='json', choices=(
  128. 'only_me', 'all_team_members'), help='Invalid permission.')
  129. args = parser.parse_args()
  130. # The role of the current user in the ta table must be admin or owner
  131. if current_user.current_tenant.current_role not in ['admin', 'owner']:
  132. raise Forbidden()
  133. dataset = DatasetService.update_dataset(
  134. dataset_id_str, args, current_user)
  135. if dataset is None:
  136. raise NotFound("Dataset not found.")
  137. return marshal(dataset, dataset_detail_fields), 200
  138. @setup_required
  139. @login_required
  140. @account_initialization_required
  141. def delete(self, dataset_id):
  142. dataset_id_str = str(dataset_id)
  143. # The role of the current user in the ta table must be admin or owner
  144. if current_user.current_tenant.current_role not in ['admin', 'owner']:
  145. raise Forbidden()
  146. if DatasetService.delete_dataset(dataset_id_str, current_user):
  147. return {'result': 'success'}, 204
  148. else:
  149. raise NotFound("Dataset not found.")
  150. class DatasetQueryApi(Resource):
  151. @setup_required
  152. @login_required
  153. @account_initialization_required
  154. def get(self, dataset_id):
  155. dataset_id_str = str(dataset_id)
  156. dataset = DatasetService.get_dataset(dataset_id_str)
  157. if dataset is None:
  158. raise NotFound("Dataset not found.")
  159. try:
  160. DatasetService.check_dataset_permission(dataset, current_user)
  161. except services.errors.account.NoPermissionError as e:
  162. raise Forbidden(str(e))
  163. page = request.args.get('page', default=1, type=int)
  164. limit = request.args.get('limit', default=20, type=int)
  165. dataset_queries, total = DatasetService.get_dataset_queries(
  166. dataset_id=dataset.id,
  167. page=page,
  168. per_page=limit
  169. )
  170. response = {
  171. 'data': marshal(dataset_queries, dataset_query_detail_fields),
  172. 'has_more': len(dataset_queries) == limit,
  173. 'limit': limit,
  174. 'total': total,
  175. 'page': page
  176. }
  177. return response, 200
  178. class DatasetIndexingEstimateApi(Resource):
  179. @setup_required
  180. @login_required
  181. @account_initialization_required
  182. def post(self):
  183. parser = reqparse.RequestParser()
  184. parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json')
  185. parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
  186. args = parser.parse_args()
  187. # validate args
  188. DocumentService.estimate_args_validate(args)
  189. if args['info_list']['data_source_type'] == 'upload_file':
  190. file_ids = args['info_list']['file_info_list']['file_ids']
  191. file_details = db.session.query(UploadFile).filter(
  192. UploadFile.tenant_id == current_user.current_tenant_id,
  193. UploadFile.id.in_(file_ids)
  194. ).all()
  195. if file_details is None:
  196. raise NotFound("File not found.")
  197. indexing_runner = IndexingRunner()
  198. response = indexing_runner.file_indexing_estimate(file_details, args['process_rule'])
  199. elif args['info_list']['data_source_type'] == 'notion_import':
  200. indexing_runner = IndexingRunner()
  201. response = indexing_runner.notion_indexing_estimate(args['info_list']['notion_info_list'],
  202. args['process_rule'])
  203. else:
  204. raise ValueError('Data source type not support')
  205. return response, 200
  206. class DatasetRelatedAppListApi(Resource):
  207. app_detail_kernel_fields = {
  208. 'id': fields.String,
  209. 'name': fields.String,
  210. 'mode': fields.String,
  211. 'icon': fields.String,
  212. 'icon_background': fields.String,
  213. }
  214. related_app_list = {
  215. 'data': fields.List(fields.Nested(app_detail_kernel_fields)),
  216. 'total': fields.Integer,
  217. }
  218. @setup_required
  219. @login_required
  220. @account_initialization_required
  221. @marshal_with(related_app_list)
  222. def get(self, dataset_id):
  223. dataset_id_str = str(dataset_id)
  224. dataset = DatasetService.get_dataset(dataset_id_str)
  225. if dataset is None:
  226. raise NotFound("Dataset not found.")
  227. try:
  228. DatasetService.check_dataset_permission(dataset, current_user)
  229. except services.errors.account.NoPermissionError as e:
  230. raise Forbidden(str(e))
  231. app_dataset_joins = DatasetService.get_related_apps(dataset.id)
  232. related_apps = []
  233. for app_dataset_join in app_dataset_joins:
  234. app_model = app_dataset_join.app
  235. if app_model:
  236. related_apps.append(app_model)
  237. return {
  238. 'data': related_apps,
  239. 'total': len(related_apps)
  240. }, 200
  241. class DatasetIndexingStatusApi(Resource):
  242. document_status_fields = {
  243. 'id': fields.String,
  244. 'indexing_status': fields.String,
  245. 'processing_started_at': TimestampField,
  246. 'parsing_completed_at': TimestampField,
  247. 'cleaning_completed_at': TimestampField,
  248. 'splitting_completed_at': TimestampField,
  249. 'completed_at': TimestampField,
  250. 'paused_at': TimestampField,
  251. 'error': fields.String,
  252. 'stopped_at': TimestampField,
  253. 'completed_segments': fields.Integer,
  254. 'total_segments': fields.Integer,
  255. }
  256. document_status_fields_list = {
  257. 'data': fields.List(fields.Nested(document_status_fields))
  258. }
  259. @setup_required
  260. @login_required
  261. @account_initialization_required
  262. def get(self, dataset_id):
  263. dataset_id = str(dataset_id)
  264. documents = db.session.query(Document).filter(
  265. Document.dataset_id == dataset_id,
  266. Document.tenant_id == current_user.current_tenant_id
  267. ).all()
  268. documents_status = []
  269. for document in documents:
  270. completed_segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None),
  271. DocumentSegment.document_id == str(document.id),
  272. DocumentSegment.status != 're_segment').count()
  273. total_segments = DocumentSegment.query.filter(DocumentSegment.document_id == str(document.id),
  274. DocumentSegment.status != 're_segment').count()
  275. document.completed_segments = completed_segments
  276. document.total_segments = total_segments
  277. documents_status.append(marshal(document, self.document_status_fields))
  278. data = {
  279. 'data': documents_status
  280. }
  281. return data
  282. api.add_resource(DatasetListApi, '/datasets')
  283. api.add_resource(DatasetApi, '/datasets/<uuid:dataset_id>')
  284. api.add_resource(DatasetQueryApi, '/datasets/<uuid:dataset_id>/queries')
  285. api.add_resource(DatasetIndexingEstimateApi, '/datasets/indexing-estimate')
  286. api.add_resource(DatasetRelatedAppListApi, '/datasets/<uuid:dataset_id>/related-apps')
  287. api.add_resource(DatasetIndexingStatusApi, '/datasets/<uuid:dataset_id>/indexing-status')