app_runner.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. import logging
  2. from typing import cast
  3. from core.app.apps.base_app_queue_manager import AppQueueManager
  4. from core.app.apps.base_app_runner import AppRunner
  5. from core.app.apps.completion.app_config_manager import CompletionAppConfig
  6. from core.app.entities.app_invoke_entities import (
  7. CompletionAppGenerateEntity,
  8. )
  9. from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
  10. from core.model_manager import ModelInstance
  11. from core.moderation.base import ModerationException
  12. from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
  13. from extensions.ext_database import db
  14. from models.model import App, Message
  15. logger = logging.getLogger(__name__)
  16. class CompletionAppRunner(AppRunner):
  17. """
  18. Completion Application Runner
  19. """
  20. def run(self, application_generate_entity: CompletionAppGenerateEntity,
  21. queue_manager: AppQueueManager,
  22. message: Message) -> None:
  23. """
  24. Run application
  25. :param application_generate_entity: application generate entity
  26. :param queue_manager: application queue manager
  27. :param message: message
  28. :return:
  29. """
  30. app_config = application_generate_entity.app_config
  31. app_config = cast(CompletionAppConfig, app_config)
  32. app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
  33. if not app_record:
  34. raise ValueError("App not found")
  35. inputs = application_generate_entity.inputs
  36. query = application_generate_entity.query
  37. files = application_generate_entity.files
  38. # Pre-calculate the number of tokens of the prompt messages,
  39. # and return the rest number of tokens by model context token size limit and max token size limit.
  40. # If the rest number of tokens is not enough, raise exception.
  41. # Include: prompt template, inputs, query(optional), files(optional)
  42. # Not Include: memory, external data, dataset context
  43. self.get_pre_calculate_rest_tokens(
  44. app_record=app_record,
  45. model_config=application_generate_entity.model_config,
  46. prompt_template_entity=app_config.prompt_template,
  47. inputs=inputs,
  48. files=files,
  49. query=query
  50. )
  51. # organize all inputs and template to prompt messages
  52. # Include: prompt template, inputs, query(optional), files(optional)
  53. prompt_messages, stop = self.organize_prompt_messages(
  54. app_record=app_record,
  55. model_config=application_generate_entity.model_config,
  56. prompt_template_entity=app_config.prompt_template,
  57. inputs=inputs,
  58. files=files,
  59. query=query
  60. )
  61. # moderation
  62. try:
  63. # process sensitive_word_avoidance
  64. _, inputs, query = self.moderation_for_inputs(
  65. app_id=app_record.id,
  66. tenant_id=app_config.tenant_id,
  67. app_generate_entity=application_generate_entity,
  68. inputs=inputs,
  69. query=query,
  70. )
  71. except ModerationException as e:
  72. self.direct_output(
  73. queue_manager=queue_manager,
  74. app_generate_entity=application_generate_entity,
  75. prompt_messages=prompt_messages,
  76. text=str(e),
  77. stream=application_generate_entity.stream
  78. )
  79. return
  80. # fill in variable inputs from external data tools if exists
  81. external_data_tools = app_config.external_data_variables
  82. if external_data_tools:
  83. inputs = self.fill_in_inputs_from_external_data_tools(
  84. tenant_id=app_record.tenant_id,
  85. app_id=app_record.id,
  86. external_data_tools=external_data_tools,
  87. inputs=inputs,
  88. query=query
  89. )
  90. # get context from datasets
  91. context = None
  92. if app_config.dataset and app_config.dataset.dataset_ids:
  93. hit_callback = DatasetIndexToolCallbackHandler(
  94. queue_manager,
  95. app_record.id,
  96. message.id,
  97. application_generate_entity.user_id,
  98. application_generate_entity.invoke_from
  99. )
  100. dataset_config = app_config.dataset
  101. if dataset_config and dataset_config.retrieve_config.query_variable:
  102. query = inputs.get(dataset_config.retrieve_config.query_variable, "")
  103. dataset_retrieval = DatasetRetrieval()
  104. context = dataset_retrieval.retrieve(
  105. app_id=app_record.id,
  106. user_id=application_generate_entity.user_id,
  107. tenant_id=app_record.tenant_id,
  108. model_config=application_generate_entity.model_config,
  109. config=dataset_config,
  110. query=query,
  111. invoke_from=application_generate_entity.invoke_from,
  112. show_retrieve_source=app_config.additional_features.show_retrieve_source,
  113. hit_callback=hit_callback
  114. )
  115. # reorganize all inputs and template to prompt messages
  116. # Include: prompt template, inputs, query(optional), files(optional)
  117. # memory(optional), external data, dataset context(optional)
  118. prompt_messages, stop = self.organize_prompt_messages(
  119. app_record=app_record,
  120. model_config=application_generate_entity.model_config,
  121. prompt_template_entity=app_config.prompt_template,
  122. inputs=inputs,
  123. files=files,
  124. query=query,
  125. context=context
  126. )
  127. # check hosting moderation
  128. hosting_moderation_result = self.check_hosting_moderation(
  129. application_generate_entity=application_generate_entity,
  130. queue_manager=queue_manager,
  131. prompt_messages=prompt_messages
  132. )
  133. if hosting_moderation_result:
  134. return
  135. # Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit
  136. self.recalc_llm_max_tokens(
  137. model_config=application_generate_entity.model_config,
  138. prompt_messages=prompt_messages
  139. )
  140. # Invoke model
  141. model_instance = ModelInstance(
  142. provider_model_bundle=application_generate_entity.model_config.provider_model_bundle,
  143. model=application_generate_entity.model_config.model
  144. )
  145. db.session.close()
  146. invoke_result = model_instance.invoke_llm(
  147. prompt_messages=prompt_messages,
  148. model_parameters=application_generate_entity.model_config.parameters,
  149. stop=stop,
  150. stream=application_generate_entity.stream,
  151. user=application_generate_entity.user_id,
  152. )
  153. # handle invoke result
  154. self._handle_invoke_result(
  155. invoke_result=invoke_result,
  156. queue_manager=queue_manager,
  157. stream=application_generate_entity.stream
  158. )