annotation_reply.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. import logging
  2. from typing import Optional
  3. from core.embedding.cached_embedding import CacheEmbedding
  4. from core.entities.application_entities import InvokeFrom
  5. from core.index.vector_index.vector_index import VectorIndex
  6. from core.model_manager import ModelManager
  7. from core.model_runtime.entities.model_entities import ModelType
  8. from extensions.ext_database import db
  9. from flask import current_app
  10. from models.dataset import Dataset
  11. from models.model import App, AppAnnotationSetting, Message, MessageAnnotation
  12. from services.annotation_service import AppAnnotationService
  13. from services.dataset_service import DatasetCollectionBindingService
  14. logger = logging.getLogger(__name__)
  15. class AnnotationReplyFeature:
  16. def query(self, app_record: App,
  17. message: Message,
  18. query: str,
  19. user_id: str,
  20. invoke_from: InvokeFrom) -> Optional[MessageAnnotation]:
  21. """
  22. Query app annotations to reply
  23. :param app_record: app record
  24. :param message: message
  25. :param query: query
  26. :param user_id: user id
  27. :param invoke_from: invoke from
  28. :return:
  29. """
  30. annotation_setting = db.session.query(AppAnnotationSetting).filter(
  31. AppAnnotationSetting.app_id == app_record.id).first()
  32. if not annotation_setting:
  33. return None
  34. collection_binding_detail = annotation_setting.collection_binding_detail
  35. try:
  36. score_threshold = annotation_setting.score_threshold or 1
  37. embedding_provider_name = collection_binding_detail.provider_name
  38. embedding_model_name = collection_binding_detail.model_name
  39. model_manager = ModelManager()
  40. model_instance = model_manager.get_model_instance(
  41. tenant_id=app_record.tenant_id,
  42. provider=embedding_provider_name,
  43. model_type=ModelType.TEXT_EMBEDDING,
  44. model=embedding_model_name
  45. )
  46. # get embedding model
  47. embeddings = CacheEmbedding(model_instance)
  48. dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
  49. embedding_provider_name,
  50. embedding_model_name,
  51. 'annotation'
  52. )
  53. dataset = Dataset(
  54. id=app_record.id,
  55. tenant_id=app_record.tenant_id,
  56. indexing_technique='high_quality',
  57. embedding_model_provider=embedding_provider_name,
  58. embedding_model=embedding_model_name,
  59. collection_binding_id=dataset_collection_binding.id
  60. )
  61. vector_index = VectorIndex(
  62. dataset=dataset,
  63. config=current_app.config,
  64. embeddings=embeddings,
  65. attributes=['doc_id', 'annotation_id', 'app_id']
  66. )
  67. documents = vector_index.search(
  68. query=query,
  69. search_type='similarity_score_threshold',
  70. search_kwargs={
  71. 'k': 1,
  72. 'score_threshold': score_threshold,
  73. 'filter': {
  74. 'group_id': [dataset.id]
  75. }
  76. }
  77. )
  78. if documents:
  79. annotation_id = documents[0].metadata['annotation_id']
  80. score = documents[0].metadata['score']
  81. annotation = AppAnnotationService.get_annotation_by_id(annotation_id)
  82. if annotation:
  83. if invoke_from in [InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP]:
  84. from_source = 'api'
  85. else:
  86. from_source = 'console'
  87. # insert annotation history
  88. AppAnnotationService.add_annotation_history(annotation.id,
  89. app_record.id,
  90. annotation.question,
  91. annotation.content,
  92. query,
  93. user_id,
  94. message.id,
  95. from_source,
  96. score)
  97. return annotation
  98. except Exception as e:
  99. logger.warning(f'Query annotation failed, exception: {str(e)}.')
  100. return None
  101. return None