basic_app_runner.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364
  1. import logging
  2. from typing import Optional, Tuple
  3. from core.app_runner.app_runner import AppRunner
  4. from core.application_queue_manager import ApplicationQueueManager, PublishFrom
  5. from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
  6. from core.entities.application_entities import (ApplicationGenerateEntity, AppOrchestrationConfigEntity, DatasetEntity,
  7. ExternalDataVariableEntity, InvokeFrom, ModelConfigEntity)
  8. from core.features.annotation_reply import AnnotationReplyFeature
  9. from core.features.dataset_retrieval import DatasetRetrievalFeature
  10. from core.features.external_data_fetch import ExternalDataFetchFeature
  11. from core.features.hosting_moderation import HostingModerationFeature
  12. from core.features.moderation import ModerationFeature
  13. from core.memory.token_buffer_memory import TokenBufferMemory
  14. from core.model_manager import ModelInstance
  15. from core.model_runtime.entities.message_entities import PromptMessage
  16. from core.moderation.base import ModerationException
  17. from core.prompt.prompt_transform import AppMode
  18. from extensions.ext_database import db
  19. from models.model import App, Conversation, Message, MessageAnnotation
  20. logger = logging.getLogger(__name__)
  21. class BasicApplicationRunner(AppRunner):
  22. """
  23. Basic Application Runner
  24. """
  25. def run(self, application_generate_entity: ApplicationGenerateEntity,
  26. queue_manager: ApplicationQueueManager,
  27. conversation: Conversation,
  28. message: Message) -> None:
  29. """
  30. Run application
  31. :param application_generate_entity: application generate entity
  32. :param queue_manager: application queue manager
  33. :param conversation: conversation
  34. :param message: message
  35. :return:
  36. """
  37. app_record = db.session.query(App).filter(App.id == application_generate_entity.app_id).first()
  38. if not app_record:
  39. raise ValueError(f"App not found")
  40. app_orchestration_config = application_generate_entity.app_orchestration_config_entity
  41. inputs = application_generate_entity.inputs
  42. query = application_generate_entity.query
  43. files = application_generate_entity.files
  44. # Pre-calculate the number of tokens of the prompt messages,
  45. # and return the rest number of tokens by model context token size limit and max token size limit.
  46. # If the rest number of tokens is not enough, raise exception.
  47. # Include: prompt template, inputs, query(optional), files(optional)
  48. # Not Include: memory, external data, dataset context
  49. self.get_pre_calculate_rest_tokens(
  50. app_record=app_record,
  51. model_config=app_orchestration_config.model_config,
  52. prompt_template_entity=app_orchestration_config.prompt_template,
  53. inputs=inputs,
  54. files=files,
  55. query=query
  56. )
  57. memory = None
  58. if application_generate_entity.conversation_id:
  59. # get memory of conversation (read-only)
  60. model_instance = ModelInstance(
  61. provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle,
  62. model=app_orchestration_config.model_config.model
  63. )
  64. memory = TokenBufferMemory(
  65. conversation=conversation,
  66. model_instance=model_instance
  67. )
  68. # organize all inputs and template to prompt messages
  69. # Include: prompt template, inputs, query(optional), files(optional)
  70. # memory(optional)
  71. prompt_messages, stop = self.organize_prompt_messages(
  72. app_record=app_record,
  73. model_config=app_orchestration_config.model_config,
  74. prompt_template_entity=app_orchestration_config.prompt_template,
  75. inputs=inputs,
  76. files=files,
  77. query=query,
  78. memory=memory
  79. )
  80. # moderation
  81. try:
  82. # process sensitive_word_avoidance
  83. _, inputs, query = self.moderation_for_inputs(
  84. app_id=app_record.id,
  85. tenant_id=application_generate_entity.tenant_id,
  86. app_orchestration_config_entity=app_orchestration_config,
  87. inputs=inputs,
  88. query=query,
  89. )
  90. except ModerationException as e:
  91. self.direct_output(
  92. queue_manager=queue_manager,
  93. app_orchestration_config=app_orchestration_config,
  94. prompt_messages=prompt_messages,
  95. text=str(e),
  96. stream=application_generate_entity.stream
  97. )
  98. return
  99. if query:
  100. # annotation reply
  101. annotation_reply = self.query_app_annotations_to_reply(
  102. app_record=app_record,
  103. message=message,
  104. query=query,
  105. user_id=application_generate_entity.user_id,
  106. invoke_from=application_generate_entity.invoke_from
  107. )
  108. if annotation_reply:
  109. queue_manager.publish_annotation_reply(
  110. message_annotation_id=annotation_reply.id,
  111. pub_from=PublishFrom.APPLICATION_MANAGER
  112. )
  113. self.direct_output(
  114. queue_manager=queue_manager,
  115. app_orchestration_config=app_orchestration_config,
  116. prompt_messages=prompt_messages,
  117. text=annotation_reply.content,
  118. stream=application_generate_entity.stream
  119. )
  120. return
  121. # fill in variable inputs from external data tools if exists
  122. external_data_tools = app_orchestration_config.external_data_variables
  123. if external_data_tools:
  124. inputs = self.fill_in_inputs_from_external_data_tools(
  125. tenant_id=app_record.tenant_id,
  126. app_id=app_record.id,
  127. external_data_tools=external_data_tools,
  128. inputs=inputs,
  129. query=query
  130. )
  131. # get context from datasets
  132. context = None
  133. if app_orchestration_config.dataset:
  134. context = self.retrieve_dataset_context(
  135. tenant_id=app_record.tenant_id,
  136. app_record=app_record,
  137. queue_manager=queue_manager,
  138. model_config=app_orchestration_config.model_config,
  139. show_retrieve_source=app_orchestration_config.show_retrieve_source,
  140. dataset_config=app_orchestration_config.dataset,
  141. message=message,
  142. inputs=inputs,
  143. query=query,
  144. user_id=application_generate_entity.user_id,
  145. invoke_from=application_generate_entity.invoke_from,
  146. memory=memory
  147. )
  148. # reorganize all inputs and template to prompt messages
  149. # Include: prompt template, inputs, query(optional), files(optional)
  150. # memory(optional), external data, dataset context(optional)
  151. prompt_messages, stop = self.organize_prompt_messages(
  152. app_record=app_record,
  153. model_config=app_orchestration_config.model_config,
  154. prompt_template_entity=app_orchestration_config.prompt_template,
  155. inputs=inputs,
  156. files=files,
  157. query=query,
  158. context=context,
  159. memory=memory
  160. )
  161. # check hosting moderation
  162. hosting_moderation_result = self.check_hosting_moderation(
  163. application_generate_entity=application_generate_entity,
  164. queue_manager=queue_manager,
  165. prompt_messages=prompt_messages
  166. )
  167. if hosting_moderation_result:
  168. return
  169. # Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit
  170. self.recale_llm_max_tokens(
  171. model_config=app_orchestration_config.model_config,
  172. prompt_messages=prompt_messages
  173. )
  174. # Invoke model
  175. model_instance = ModelInstance(
  176. provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle,
  177. model=app_orchestration_config.model_config.model
  178. )
  179. invoke_result = model_instance.invoke_llm(
  180. prompt_messages=prompt_messages,
  181. model_parameters=app_orchestration_config.model_config.parameters,
  182. stop=stop,
  183. stream=application_generate_entity.stream,
  184. user=application_generate_entity.user_id,
  185. )
  186. # handle invoke result
  187. self._handle_invoke_result(
  188. invoke_result=invoke_result,
  189. queue_manager=queue_manager,
  190. stream=application_generate_entity.stream
  191. )
  192. def moderation_for_inputs(self, app_id: str,
  193. tenant_id: str,
  194. app_orchestration_config_entity: AppOrchestrationConfigEntity,
  195. inputs: dict,
  196. query: str) -> Tuple[bool, dict, str]:
  197. """
  198. Process sensitive_word_avoidance.
  199. :param app_id: app id
  200. :param tenant_id: tenant id
  201. :param app_orchestration_config_entity: app orchestration config entity
  202. :param inputs: inputs
  203. :param query: query
  204. :return:
  205. """
  206. moderation_feature = ModerationFeature()
  207. return moderation_feature.check(
  208. app_id=app_id,
  209. tenant_id=tenant_id,
  210. app_orchestration_config_entity=app_orchestration_config_entity,
  211. inputs=inputs,
  212. query=query,
  213. )
  214. def query_app_annotations_to_reply(self, app_record: App,
  215. message: Message,
  216. query: str,
  217. user_id: str,
  218. invoke_from: InvokeFrom) -> Optional[MessageAnnotation]:
  219. """
  220. Query app annotations to reply
  221. :param app_record: app record
  222. :param message: message
  223. :param query: query
  224. :param user_id: user id
  225. :param invoke_from: invoke from
  226. :return:
  227. """
  228. annotation_reply_feature = AnnotationReplyFeature()
  229. return annotation_reply_feature.query(
  230. app_record=app_record,
  231. message=message,
  232. query=query,
  233. user_id=user_id,
  234. invoke_from=invoke_from
  235. )
  236. def fill_in_inputs_from_external_data_tools(self, tenant_id: str,
  237. app_id: str,
  238. external_data_tools: list[ExternalDataVariableEntity],
  239. inputs: dict,
  240. query: str) -> dict:
  241. """
  242. Fill in variable inputs from external data tools if exists.
  243. :param tenant_id: workspace id
  244. :param app_id: app id
  245. :param external_data_tools: external data tools configs
  246. :param inputs: the inputs
  247. :param query: the query
  248. :return: the filled inputs
  249. """
  250. external_data_fetch_feature = ExternalDataFetchFeature()
  251. return external_data_fetch_feature.fetch(
  252. tenant_id=tenant_id,
  253. app_id=app_id,
  254. external_data_tools=external_data_tools,
  255. inputs=inputs,
  256. query=query
  257. )
  258. def retrieve_dataset_context(self, tenant_id: str,
  259. app_record: App,
  260. queue_manager: ApplicationQueueManager,
  261. model_config: ModelConfigEntity,
  262. dataset_config: DatasetEntity,
  263. show_retrieve_source: bool,
  264. message: Message,
  265. inputs: dict,
  266. query: str,
  267. user_id: str,
  268. invoke_from: InvokeFrom,
  269. memory: Optional[TokenBufferMemory] = None) -> Optional[str]:
  270. """
  271. Retrieve dataset context
  272. :param tenant_id: tenant id
  273. :param app_record: app record
  274. :param queue_manager: queue manager
  275. :param model_config: model config
  276. :param dataset_config: dataset config
  277. :param show_retrieve_source: show retrieve source
  278. :param message: message
  279. :param inputs: inputs
  280. :param query: query
  281. :param user_id: user id
  282. :param invoke_from: invoke from
  283. :param memory: memory
  284. :return:
  285. """
  286. hit_callback = DatasetIndexToolCallbackHandler(
  287. queue_manager,
  288. app_record.id,
  289. message.id,
  290. user_id,
  291. invoke_from
  292. )
  293. if (app_record.mode == AppMode.COMPLETION.value and dataset_config
  294. and dataset_config.retrieve_config.query_variable):
  295. query = inputs.get(dataset_config.retrieve_config.query_variable, "")
  296. dataset_retrieval = DatasetRetrievalFeature()
  297. return dataset_retrieval.retrieve(
  298. tenant_id=tenant_id,
  299. model_config=model_config,
  300. config=dataset_config,
  301. query=query,
  302. invoke_from=invoke_from,
  303. show_retrieve_source=show_retrieve_source,
  304. hit_callback=hit_callback,
  305. memory=memory
  306. )
  307. def check_hosting_moderation(self, application_generate_entity: ApplicationGenerateEntity,
  308. queue_manager: ApplicationQueueManager,
  309. prompt_messages: list[PromptMessage]) -> bool:
  310. """
  311. Check hosting moderation
  312. :param application_generate_entity: application generate entity
  313. :param queue_manager: queue manager
  314. :param prompt_messages: prompt messages
  315. :return:
  316. """
  317. hosting_moderation_feature = HostingModerationFeature()
  318. moderation_result = hosting_moderation_feature.check(
  319. application_generate_entity=application_generate_entity,
  320. prompt_messages=prompt_messages
  321. )
  322. if moderation_result:
  323. self.direct_output(
  324. queue_manager=queue_manager,
  325. app_orchestration_config=application_generate_entity.app_orchestration_config_entity,
  326. prompt_messages=prompt_messages,
  327. text="I apologize for any confusion, " \
  328. "but I'm an AI assistant to be helpful, harmless, and honest.",
  329. stream=application_generate_entity.stream
  330. )
  331. return moderation_result