# -*- coding:utf-8 -*- from flask import request from flask_login import login_required, current_user from flask_restful import Resource, reqparse, fields, marshal, marshal_with from werkzeug.exceptions import NotFound, Forbidden import services from controllers.console import api from controllers.console.datasets.error import DatasetNameDuplicateError from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from core.indexing_runner import IndexingRunner from libs.helper import TimestampField from extensions.ext_database import db from models.model import UploadFile from services.dataset_service import DatasetService dataset_detail_fields = { 'id': fields.String, 'name': fields.String, 'description': fields.String, 'provider': fields.String, 'permission': fields.String, 'data_source_type': fields.String, 'indexing_technique': fields.String, 'app_count': fields.Integer, 'document_count': fields.Integer, 'word_count': fields.Integer, 'created_by': fields.String, 'created_at': TimestampField, 'updated_by': fields.String, 'updated_at': TimestampField, } dataset_query_detail_fields = { "id": fields.String, "content": fields.String, "source": fields.String, "source_app_id": fields.String, "created_by_role": fields.String, "created_by": fields.String, "created_at": TimestampField } def _validate_name(name): if not name or len(name) < 1 or len(name) > 40: raise ValueError('Name must be between 1 to 40 characters.') return name def _validate_description_length(description): if len(description) > 200: raise ValueError('Description cannot exceed 200 characters.') return description class DatasetListApi(Resource): @setup_required @login_required @account_initialization_required def get(self): page = request.args.get('page', default=1, type=int) limit = request.args.get('limit', default=20, type=int) ids = request.args.getlist('ids') provider = request.args.get('provider', default="vendor") if ids: datasets, total = DatasetService.get_datasets_by_ids(ids, current_user.current_tenant_id) else: datasets, total = DatasetService.get_datasets(page, limit, provider, current_user.current_tenant_id, current_user) response = { 'data': marshal(datasets, dataset_detail_fields), 'has_more': len(datasets) == limit, 'limit': limit, 'total': total, 'page': page } return response, 200 @setup_required @login_required @account_initialization_required def post(self): parser = reqparse.RequestParser() parser.add_argument('name', nullable=False, required=True, help='type is required. Name must be between 1 to 40 characters.', type=_validate_name) parser.add_argument('indexing_technique', type=str, location='json', choices=('high_quality', 'economy'), help='Invalid indexing technique.') args = parser.parse_args() # The role of the current user in the ta table must be admin or owner if current_user.current_tenant.current_role not in ['admin', 'owner']: raise Forbidden() try: dataset = DatasetService.create_empty_dataset( tenant_id=current_user.current_tenant_id, name=args['name'], indexing_technique=args['indexing_technique'], account=current_user ) except services.errors.dataset.DatasetNameDuplicateError: raise DatasetNameDuplicateError() return marshal(dataset, dataset_detail_fields), 201 class DatasetApi(Resource): @setup_required @login_required @account_initialization_required def get(self, dataset_id): dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: raise NotFound("Dataset not found.") try: DatasetService.check_dataset_permission( dataset, current_user) except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) return marshal(dataset, dataset_detail_fields), 200 @setup_required @login_required @account_initialization_required def patch(self, dataset_id): dataset_id_str = str(dataset_id) parser = reqparse.RequestParser() parser.add_argument('name', nullable=False, help='type is required. Name must be between 1 to 40 characters.', type=_validate_name) parser.add_argument('description', location='json', store_missing=False, type=_validate_description_length) parser.add_argument('indexing_technique', type=str, location='json', choices=('high_quality', 'economy'), help='Invalid indexing technique.') parser.add_argument('permission', type=str, location='json', choices=( 'only_me', 'all_team_members'), help='Invalid permission.') args = parser.parse_args() # The role of the current user in the ta table must be admin or owner if current_user.current_tenant.current_role not in ['admin', 'owner']: raise Forbidden() dataset = DatasetService.update_dataset( dataset_id_str, args, current_user) if dataset is None: raise NotFound("Dataset not found.") return marshal(dataset, dataset_detail_fields), 200 @setup_required @login_required @account_initialization_required def delete(self, dataset_id): dataset_id_str = str(dataset_id) # The role of the current user in the ta table must be admin or owner if current_user.current_tenant.current_role not in ['admin', 'owner']: raise Forbidden() if DatasetService.delete_dataset(dataset_id_str, current_user): return {'result': 'success'}, 204 else: raise NotFound("Dataset not found.") class DatasetQueryApi(Resource): @setup_required @login_required @account_initialization_required def get(self, dataset_id): dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: raise NotFound("Dataset not found.") try: DatasetService.check_dataset_permission(dataset, current_user) except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) page = request.args.get('page', default=1, type=int) limit = request.args.get('limit', default=20, type=int) dataset_queries, total = DatasetService.get_dataset_queries( dataset_id=dataset.id, page=page, per_page=limit ) response = { 'data': marshal(dataset_queries, dataset_query_detail_fields), 'has_more': len(dataset_queries) == limit, 'limit': limit, 'total': total, 'page': page } return response, 200 class DatasetIndexingEstimateApi(Resource): @setup_required @login_required @account_initialization_required def post(self): segment_rule = request.get_json() file_detail = db.session.query(UploadFile).filter( UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id == segment_rule["file_id"] ).first() if file_detail is None: raise NotFound("File not found.") indexing_runner = IndexingRunner() response = indexing_runner.indexing_estimate(file_detail, segment_rule['process_rule']) return response, 200 class DatasetRelatedAppListApi(Resource): app_detail_kernel_fields = { 'id': fields.String, 'name': fields.String, 'mode': fields.String, 'icon': fields.String, 'icon_background': fields.String, } related_app_list = { 'data': fields.List(fields.Nested(app_detail_kernel_fields)), 'total': fields.Integer, } @setup_required @login_required @account_initialization_required @marshal_with(related_app_list) def get(self, dataset_id): dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: raise NotFound("Dataset not found.") try: DatasetService.check_dataset_permission(dataset, current_user) except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) app_dataset_joins = DatasetService.get_related_apps(dataset.id) related_apps = [] for app_dataset_join in app_dataset_joins: app_model = app_dataset_join.app if app_model: related_apps.append(app_model) return { 'data': related_apps, 'total': len(related_apps) }, 200 api.add_resource(DatasetListApi, '/datasets') api.add_resource(DatasetApi, '/datasets/') api.add_resource(DatasetQueryApi, '/datasets//queries') api.add_resource(DatasetIndexingEstimateApi, '/datasets/file-indexing-estimate') api.add_resource(DatasetRelatedAppListApi, '/datasets//related-apps')