segment.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. from flask_login import current_user
  2. from flask_restful import reqparse, marshal
  3. from werkzeug.exceptions import NotFound
  4. from controllers.service_api import api
  5. from controllers.service_api.app.error import ProviderNotInitializeError
  6. from controllers.service_api.wraps import DatasetApiResource
  7. from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError
  8. from core.model_providers.model_factory import ModelFactory
  9. from extensions.ext_database import db
  10. from fields.segment_fields import segment_fields
  11. from models.dataset import Dataset
  12. from services.dataset_service import DocumentService, SegmentService
  13. class SegmentApi(DatasetApiResource):
  14. """Resource for segments."""
  15. def post(self, tenant_id, dataset_id, document_id):
  16. """Create single segment."""
  17. # check dataset
  18. dataset_id = str(dataset_id)
  19. tenant_id = str(tenant_id)
  20. dataset = db.session.query(Dataset).filter(
  21. Dataset.tenant_id == tenant_id,
  22. Dataset.id == dataset_id
  23. ).first()
  24. # check document
  25. document_id = str(document_id)
  26. document = DocumentService.get_document(dataset.id, document_id)
  27. if not document:
  28. raise NotFound('Document not found.')
  29. # check embedding model setting
  30. if dataset.indexing_technique == 'high_quality':
  31. try:
  32. ModelFactory.get_embedding_model(
  33. tenant_id=current_user.current_tenant_id,
  34. model_provider_name=dataset.embedding_model_provider,
  35. model_name=dataset.embedding_model
  36. )
  37. except LLMBadRequestError:
  38. raise ProviderNotInitializeError(
  39. f"No Embedding Model available. Please configure a valid provider "
  40. f"in the Settings -> Model Provider.")
  41. except ProviderTokenNotInitError as ex:
  42. raise ProviderNotInitializeError(ex.description)
  43. # validate args
  44. parser = reqparse.RequestParser()
  45. parser.add_argument('segments', type=list, required=False, nullable=True, location='json')
  46. args = parser.parse_args()
  47. for args_item in args['segments']:
  48. SegmentService.segment_create_args_validate(args_item, document)
  49. segments = SegmentService.multi_create_segment(args['segments'], document, dataset)
  50. return {
  51. 'data': marshal(segments, segment_fields),
  52. 'doc_form': document.doc_form
  53. }, 200
  54. api.add_resource(SegmentApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments')