123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181 |
- import logging
- from typing import cast
- from core.app.apps.base_app_queue_manager import AppQueueManager
- from core.app.apps.base_app_runner import AppRunner
- from core.app.apps.completion.app_config_manager import CompletionAppConfig
- from core.app.entities.app_invoke_entities import (
- CompletionAppGenerateEntity,
- )
- from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
- from core.model_manager import ModelInstance
- from core.moderation.base import ModerationException
- from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
- from extensions.ext_database import db
- from models.model import App, Message
- logger = logging.getLogger(__name__)
- class CompletionAppRunner(AppRunner):
- """
- Completion Application Runner
- """
- def run(self, application_generate_entity: CompletionAppGenerateEntity,
- queue_manager: AppQueueManager,
- message: Message) -> None:
- """
- Run application
- :param application_generate_entity: application generate entity
- :param queue_manager: application queue manager
- :param message: message
- :return:
- """
- app_config = application_generate_entity.app_config
- app_config = cast(CompletionAppConfig, app_config)
- app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
- if not app_record:
- raise ValueError("App not found")
- inputs = application_generate_entity.inputs
- query = application_generate_entity.query
- files = application_generate_entity.files
- # Pre-calculate the number of tokens of the prompt messages,
- # and return the rest number of tokens by model context token size limit and max token size limit.
- # If the rest number of tokens is not enough, raise exception.
- # Include: prompt template, inputs, query(optional), files(optional)
- # Not Include: memory, external data, dataset context
- self.get_pre_calculate_rest_tokens(
- app_record=app_record,
- model_config=application_generate_entity.model_config,
- prompt_template_entity=app_config.prompt_template,
- inputs=inputs,
- files=files,
- query=query
- )
- # organize all inputs and template to prompt messages
- # Include: prompt template, inputs, query(optional), files(optional)
- prompt_messages, stop = self.organize_prompt_messages(
- app_record=app_record,
- model_config=application_generate_entity.model_config,
- prompt_template_entity=app_config.prompt_template,
- inputs=inputs,
- files=files,
- query=query
- )
- # moderation
- try:
- # process sensitive_word_avoidance
- _, inputs, query = self.moderation_for_inputs(
- app_id=app_record.id,
- tenant_id=app_config.tenant_id,
- app_generate_entity=application_generate_entity,
- inputs=inputs,
- query=query,
- )
- except ModerationException as e:
- self.direct_output(
- queue_manager=queue_manager,
- app_generate_entity=application_generate_entity,
- prompt_messages=prompt_messages,
- text=str(e),
- stream=application_generate_entity.stream
- )
- return
- # fill in variable inputs from external data tools if exists
- external_data_tools = app_config.external_data_variables
- if external_data_tools:
- inputs = self.fill_in_inputs_from_external_data_tools(
- tenant_id=app_record.tenant_id,
- app_id=app_record.id,
- external_data_tools=external_data_tools,
- inputs=inputs,
- query=query
- )
- # get context from datasets
- context = None
- if app_config.dataset and app_config.dataset.dataset_ids:
- hit_callback = DatasetIndexToolCallbackHandler(
- queue_manager,
- app_record.id,
- message.id,
- application_generate_entity.user_id,
- application_generate_entity.invoke_from
- )
- dataset_config = app_config.dataset
- if dataset_config and dataset_config.retrieve_config.query_variable:
- query = inputs.get(dataset_config.retrieve_config.query_variable, "")
- dataset_retrieval = DatasetRetrieval()
- context = dataset_retrieval.retrieve(
- app_id=app_record.id,
- user_id=application_generate_entity.user_id,
- tenant_id=app_record.tenant_id,
- model_config=application_generate_entity.model_config,
- config=dataset_config,
- query=query,
- invoke_from=application_generate_entity.invoke_from,
- show_retrieve_source=app_config.additional_features.show_retrieve_source,
- hit_callback=hit_callback
- )
- # reorganize all inputs and template to prompt messages
- # Include: prompt template, inputs, query(optional), files(optional)
- # memory(optional), external data, dataset context(optional)
- prompt_messages, stop = self.organize_prompt_messages(
- app_record=app_record,
- model_config=application_generate_entity.model_config,
- prompt_template_entity=app_config.prompt_template,
- inputs=inputs,
- files=files,
- query=query,
- context=context
- )
- # check hosting moderation
- hosting_moderation_result = self.check_hosting_moderation(
- application_generate_entity=application_generate_entity,
- queue_manager=queue_manager,
- prompt_messages=prompt_messages
- )
- if hosting_moderation_result:
- return
- # Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit
- self.recalc_llm_max_tokens(
- model_config=application_generate_entity.model_config,
- prompt_messages=prompt_messages
- )
- # Invoke model
- model_instance = ModelInstance(
- provider_model_bundle=application_generate_entity.model_config.provider_model_bundle,
- model=application_generate_entity.model_config.model
- )
- db.session.close()
- invoke_result = model_instance.invoke_llm(
- prompt_messages=prompt_messages,
- model_parameters=application_generate_entity.model_config.parameters,
- stop=stop,
- stream=application_generate_entity.stream,
- user=application_generate_entity.user_id,
- )
- # handle invoke result
- self._handle_invoke_result(
- invoke_result=invoke_result,
- queue_manager=queue_manager,
- stream=application_generate_entity.stream
- )
-
|