hit_testing.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. import logging
  2. import services
  3. from controllers.console import api
  4. from controllers.console.app.error import (CompletionRequestError, ProviderModelCurrentlyNotSupportError,
  5. ProviderNotInitializeError, ProviderQuotaExceededError)
  6. from controllers.console.datasets.error import DatasetNotInitializedError, HighQualityDatasetOnlyError
  7. from controllers.console.setup import setup_required
  8. from controllers.console.wraps import account_initialization_required
  9. from core.errors.error import (LLMBadRequestError, ModelCurrentlyNotSupportError, ProviderTokenNotInitError,
  10. QuotaExceededError)
  11. from core.model_runtime.errors.invoke import InvokeError
  12. from fields.hit_testing_fields import hit_testing_record_fields
  13. from flask_login import current_user
  14. from flask_restful import Resource, marshal, reqparse
  15. from libs.login import login_required
  16. from services.dataset_service import DatasetService
  17. from services.hit_testing_service import HitTestingService
  18. from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
  19. class HitTestingApi(Resource):
  20. @setup_required
  21. @login_required
  22. @account_initialization_required
  23. def post(self, dataset_id):
  24. dataset_id_str = str(dataset_id)
  25. dataset = DatasetService.get_dataset(dataset_id_str)
  26. if dataset is None:
  27. raise NotFound("Dataset not found.")
  28. try:
  29. DatasetService.check_dataset_permission(dataset, current_user)
  30. except services.errors.account.NoPermissionError as e:
  31. raise Forbidden(str(e))
  32. # only high quality dataset can be used for hit testing
  33. if dataset.indexing_technique != 'high_quality':
  34. raise HighQualityDatasetOnlyError()
  35. parser = reqparse.RequestParser()
  36. parser.add_argument('query', type=str, location='json')
  37. parser.add_argument('retrieval_model', type=dict, required=False, location='json')
  38. args = parser.parse_args()
  39. HitTestingService.hit_testing_args_check(args)
  40. try:
  41. response = HitTestingService.retrieve(
  42. dataset=dataset,
  43. query=args['query'],
  44. account=current_user,
  45. retrieval_model=args['retrieval_model'],
  46. limit=10
  47. )
  48. return {"query": response['query'], 'records': marshal(response['records'], hit_testing_record_fields)}
  49. except services.errors.index.IndexNotInitializedError:
  50. raise DatasetNotInitializedError()
  51. except ProviderTokenNotInitError as ex:
  52. raise ProviderNotInitializeError(ex.description)
  53. except QuotaExceededError:
  54. raise ProviderQuotaExceededError()
  55. except ModelCurrentlyNotSupportError:
  56. raise ProviderModelCurrentlyNotSupportError()
  57. except LLMBadRequestError:
  58. raise ProviderNotInitializeError(
  59. f"No Embedding Model or Reranking Model available. Please configure a valid provider "
  60. f"in the Settings -> Model Provider.")
  61. except InvokeError as e:
  62. raise CompletionRequestError(e.description)
  63. except ValueError as e:
  64. raise ValueError(str(e))
  65. except Exception as e:
  66. logging.exception("Hit testing failed.")
  67. raise InternalServerError(str(e))
  68. api.add_resource(HitTestingApi, '/datasets/<uuid:dataset_id>/hit-testing')