batch_import_annotations_task.py 3.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. import logging
  2. import time
  3. import click
  4. from celery import shared_task
  5. from werkzeug.exceptions import NotFound
  6. from core.rag.datasource.vdb.vector_factory import Vector
  7. from core.rag.models.document import Document
  8. from extensions.ext_database import db
  9. from extensions.ext_redis import redis_client
  10. from models.dataset import Dataset
  11. from models.model import App, AppAnnotationSetting, MessageAnnotation
  12. from services.dataset_service import DatasetCollectionBindingService
  13. @shared_task(queue='dataset')
  14. def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id: str, tenant_id: str,
  15. user_id: str):
  16. """
  17. Add annotation to index.
  18. :param job_id: job_id
  19. :param content_list: content list
  20. :param tenant_id: tenant id
  21. :param app_id: app id
  22. :param user_id: user_id
  23. """
  24. logging.info(click.style('Start batch import annotation: {}'.format(job_id), fg='green'))
  25. start_at = time.perf_counter()
  26. indexing_cache_key = 'app_annotation_batch_import_{}'.format(str(job_id))
  27. # get app info
  28. app = db.session.query(App).filter(
  29. App.id == app_id,
  30. App.tenant_id == tenant_id,
  31. App.status == 'normal'
  32. ).first()
  33. if app:
  34. try:
  35. documents = []
  36. for content in content_list:
  37. annotation = MessageAnnotation(
  38. app_id=app.id,
  39. content=content['answer'],
  40. question=content['question'],
  41. account_id=user_id
  42. )
  43. db.session.add(annotation)
  44. db.session.flush()
  45. document = Document(
  46. page_content=content['question'],
  47. metadata={
  48. "annotation_id": annotation.id,
  49. "app_id": app_id,
  50. "doc_id": annotation.id
  51. }
  52. )
  53. documents.append(document)
  54. # if annotation reply is enabled , batch add annotations' index
  55. app_annotation_setting = db.session.query(AppAnnotationSetting).filter(
  56. AppAnnotationSetting.app_id == app_id
  57. ).first()
  58. if app_annotation_setting:
  59. dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
  60. app_annotation_setting.collection_binding_id,
  61. 'annotation'
  62. )
  63. if not dataset_collection_binding:
  64. raise NotFound("App annotation setting not found")
  65. dataset = Dataset(
  66. id=app_id,
  67. tenant_id=tenant_id,
  68. indexing_technique='high_quality',
  69. embedding_model_provider=dataset_collection_binding.provider_name,
  70. embedding_model=dataset_collection_binding.model_name,
  71. collection_binding_id=dataset_collection_binding.id
  72. )
  73. vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id'])
  74. vector.create(documents, duplicate_check=True)
  75. db.session.commit()
  76. redis_client.setex(indexing_cache_key, 600, 'completed')
  77. end_at = time.perf_counter()
  78. logging.info(
  79. click.style(
  80. 'Build index successful for batch import annotation: {} latency: {}'.format(job_id, end_at - start_at),
  81. fg='green'))
  82. except Exception as e:
  83. db.session.rollback()
  84. redis_client.setex(indexing_cache_key, 600, 'error')
  85. indexing_error_msg_key = 'app_annotation_batch_import_error_msg_{}'.format(str(job_id))
  86. redis_client.setex(indexing_error_msg_key, 600, str(e))
  87. logging.exception("Build index for batch import annotations failed")