123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656 |
- import json
- import logging
- import threading
- import uuid
- from typing import cast, Optional, Any, Union, Generator, Tuple
- from flask import Flask, current_app
- from pydantic import ValidationError
- from core.app_runner.agent_app_runner import AgentApplicationRunner
- from core.app_runner.basic_app_runner import BasicApplicationRunner
- from core.app_runner.generate_task_pipeline import GenerateTaskPipeline
- from core.entities.application_entities import ApplicationGenerateEntity, AppOrchestrationConfigEntity, \
- ModelConfigEntity, PromptTemplateEntity, AdvancedChatPromptTemplateEntity, \
- AdvancedCompletionPromptTemplateEntity, ExternalDataVariableEntity, DatasetEntity, DatasetRetrieveConfigEntity, \
- AgentEntity, AgentToolEntity, FileUploadEntity, SensitiveWordAvoidanceEntity, InvokeFrom
- from core.entities.model_entities import ModelStatus
- from core.file.file_obj import FileObj
- from core.errors.error import QuotaExceededError, ProviderTokenNotInitError, ModelCurrentlyNotSupportError
- from core.model_runtime.entities.message_entities import PromptMessageRole
- from core.model_runtime.entities.model_entities import ModelType
- from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
- from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
- from core.prompt.prompt_template import PromptTemplateParser
- from core.provider_manager import ProviderManager
- from core.application_queue_manager import ApplicationQueueManager, ConversationTaskStoppedException
- from extensions.ext_database import db
- from models.account import Account
- from models.model import EndUser, Conversation, Message, MessageFile, App
- logger = logging.getLogger(__name__)
- class ApplicationManager:
- """
- This class is responsible for managing application
- """
- def generate(self, tenant_id: str,
- app_id: str,
- app_model_config_id: str,
- app_model_config_dict: dict,
- app_model_config_override: bool,
- user: Union[Account, EndUser],
- invoke_from: InvokeFrom,
- inputs: dict[str, str],
- query: Optional[str] = None,
- files: Optional[list[FileObj]] = None,
- conversation: Optional[Conversation] = None,
- stream: bool = False,
- extras: Optional[dict[str, Any]] = None) \
- -> Union[dict, Generator]:
- """
- Generate App response.
- :param tenant_id: workspace ID
- :param app_id: app ID
- :param app_model_config_id: app model config id
- :param app_model_config_dict: app model config dict
- :param app_model_config_override: app model config override
- :param user: account or end user
- :param invoke_from: invoke from source
- :param inputs: inputs
- :param query: query
- :param files: file obj list
- :param conversation: conversation
- :param stream: is stream
- :param extras: extras
- """
- # init task id
- task_id = str(uuid.uuid4())
- # init application generate entity
- application_generate_entity = ApplicationGenerateEntity(
- task_id=task_id,
- tenant_id=tenant_id,
- app_id=app_id,
- app_model_config_id=app_model_config_id,
- app_model_config_dict=app_model_config_dict,
- app_orchestration_config_entity=self._convert_from_app_model_config_dict(
- tenant_id=tenant_id,
- app_model_config_dict=app_model_config_dict
- ),
- app_model_config_override=app_model_config_override,
- conversation_id=conversation.id if conversation else None,
- inputs=conversation.inputs if conversation else inputs,
- query=query.replace('\x00', '') if query else None,
- files=files if files else [],
- user_id=user.id,
- stream=stream,
- invoke_from=invoke_from,
- extras=extras
- )
- # init generate records
- (
- conversation,
- message
- ) = self._init_generate_records(application_generate_entity)
- # init queue manager
- queue_manager = ApplicationQueueManager(
- task_id=application_generate_entity.task_id,
- user_id=application_generate_entity.user_id,
- invoke_from=application_generate_entity.invoke_from,
- conversation_id=conversation.id,
- app_mode=conversation.mode,
- message_id=message.id
- )
- # new thread
- worker_thread = threading.Thread(target=self._generate_worker, kwargs={
- 'flask_app': current_app._get_current_object(),
- 'application_generate_entity': application_generate_entity,
- 'queue_manager': queue_manager,
- 'conversation_id': conversation.id,
- 'message_id': message.id,
- })
- worker_thread.start()
- # return response or stream generator
- return self._handle_response(
- application_generate_entity=application_generate_entity,
- queue_manager=queue_manager,
- conversation=conversation,
- message=message,
- stream=stream
- )
- def _generate_worker(self, flask_app: Flask,
- application_generate_entity: ApplicationGenerateEntity,
- queue_manager: ApplicationQueueManager,
- conversation_id: str,
- message_id: str) -> None:
- """
- Generate worker in a new thread.
- :param flask_app: Flask app
- :param application_generate_entity: application generate entity
- :param queue_manager: queue manager
- :param conversation_id: conversation ID
- :param message_id: message ID
- :return:
- """
- with flask_app.app_context():
- try:
- # get conversation and message
- conversation = self._get_conversation(conversation_id)
- message = self._get_message(message_id)
- if application_generate_entity.app_orchestration_config_entity.agent:
- # agent app
- runner = AgentApplicationRunner()
- runner.run(
- application_generate_entity=application_generate_entity,
- queue_manager=queue_manager,
- conversation=conversation,
- message=message
- )
- else:
- # basic app
- runner = BasicApplicationRunner()
- runner.run(
- application_generate_entity=application_generate_entity,
- queue_manager=queue_manager,
- conversation=conversation,
- message=message
- )
- except ConversationTaskStoppedException:
- pass
- except InvokeAuthorizationError:
- queue_manager.publish_error(InvokeAuthorizationError('Incorrect API key provided'))
- except ValidationError as e:
- logger.exception("Validation Error when generating")
- queue_manager.publish_error(e)
- except (ValueError, InvokeError) as e:
- queue_manager.publish_error(e)
- except Exception as e:
- logger.exception("Unknown Error when generating")
- queue_manager.publish_error(e)
- finally:
- db.session.remove()
- def _handle_response(self, application_generate_entity: ApplicationGenerateEntity,
- queue_manager: ApplicationQueueManager,
- conversation: Conversation,
- message: Message,
- stream: bool = False) -> Union[dict, Generator]:
- """
- Handle response.
- :param application_generate_entity: application generate entity
- :param queue_manager: queue manager
- :param conversation: conversation
- :param message: message
- :param stream: is stream
- :return:
- """
- # init generate task pipeline
- generate_task_pipeline = GenerateTaskPipeline(
- application_generate_entity=application_generate_entity,
- queue_manager=queue_manager,
- conversation=conversation,
- message=message
- )
- try:
- return generate_task_pipeline.process(stream=stream)
- except ValueError as e:
- if e.args[0] == "I/O operation on closed file.": # ignore this error
- raise ConversationTaskStoppedException()
- else:
- logger.exception(e)
- raise e
- finally:
- db.session.remove()
- def _convert_from_app_model_config_dict(self, tenant_id: str, app_model_config_dict: dict) \
- -> AppOrchestrationConfigEntity:
- """
- Convert app model config dict to entity.
- :param tenant_id: tenant ID
- :param app_model_config_dict: app model config dict
- :raises ProviderTokenNotInitError: provider token not init error
- :return: app orchestration config entity
- """
- properties = {}
- copy_app_model_config_dict = app_model_config_dict.copy()
- provider_manager = ProviderManager()
- provider_model_bundle = provider_manager.get_provider_model_bundle(
- tenant_id=tenant_id,
- provider=copy_app_model_config_dict['model']['provider'],
- model_type=ModelType.LLM
- )
- provider_name = provider_model_bundle.configuration.provider.provider
- model_name = copy_app_model_config_dict['model']['name']
- model_type_instance = provider_model_bundle.model_type_instance
- model_type_instance = cast(LargeLanguageModel, model_type_instance)
- # check model credentials
- model_credentials = provider_model_bundle.configuration.get_current_credentials(
- model_type=ModelType.LLM,
- model=copy_app_model_config_dict['model']['name']
- )
- if model_credentials is None:
- raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
- # check model
- provider_model = provider_model_bundle.configuration.get_provider_model(
- model=copy_app_model_config_dict['model']['name'],
- model_type=ModelType.LLM
- )
- if provider_model is None:
- model_name = copy_app_model_config_dict['model']['name']
- raise ValueError(f"Model {model_name} not exist.")
- if provider_model.status == ModelStatus.NO_CONFIGURE:
- raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
- elif provider_model.status == ModelStatus.NO_PERMISSION:
- raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
- elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
- raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
- # model config
- completion_params = copy_app_model_config_dict['model'].get('completion_params')
- stop = []
- if 'stop' in completion_params:
- stop = completion_params['stop']
- del completion_params['stop']
- # get model mode
- model_mode = copy_app_model_config_dict['model'].get('mode')
- if not model_mode:
- mode_enum = model_type_instance.get_model_mode(
- model=copy_app_model_config_dict['model']['name'],
- credentials=model_credentials
- )
- model_mode = mode_enum.value
- model_schema = model_type_instance.get_model_schema(
- copy_app_model_config_dict['model']['name'],
- model_credentials
- )
- if not model_schema:
- raise ValueError(f"Model {model_name} not exist.")
- properties['model_config'] = ModelConfigEntity(
- provider=copy_app_model_config_dict['model']['provider'],
- model=copy_app_model_config_dict['model']['name'],
- model_schema=model_schema,
- mode=model_mode,
- provider_model_bundle=provider_model_bundle,
- credentials=model_credentials,
- parameters=completion_params,
- stop=stop,
- )
- # prompt template
- prompt_type = PromptTemplateEntity.PromptType.value_of(copy_app_model_config_dict['prompt_type'])
- if prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
- simple_prompt_template = copy_app_model_config_dict.get("pre_prompt", "")
- properties['prompt_template'] = PromptTemplateEntity(
- prompt_type=prompt_type,
- simple_prompt_template=simple_prompt_template
- )
- else:
- advanced_chat_prompt_template = None
- chat_prompt_config = copy_app_model_config_dict.get("chat_prompt_config", {})
- if chat_prompt_config:
- chat_prompt_messages = []
- for message in chat_prompt_config.get("prompt", []):
- chat_prompt_messages.append({
- "text": message["text"],
- "role": PromptMessageRole.value_of(message["role"])
- })
- advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(
- messages=chat_prompt_messages
- )
- advanced_completion_prompt_template = None
- completion_prompt_config = copy_app_model_config_dict.get("completion_prompt_config", {})
- if completion_prompt_config:
- completion_prompt_template_params = {
- 'prompt': completion_prompt_config['prompt']['text'],
- }
- if 'conversation_histories_role' in completion_prompt_config:
- completion_prompt_template_params['role_prefix'] = {
- 'user': completion_prompt_config['conversation_histories_role']['user_prefix'],
- 'assistant': completion_prompt_config['conversation_histories_role']['assistant_prefix']
- }
- advanced_completion_prompt_template = AdvancedCompletionPromptTemplateEntity(
- **completion_prompt_template_params
- )
- properties['prompt_template'] = PromptTemplateEntity(
- prompt_type=prompt_type,
- advanced_chat_prompt_template=advanced_chat_prompt_template,
- advanced_completion_prompt_template=advanced_completion_prompt_template
- )
- # external data variables
- properties['external_data_variables'] = []
- external_data_tools = copy_app_model_config_dict.get('external_data_tools', [])
- for external_data_tool in external_data_tools:
- if 'enabled' not in external_data_tool or not external_data_tool['enabled']:
- continue
- properties['external_data_variables'].append(
- ExternalDataVariableEntity(
- variable=external_data_tool['variable'],
- type=external_data_tool['type'],
- config=external_data_tool['config']
- )
- )
- # show retrieve source
- show_retrieve_source = False
- retriever_resource_dict = copy_app_model_config_dict.get('retriever_resource')
- if retriever_resource_dict:
- if 'enabled' in retriever_resource_dict and retriever_resource_dict['enabled']:
- show_retrieve_source = True
- properties['show_retrieve_source'] = show_retrieve_source
- if 'agent_mode' in copy_app_model_config_dict and copy_app_model_config_dict['agent_mode'] \
- and 'enabled' in copy_app_model_config_dict['agent_mode'] and copy_app_model_config_dict['agent_mode'][
- 'enabled']:
- agent_dict = copy_app_model_config_dict.get('agent_mode')
- agent_strategy = agent_dict.get('strategy', 'router')
- if agent_strategy in ['router', 'react_router']:
- dataset_ids = []
- for tool in agent_dict.get('tools', []):
- key = list(tool.keys())[0]
- if key != 'dataset':
- continue
- tool_item = tool[key]
- if "enabled" not in tool_item or not tool_item["enabled"]:
- continue
- dataset_id = tool_item['id']
- dataset_ids.append(dataset_id)
- dataset_configs = copy_app_model_config_dict.get('dataset_configs', {'retrieval_model': 'single'})
- query_variable = copy_app_model_config_dict.get('dataset_query_variable')
- if dataset_configs['retrieval_model'] == 'single':
- properties['dataset'] = DatasetEntity(
- dataset_ids=dataset_ids,
- retrieve_config=DatasetRetrieveConfigEntity(
- query_variable=query_variable,
- retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
- dataset_configs['retrieval_model']
- ),
- single_strategy=agent_strategy
- )
- )
- else:
- properties['dataset'] = DatasetEntity(
- dataset_ids=dataset_ids,
- retrieve_config=DatasetRetrieveConfigEntity(
- query_variable=query_variable,
- retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
- dataset_configs['retrieval_model']
- ),
- top_k=dataset_configs.get('top_k'),
- score_threshold=dataset_configs.get('score_threshold'),
- reranking_model=dataset_configs.get('reranking_model')
- )
- )
- else:
- if agent_strategy == 'react':
- strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
- else:
- strategy = AgentEntity.Strategy.FUNCTION_CALLING
- agent_tools = []
- for tool in agent_dict.get('tools', []):
- key = list(tool.keys())[0]
- tool_item = tool[key]
- agent_tool_properties = {
- "tool_id": key
- }
- if "enabled" not in tool_item or not tool_item["enabled"]:
- continue
- agent_tool_properties["config"] = tool_item
- agent_tools.append(AgentToolEntity(**agent_tool_properties))
- properties['agent'] = AgentEntity(
- provider=properties['model_config'].provider,
- model=properties['model_config'].model,
- strategy=strategy,
- tools=agent_tools
- )
- # file upload
- file_upload_dict = copy_app_model_config_dict.get('file_upload')
- if file_upload_dict:
- if 'image' in file_upload_dict and file_upload_dict['image']:
- if 'enabled' in file_upload_dict['image'] and file_upload_dict['image']['enabled']:
- properties['file_upload'] = FileUploadEntity(
- image_config={
- 'number_limits': file_upload_dict['image']['number_limits'],
- 'detail': file_upload_dict['image']['detail'],
- 'transfer_methods': file_upload_dict['image']['transfer_methods']
- }
- )
- # opening statement
- properties['opening_statement'] = copy_app_model_config_dict.get('opening_statement')
- # suggested questions after answer
- suggested_questions_after_answer_dict = copy_app_model_config_dict.get('suggested_questions_after_answer')
- if suggested_questions_after_answer_dict:
- if 'enabled' in suggested_questions_after_answer_dict and suggested_questions_after_answer_dict['enabled']:
- properties['suggested_questions_after_answer'] = True
- # more like this
- more_like_this_dict = copy_app_model_config_dict.get('more_like_this')
- if more_like_this_dict:
- if 'enabled' in more_like_this_dict and more_like_this_dict['enabled']:
- properties['more_like_this'] = copy_app_model_config_dict.get('opening_statement')
- # speech to text
- speech_to_text_dict = copy_app_model_config_dict.get('speech_to_text')
- if speech_to_text_dict:
- if 'enabled' in speech_to_text_dict and speech_to_text_dict['enabled']:
- properties['speech_to_text'] = True
- # sensitive word avoidance
- sensitive_word_avoidance_dict = copy_app_model_config_dict.get('sensitive_word_avoidance')
- if sensitive_word_avoidance_dict:
- if 'enabled' in sensitive_word_avoidance_dict and sensitive_word_avoidance_dict['enabled']:
- properties['sensitive_word_avoidance'] = SensitiveWordAvoidanceEntity(
- type=sensitive_word_avoidance_dict.get('type'),
- config=sensitive_word_avoidance_dict.get('config'),
- )
- return AppOrchestrationConfigEntity(**properties)
- def _init_generate_records(self, application_generate_entity: ApplicationGenerateEntity) \
- -> Tuple[Conversation, Message]:
- """
- Initialize generate records
- :param application_generate_entity: application generate entity
- :return:
- """
- app_orchestration_config_entity = application_generate_entity.app_orchestration_config_entity
- model_type_instance = app_orchestration_config_entity.model_config.provider_model_bundle.model_type_instance
- model_type_instance = cast(LargeLanguageModel, model_type_instance)
- model_schema = model_type_instance.get_model_schema(
- model=app_orchestration_config_entity.model_config.model,
- credentials=app_orchestration_config_entity.model_config.credentials
- )
- app_record = (db.session.query(App)
- .filter(App.id == application_generate_entity.app_id).first())
- app_mode = app_record.mode
- # get from source
- end_user_id = None
- account_id = None
- if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
- from_source = 'api'
- end_user_id = application_generate_entity.user_id
- else:
- from_source = 'console'
- account_id = application_generate_entity.user_id
- override_model_configs = None
- if application_generate_entity.app_model_config_override:
- override_model_configs = application_generate_entity.app_model_config_dict
- introduction = ''
- if app_mode == 'chat':
- # get conversation introduction
- introduction = self._get_conversation_introduction(application_generate_entity)
- if not application_generate_entity.conversation_id:
- conversation = Conversation(
- app_id=app_record.id,
- app_model_config_id=application_generate_entity.app_model_config_id,
- model_provider=app_orchestration_config_entity.model_config.provider,
- model_id=app_orchestration_config_entity.model_config.model,
- override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
- mode=app_mode,
- name='New conversation',
- inputs=application_generate_entity.inputs,
- introduction=introduction,
- system_instruction="",
- system_instruction_tokens=0,
- status='normal',
- from_source=from_source,
- from_end_user_id=end_user_id,
- from_account_id=account_id,
- )
- db.session.add(conversation)
- db.session.commit()
- else:
- conversation = (
- db.session.query(Conversation)
- .filter(
- Conversation.id == application_generate_entity.conversation_id,
- Conversation.app_id == app_record.id
- ).first()
- )
- currency = model_schema.pricing.currency if model_schema.pricing else 'USD'
- message = Message(
- app_id=app_record.id,
- model_provider=app_orchestration_config_entity.model_config.provider,
- model_id=app_orchestration_config_entity.model_config.model,
- override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
- conversation_id=conversation.id,
- inputs=application_generate_entity.inputs,
- query=application_generate_entity.query or "",
- message="",
- message_tokens=0,
- message_unit_price=0,
- message_price_unit=0,
- answer="",
- answer_tokens=0,
- answer_unit_price=0,
- answer_price_unit=0,
- provider_response_latency=0,
- total_price=0,
- currency=currency,
- from_source=from_source,
- from_end_user_id=end_user_id,
- from_account_id=account_id,
- agent_based=app_orchestration_config_entity.agent is not None
- )
- db.session.add(message)
- db.session.commit()
- for file in application_generate_entity.files:
- message_file = MessageFile(
- message_id=message.id,
- type=file.type.value,
- transfer_method=file.transfer_method.value,
- url=file.url,
- upload_file_id=file.upload_file_id,
- created_by_role=('account' if account_id else 'end_user'),
- created_by=account_id or end_user_id,
- )
- db.session.add(message_file)
- db.session.commit()
- return conversation, message
- def _get_conversation_introduction(self, application_generate_entity: ApplicationGenerateEntity) -> str:
- """
- Get conversation introduction
- :param application_generate_entity: application generate entity
- :return: conversation introduction
- """
- app_orchestration_config_entity = application_generate_entity.app_orchestration_config_entity
- introduction = app_orchestration_config_entity.opening_statement
- if introduction:
- try:
- inputs = application_generate_entity.inputs
- prompt_template = PromptTemplateParser(template=introduction)
- prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
- introduction = prompt_template.format(prompt_inputs)
- except KeyError:
- pass
- return introduction
- def _get_conversation(self, conversation_id: str) -> Conversation:
- """
- Get conversation by conversation id
- :param conversation_id: conversation id
- :return: conversation
- """
- conversation = (
- db.session.query(Conversation)
- .filter(Conversation.id == conversation_id)
- .first()
- )
- return conversation
- def _get_message(self, message_id: str) -> Message:
- """
- Get message by message id
- :param message_id: message id
- :return: message
- """
- message = (
- db.session.query(Message)
- .filter(Message.id == message_id)
- .first()
- )
- return message
|