123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118 |
- import logging
- from typing import Optional
- from core.embedding.cached_embedding import CacheEmbedding
- from core.entities.application_entities import InvokeFrom
- from core.index.vector_index.vector_index import VectorIndex
- from core.model_manager import ModelManager
- from core.model_runtime.entities.model_entities import ModelType
- from extensions.ext_database import db
- from flask import current_app
- from models.dataset import Dataset
- from models.model import App, AppAnnotationSetting, Message, MessageAnnotation
- from services.annotation_service import AppAnnotationService
- from services.dataset_service import DatasetCollectionBindingService
- logger = logging.getLogger(__name__)
- class AnnotationReplyFeature:
- def query(self, app_record: App,
- message: Message,
- query: str,
- user_id: str,
- invoke_from: InvokeFrom) -> Optional[MessageAnnotation]:
- """
- Query app annotations to reply
- :param app_record: app record
- :param message: message
- :param query: query
- :param user_id: user id
- :param invoke_from: invoke from
- :return:
- """
- annotation_setting = db.session.query(AppAnnotationSetting).filter(
- AppAnnotationSetting.app_id == app_record.id).first()
- if not annotation_setting:
- return None
- collection_binding_detail = annotation_setting.collection_binding_detail
- try:
- score_threshold = annotation_setting.score_threshold or 1
- embedding_provider_name = collection_binding_detail.provider_name
- embedding_model_name = collection_binding_detail.model_name
- model_manager = ModelManager()
- model_instance = model_manager.get_model_instance(
- tenant_id=app_record.tenant_id,
- provider=embedding_provider_name,
- model_type=ModelType.TEXT_EMBEDDING,
- model=embedding_model_name
- )
- # get embedding model
- embeddings = CacheEmbedding(model_instance)
- dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
- embedding_provider_name,
- embedding_model_name,
- 'annotation'
- )
- dataset = Dataset(
- id=app_record.id,
- tenant_id=app_record.tenant_id,
- indexing_technique='high_quality',
- embedding_model_provider=embedding_provider_name,
- embedding_model=embedding_model_name,
- collection_binding_id=dataset_collection_binding.id
- )
- vector_index = VectorIndex(
- dataset=dataset,
- config=current_app.config,
- embeddings=embeddings,
- attributes=['doc_id', 'annotation_id', 'app_id']
- )
- documents = vector_index.search(
- query=query,
- search_type='similarity_score_threshold',
- search_kwargs={
- 'k': 1,
- 'score_threshold': score_threshold,
- 'filter': {
- 'group_id': [dataset.id]
- }
- }
- )
- if documents:
- annotation_id = documents[0].metadata['annotation_id']
- score = documents[0].metadata['score']
- annotation = AppAnnotationService.get_annotation_by_id(annotation_id)
- if annotation:
- if invoke_from in [InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP]:
- from_source = 'api'
- else:
- from_source = 'console'
- # insert annotation history
- AppAnnotationService.add_annotation_history(annotation.id,
- app_record.id,
- annotation.question,
- annotation.content,
- query,
- user_id,
- message.id,
- from_source,
- score)
- return annotation
- except Exception as e:
- logger.warning(f'Query annotation failed, exception: {str(e)}.')
- return None
- return None
|