basic_app_runner.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. import logging
  2. from typing import Optional
  3. from core.app_runner.app_runner import AppRunner
  4. from core.application_queue_manager import ApplicationQueueManager, PublishFrom
  5. from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
  6. from core.entities.application_entities import (ApplicationGenerateEntity, DatasetEntity,
  7. InvokeFrom, ModelConfigEntity)
  8. from core.features.dataset_retrieval import DatasetRetrievalFeature
  9. from core.memory.token_buffer_memory import TokenBufferMemory
  10. from core.model_manager import ModelInstance
  11. from core.moderation.base import ModerationException
  12. from core.prompt.prompt_transform import AppMode
  13. from extensions.ext_database import db
  14. from models.model import App, Conversation, Message
  15. logger = logging.getLogger(__name__)
  16. class BasicApplicationRunner(AppRunner):
  17. """
  18. Basic Application Runner
  19. """
  20. def run(self, application_generate_entity: ApplicationGenerateEntity,
  21. queue_manager: ApplicationQueueManager,
  22. conversation: Conversation,
  23. message: Message) -> None:
  24. """
  25. Run application
  26. :param application_generate_entity: application generate entity
  27. :param queue_manager: application queue manager
  28. :param conversation: conversation
  29. :param message: message
  30. :return:
  31. """
  32. app_record = db.session.query(App).filter(App.id == application_generate_entity.app_id).first()
  33. if not app_record:
  34. raise ValueError(f"App not found")
  35. app_orchestration_config = application_generate_entity.app_orchestration_config_entity
  36. inputs = application_generate_entity.inputs
  37. query = application_generate_entity.query
  38. files = application_generate_entity.files
  39. # Pre-calculate the number of tokens of the prompt messages,
  40. # and return the rest number of tokens by model context token size limit and max token size limit.
  41. # If the rest number of tokens is not enough, raise exception.
  42. # Include: prompt template, inputs, query(optional), files(optional)
  43. # Not Include: memory, external data, dataset context
  44. self.get_pre_calculate_rest_tokens(
  45. app_record=app_record,
  46. model_config=app_orchestration_config.model_config,
  47. prompt_template_entity=app_orchestration_config.prompt_template,
  48. inputs=inputs,
  49. files=files,
  50. query=query
  51. )
  52. memory = None
  53. if application_generate_entity.conversation_id:
  54. # get memory of conversation (read-only)
  55. model_instance = ModelInstance(
  56. provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle,
  57. model=app_orchestration_config.model_config.model
  58. )
  59. memory = TokenBufferMemory(
  60. conversation=conversation,
  61. model_instance=model_instance
  62. )
  63. # organize all inputs and template to prompt messages
  64. # Include: prompt template, inputs, query(optional), files(optional)
  65. # memory(optional)
  66. prompt_messages, stop = self.organize_prompt_messages(
  67. app_record=app_record,
  68. model_config=app_orchestration_config.model_config,
  69. prompt_template_entity=app_orchestration_config.prompt_template,
  70. inputs=inputs,
  71. files=files,
  72. query=query,
  73. memory=memory
  74. )
  75. # moderation
  76. try:
  77. # process sensitive_word_avoidance
  78. _, inputs, query = self.moderation_for_inputs(
  79. app_id=app_record.id,
  80. tenant_id=application_generate_entity.tenant_id,
  81. app_orchestration_config_entity=app_orchestration_config,
  82. inputs=inputs,
  83. query=query,
  84. )
  85. except ModerationException as e:
  86. self.direct_output(
  87. queue_manager=queue_manager,
  88. app_orchestration_config=app_orchestration_config,
  89. prompt_messages=prompt_messages,
  90. text=str(e),
  91. stream=application_generate_entity.stream
  92. )
  93. return
  94. if query:
  95. # annotation reply
  96. annotation_reply = self.query_app_annotations_to_reply(
  97. app_record=app_record,
  98. message=message,
  99. query=query,
  100. user_id=application_generate_entity.user_id,
  101. invoke_from=application_generate_entity.invoke_from
  102. )
  103. if annotation_reply:
  104. queue_manager.publish_annotation_reply(
  105. message_annotation_id=annotation_reply.id,
  106. pub_from=PublishFrom.APPLICATION_MANAGER
  107. )
  108. self.direct_output(
  109. queue_manager=queue_manager,
  110. app_orchestration_config=app_orchestration_config,
  111. prompt_messages=prompt_messages,
  112. text=annotation_reply.content,
  113. stream=application_generate_entity.stream
  114. )
  115. return
  116. # fill in variable inputs from external data tools if exists
  117. external_data_tools = app_orchestration_config.external_data_variables
  118. if external_data_tools:
  119. inputs = self.fill_in_inputs_from_external_data_tools(
  120. tenant_id=app_record.tenant_id,
  121. app_id=app_record.id,
  122. external_data_tools=external_data_tools,
  123. inputs=inputs,
  124. query=query
  125. )
  126. # get context from datasets
  127. context = None
  128. if app_orchestration_config.dataset and app_orchestration_config.dataset.dataset_ids:
  129. context = self.retrieve_dataset_context(
  130. tenant_id=app_record.tenant_id,
  131. app_record=app_record,
  132. queue_manager=queue_manager,
  133. model_config=app_orchestration_config.model_config,
  134. show_retrieve_source=app_orchestration_config.show_retrieve_source,
  135. dataset_config=app_orchestration_config.dataset,
  136. message=message,
  137. inputs=inputs,
  138. query=query,
  139. user_id=application_generate_entity.user_id,
  140. invoke_from=application_generate_entity.invoke_from,
  141. memory=memory
  142. )
  143. # reorganize all inputs and template to prompt messages
  144. # Include: prompt template, inputs, query(optional), files(optional)
  145. # memory(optional), external data, dataset context(optional)
  146. prompt_messages, stop = self.organize_prompt_messages(
  147. app_record=app_record,
  148. model_config=app_orchestration_config.model_config,
  149. prompt_template_entity=app_orchestration_config.prompt_template,
  150. inputs=inputs,
  151. files=files,
  152. query=query,
  153. context=context,
  154. memory=memory
  155. )
  156. # check hosting moderation
  157. hosting_moderation_result = self.check_hosting_moderation(
  158. application_generate_entity=application_generate_entity,
  159. queue_manager=queue_manager,
  160. prompt_messages=prompt_messages
  161. )
  162. if hosting_moderation_result:
  163. return
  164. # Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit
  165. self.recale_llm_max_tokens(
  166. model_config=app_orchestration_config.model_config,
  167. prompt_messages=prompt_messages
  168. )
  169. # Invoke model
  170. model_instance = ModelInstance(
  171. provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle,
  172. model=app_orchestration_config.model_config.model
  173. )
  174. invoke_result = model_instance.invoke_llm(
  175. prompt_messages=prompt_messages,
  176. model_parameters=app_orchestration_config.model_config.parameters,
  177. stop=stop,
  178. stream=application_generate_entity.stream,
  179. user=application_generate_entity.user_id,
  180. )
  181. # handle invoke result
  182. self._handle_invoke_result(
  183. invoke_result=invoke_result,
  184. queue_manager=queue_manager,
  185. stream=application_generate_entity.stream
  186. )
  187. def retrieve_dataset_context(self, tenant_id: str,
  188. app_record: App,
  189. queue_manager: ApplicationQueueManager,
  190. model_config: ModelConfigEntity,
  191. dataset_config: DatasetEntity,
  192. show_retrieve_source: bool,
  193. message: Message,
  194. inputs: dict,
  195. query: str,
  196. user_id: str,
  197. invoke_from: InvokeFrom,
  198. memory: Optional[TokenBufferMemory] = None) -> Optional[str]:
  199. """
  200. Retrieve dataset context
  201. :param tenant_id: tenant id
  202. :param app_record: app record
  203. :param queue_manager: queue manager
  204. :param model_config: model config
  205. :param dataset_config: dataset config
  206. :param show_retrieve_source: show retrieve source
  207. :param message: message
  208. :param inputs: inputs
  209. :param query: query
  210. :param user_id: user id
  211. :param invoke_from: invoke from
  212. :param memory: memory
  213. :return:
  214. """
  215. hit_callback = DatasetIndexToolCallbackHandler(
  216. queue_manager,
  217. app_record.id,
  218. message.id,
  219. user_id,
  220. invoke_from
  221. )
  222. if (app_record.mode == AppMode.COMPLETION.value and dataset_config
  223. and dataset_config.retrieve_config.query_variable):
  224. query = inputs.get(dataset_config.retrieve_config.query_variable, "")
  225. dataset_retrieval = DatasetRetrievalFeature()
  226. return dataset_retrieval.retrieve(
  227. tenant_id=tenant_id,
  228. model_config=model_config,
  229. config=dataset_config,
  230. query=query,
  231. invoke_from=invoke_from,
  232. show_retrieve_source=show_retrieve_source,
  233. hit_callback=hit_callback,
  234. memory=memory
  235. )