enable_annotation_reply_task.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. import datetime
  2. import logging
  3. import time
  4. import click
  5. from celery import shared_task
  6. from werkzeug.exceptions import NotFound
  7. from core.rag.datasource.vdb.vector_factory import Vector
  8. from core.rag.models.document import Document
  9. from extensions.ext_database import db
  10. from extensions.ext_redis import redis_client
  11. from models.dataset import Dataset
  12. from models.model import App, AppAnnotationSetting, MessageAnnotation
  13. from services.dataset_service import DatasetCollectionBindingService
  14. @shared_task(queue="dataset")
  15. def enable_annotation_reply_task(
  16. job_id: str,
  17. app_id: str,
  18. user_id: str,
  19. tenant_id: str,
  20. score_threshold: float,
  21. embedding_provider_name: str,
  22. embedding_model_name: str,
  23. ):
  24. """
  25. Async enable annotation reply task
  26. """
  27. logging.info(click.style("Start add app annotation to index: {}".format(app_id), fg="green"))
  28. start_at = time.perf_counter()
  29. # get app info
  30. app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
  31. if not app:
  32. raise NotFound("App not found")
  33. annotations = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app_id).all()
  34. enable_app_annotation_key = "enable_app_annotation_{}".format(str(app_id))
  35. enable_app_annotation_job_key = "enable_app_annotation_job_{}".format(str(job_id))
  36. try:
  37. documents = []
  38. dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
  39. embedding_provider_name, embedding_model_name, "annotation"
  40. )
  41. annotation_setting = (
  42. db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first()
  43. )
  44. if annotation_setting:
  45. annotation_setting.score_threshold = score_threshold
  46. annotation_setting.collection_binding_id = dataset_collection_binding.id
  47. annotation_setting.updated_user_id = user_id
  48. annotation_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  49. db.session.add(annotation_setting)
  50. else:
  51. new_app_annotation_setting = AppAnnotationSetting(
  52. app_id=app_id,
  53. score_threshold=score_threshold,
  54. collection_binding_id=dataset_collection_binding.id,
  55. created_user_id=user_id,
  56. updated_user_id=user_id,
  57. )
  58. db.session.add(new_app_annotation_setting)
  59. dataset = Dataset(
  60. id=app_id,
  61. tenant_id=tenant_id,
  62. indexing_technique="high_quality",
  63. embedding_model_provider=embedding_provider_name,
  64. embedding_model=embedding_model_name,
  65. collection_binding_id=dataset_collection_binding.id,
  66. )
  67. if annotations:
  68. for annotation in annotations:
  69. document = Document(
  70. page_content=annotation.question,
  71. metadata={"annotation_id": annotation.id, "app_id": app_id, "doc_id": annotation.id},
  72. )
  73. documents.append(document)
  74. vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
  75. try:
  76. vector.delete_by_metadata_field("app_id", app_id)
  77. except Exception as e:
  78. logging.info(click.style("Delete annotation index error: {}".format(str(e)), fg="red"))
  79. vector.create(documents)
  80. db.session.commit()
  81. redis_client.setex(enable_app_annotation_job_key, 600, "completed")
  82. end_at = time.perf_counter()
  83. logging.info(
  84. click.style("App annotations added to index: {} latency: {}".format(app_id, end_at - start_at), fg="green")
  85. )
  86. except Exception as e:
  87. logging.exception("Annotation batch created index failed:{}".format(str(e)))
  88. redis_client.setex(enable_app_annotation_job_key, 600, "error")
  89. enable_app_annotation_error_key = "enable_app_annotation_error_{}".format(str(job_id))
  90. redis_client.setex(enable_app_annotation_error_key, 600, str(e))
  91. db.session.rollback()
  92. finally:
  93. redis_client.delete(enable_app_annotation_key)