annotation_reply.py 3.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. import logging
  2. from typing import Optional
  3. from core.entities.application_entities import InvokeFrom
  4. from core.rag.datasource.vdb.vector_factory import Vector
  5. from extensions.ext_database import db
  6. from models.dataset import Dataset
  7. from models.model import App, AppAnnotationSetting, Message, MessageAnnotation
  8. from services.annotation_service import AppAnnotationService
  9. from services.dataset_service import DatasetCollectionBindingService
  10. logger = logging.getLogger(__name__)
  11. class AnnotationReplyFeature:
  12. def query(self, app_record: App,
  13. message: Message,
  14. query: str,
  15. user_id: str,
  16. invoke_from: InvokeFrom) -> Optional[MessageAnnotation]:
  17. """
  18. Query app annotations to reply
  19. :param app_record: app record
  20. :param message: message
  21. :param query: query
  22. :param user_id: user id
  23. :param invoke_from: invoke from
  24. :return:
  25. """
  26. annotation_setting = db.session.query(AppAnnotationSetting).filter(
  27. AppAnnotationSetting.app_id == app_record.id).first()
  28. if not annotation_setting:
  29. return None
  30. collection_binding_detail = annotation_setting.collection_binding_detail
  31. try:
  32. score_threshold = annotation_setting.score_threshold or 1
  33. embedding_provider_name = collection_binding_detail.provider_name
  34. embedding_model_name = collection_binding_detail.model_name
  35. dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
  36. embedding_provider_name,
  37. embedding_model_name,
  38. 'annotation'
  39. )
  40. dataset = Dataset(
  41. id=app_record.id,
  42. tenant_id=app_record.tenant_id,
  43. indexing_technique='high_quality',
  44. embedding_model_provider=embedding_provider_name,
  45. embedding_model=embedding_model_name,
  46. collection_binding_id=dataset_collection_binding.id
  47. )
  48. vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id'])
  49. documents = vector.search_by_vector(
  50. query=query,
  51. top_k=1,
  52. score_threshold=score_threshold,
  53. filter={
  54. 'group_id': [dataset.id]
  55. }
  56. )
  57. if documents:
  58. annotation_id = documents[0].metadata['annotation_id']
  59. score = documents[0].metadata['score']
  60. annotation = AppAnnotationService.get_annotation_by_id(annotation_id)
  61. if annotation:
  62. if invoke_from in [InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP]:
  63. from_source = 'api'
  64. else:
  65. from_source = 'console'
  66. # insert annotation history
  67. AppAnnotationService.add_annotation_history(annotation.id,
  68. app_record.id,
  69. annotation.question,
  70. annotation.content,
  71. query,
  72. user_id,
  73. message.id,
  74. from_source,
  75. score)
  76. return annotation
  77. except Exception as e:
  78. logger.warning(f'Query annotation failed, exception: {str(e)}.')
  79. return None
  80. return None