base_app_runner.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436
  1. import time
  2. from collections.abc import Generator
  3. from typing import Optional, Union, cast
  4. from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity
  5. from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
  6. from core.app.entities.app_invoke_entities import (
  7. AppGenerateEntity,
  8. EasyUIBasedAppGenerateEntity,
  9. InvokeFrom,
  10. ModelConfigWithCredentialsEntity,
  11. )
  12. from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChunkEvent, QueueMessageEndEvent
  13. from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
  14. from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature
  15. from core.external_data_tool.external_data_fetch import ExternalDataFetch
  16. from core.file.file_obj import FileVar
  17. from core.memory.token_buffer_memory import TokenBufferMemory
  18. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
  19. from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage
  20. from core.model_runtime.entities.model_entities import ModelPropertyKey
  21. from core.model_runtime.errors.invoke import InvokeBadRequestError
  22. from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
  23. from core.moderation.input_moderation import InputModeration
  24. from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
  25. from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
  26. from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform
  27. from models.model import App, AppMode, Message, MessageAnnotation
  28. class AppRunner:
  29. def get_pre_calculate_rest_tokens(self, app_record: App,
  30. model_config: ModelConfigWithCredentialsEntity,
  31. prompt_template_entity: PromptTemplateEntity,
  32. inputs: dict[str, str],
  33. files: list[FileVar],
  34. query: Optional[str] = None) -> int:
  35. """
  36. Get pre calculate rest tokens
  37. :param app_record: app record
  38. :param model_config: model config entity
  39. :param prompt_template_entity: prompt template entity
  40. :param inputs: inputs
  41. :param files: files
  42. :param query: query
  43. :return:
  44. """
  45. model_type_instance = model_config.provider_model_bundle.model_type_instance
  46. model_type_instance = cast(LargeLanguageModel, model_type_instance)
  47. model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
  48. max_tokens = 0
  49. for parameter_rule in model_config.model_schema.parameter_rules:
  50. if (parameter_rule.name == 'max_tokens'
  51. or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
  52. max_tokens = (model_config.parameters.get(parameter_rule.name)
  53. or model_config.parameters.get(parameter_rule.use_template)) or 0
  54. if model_context_tokens is None:
  55. return -1
  56. if max_tokens is None:
  57. max_tokens = 0
  58. # get prompt messages without memory and context
  59. prompt_messages, stop = self.organize_prompt_messages(
  60. app_record=app_record,
  61. model_config=model_config,
  62. prompt_template_entity=prompt_template_entity,
  63. inputs=inputs,
  64. files=files,
  65. query=query
  66. )
  67. prompt_tokens = model_type_instance.get_num_tokens(
  68. model_config.model,
  69. model_config.credentials,
  70. prompt_messages
  71. )
  72. rest_tokens = model_context_tokens - max_tokens - prompt_tokens
  73. if rest_tokens < 0:
  74. raise InvokeBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, "
  75. "or shrink the max token, or switch to a llm with a larger token limit size.")
  76. return rest_tokens
  77. def recalc_llm_max_tokens(self, model_config: ModelConfigWithCredentialsEntity,
  78. prompt_messages: list[PromptMessage]):
  79. # recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
  80. model_type_instance = model_config.provider_model_bundle.model_type_instance
  81. model_type_instance = cast(LargeLanguageModel, model_type_instance)
  82. model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
  83. max_tokens = 0
  84. for parameter_rule in model_config.model_schema.parameter_rules:
  85. if (parameter_rule.name == 'max_tokens'
  86. or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
  87. max_tokens = (model_config.parameters.get(parameter_rule.name)
  88. or model_config.parameters.get(parameter_rule.use_template)) or 0
  89. if model_context_tokens is None:
  90. return -1
  91. if max_tokens is None:
  92. max_tokens = 0
  93. prompt_tokens = model_type_instance.get_num_tokens(
  94. model_config.model,
  95. model_config.credentials,
  96. prompt_messages
  97. )
  98. if prompt_tokens + max_tokens > model_context_tokens:
  99. max_tokens = max(model_context_tokens - prompt_tokens, 16)
  100. for parameter_rule in model_config.model_schema.parameter_rules:
  101. if (parameter_rule.name == 'max_tokens'
  102. or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
  103. model_config.parameters[parameter_rule.name] = max_tokens
  104. def organize_prompt_messages(self, app_record: App,
  105. model_config: ModelConfigWithCredentialsEntity,
  106. prompt_template_entity: PromptTemplateEntity,
  107. inputs: dict[str, str],
  108. files: list[FileVar],
  109. query: Optional[str] = None,
  110. context: Optional[str] = None,
  111. memory: Optional[TokenBufferMemory] = None) \
  112. -> tuple[list[PromptMessage], Optional[list[str]]]:
  113. """
  114. Organize prompt messages
  115. :param context:
  116. :param app_record: app record
  117. :param model_config: model config entity
  118. :param prompt_template_entity: prompt template entity
  119. :param inputs: inputs
  120. :param files: files
  121. :param query: query
  122. :param memory: memory
  123. :return:
  124. """
  125. # get prompt without memory and context
  126. if prompt_template_entity.prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
  127. prompt_transform = SimplePromptTransform()
  128. prompt_messages, stop = prompt_transform.get_prompt(
  129. app_mode=AppMode.value_of(app_record.mode),
  130. prompt_template_entity=prompt_template_entity,
  131. inputs=inputs,
  132. query=query if query else '',
  133. files=files,
  134. context=context,
  135. memory=memory,
  136. model_config=model_config
  137. )
  138. else:
  139. memory_config = MemoryConfig(
  140. window=MemoryConfig.WindowConfig(
  141. enabled=False
  142. )
  143. )
  144. model_mode = ModelMode.value_of(model_config.mode)
  145. if model_mode == ModelMode.COMPLETION:
  146. advanced_completion_prompt_template = prompt_template_entity.advanced_completion_prompt_template
  147. prompt_template = CompletionModelPromptTemplate(
  148. text=advanced_completion_prompt_template.prompt
  149. )
  150. if advanced_completion_prompt_template.role_prefix:
  151. memory_config.role_prefix = MemoryConfig.RolePrefix(
  152. user=advanced_completion_prompt_template.role_prefix.user,
  153. assistant=advanced_completion_prompt_template.role_prefix.assistant
  154. )
  155. else:
  156. prompt_template = []
  157. for message in prompt_template_entity.advanced_chat_prompt_template.messages:
  158. prompt_template.append(ChatModelMessage(
  159. text=message.text,
  160. role=message.role
  161. ))
  162. prompt_transform = AdvancedPromptTransform()
  163. prompt_messages = prompt_transform.get_prompt(
  164. prompt_template=prompt_template,
  165. inputs=inputs,
  166. query=query if query else '',
  167. files=files,
  168. context=context,
  169. memory_config=memory_config,
  170. memory=memory,
  171. model_config=model_config
  172. )
  173. stop = model_config.stop
  174. return prompt_messages, stop
  175. def direct_output(self, queue_manager: AppQueueManager,
  176. app_generate_entity: EasyUIBasedAppGenerateEntity,
  177. prompt_messages: list,
  178. text: str,
  179. stream: bool,
  180. usage: Optional[LLMUsage] = None) -> None:
  181. """
  182. Direct output
  183. :param queue_manager: application queue manager
  184. :param app_generate_entity: app generate entity
  185. :param prompt_messages: prompt messages
  186. :param text: text
  187. :param stream: stream
  188. :param usage: usage
  189. :return:
  190. """
  191. if stream:
  192. index = 0
  193. for token in text:
  194. chunk = LLMResultChunk(
  195. model=app_generate_entity.model_config.model,
  196. prompt_messages=prompt_messages,
  197. delta=LLMResultChunkDelta(
  198. index=index,
  199. message=AssistantPromptMessage(content=token)
  200. )
  201. )
  202. queue_manager.publish(
  203. QueueLLMChunkEvent(
  204. chunk=chunk
  205. ), PublishFrom.APPLICATION_MANAGER
  206. )
  207. index += 1
  208. time.sleep(0.01)
  209. queue_manager.publish(
  210. QueueMessageEndEvent(
  211. llm_result=LLMResult(
  212. model=app_generate_entity.model_config.model,
  213. prompt_messages=prompt_messages,
  214. message=AssistantPromptMessage(content=text),
  215. usage=usage if usage else LLMUsage.empty_usage()
  216. ),
  217. ), PublishFrom.APPLICATION_MANAGER
  218. )
  219. def _handle_invoke_result(self, invoke_result: Union[LLMResult, Generator],
  220. queue_manager: AppQueueManager,
  221. stream: bool,
  222. agent: bool = False) -> None:
  223. """
  224. Handle invoke result
  225. :param invoke_result: invoke result
  226. :param queue_manager: application queue manager
  227. :param stream: stream
  228. :return:
  229. """
  230. if not stream:
  231. self._handle_invoke_result_direct(
  232. invoke_result=invoke_result,
  233. queue_manager=queue_manager,
  234. agent=agent
  235. )
  236. else:
  237. self._handle_invoke_result_stream(
  238. invoke_result=invoke_result,
  239. queue_manager=queue_manager,
  240. agent=agent
  241. )
  242. def _handle_invoke_result_direct(self, invoke_result: LLMResult,
  243. queue_manager: AppQueueManager,
  244. agent: bool) -> None:
  245. """
  246. Handle invoke result direct
  247. :param invoke_result: invoke result
  248. :param queue_manager: application queue manager
  249. :return:
  250. """
  251. queue_manager.publish(
  252. QueueMessageEndEvent(
  253. llm_result=invoke_result,
  254. ), PublishFrom.APPLICATION_MANAGER
  255. )
  256. def _handle_invoke_result_stream(self, invoke_result: Generator,
  257. queue_manager: AppQueueManager,
  258. agent: bool) -> None:
  259. """
  260. Handle invoke result
  261. :param invoke_result: invoke result
  262. :param queue_manager: application queue manager
  263. :return:
  264. """
  265. model = None
  266. prompt_messages = []
  267. text = ''
  268. usage = None
  269. for result in invoke_result:
  270. if not agent:
  271. queue_manager.publish(
  272. QueueLLMChunkEvent(
  273. chunk=result
  274. ), PublishFrom.APPLICATION_MANAGER
  275. )
  276. else:
  277. queue_manager.publish(
  278. QueueAgentMessageEvent(
  279. chunk=result
  280. ), PublishFrom.APPLICATION_MANAGER
  281. )
  282. text += result.delta.message.content
  283. if not model:
  284. model = result.model
  285. if not prompt_messages:
  286. prompt_messages = result.prompt_messages
  287. if not usage and result.delta.usage:
  288. usage = result.delta.usage
  289. if not usage:
  290. usage = LLMUsage.empty_usage()
  291. llm_result = LLMResult(
  292. model=model,
  293. prompt_messages=prompt_messages,
  294. message=AssistantPromptMessage(content=text),
  295. usage=usage
  296. )
  297. queue_manager.publish(
  298. QueueMessageEndEvent(
  299. llm_result=llm_result,
  300. ), PublishFrom.APPLICATION_MANAGER
  301. )
  302. def moderation_for_inputs(self, app_id: str,
  303. tenant_id: str,
  304. app_generate_entity: AppGenerateEntity,
  305. inputs: dict,
  306. query: str) -> tuple[bool, dict, str]:
  307. """
  308. Process sensitive_word_avoidance.
  309. :param app_id: app id
  310. :param tenant_id: tenant id
  311. :param app_generate_entity: app generate entity
  312. :param inputs: inputs
  313. :param query: query
  314. :return:
  315. """
  316. moderation_feature = InputModeration()
  317. return moderation_feature.check(
  318. app_id=app_id,
  319. tenant_id=tenant_id,
  320. app_config=app_generate_entity.app_config,
  321. inputs=inputs,
  322. query=query if query else ''
  323. )
  324. def check_hosting_moderation(self, application_generate_entity: EasyUIBasedAppGenerateEntity,
  325. queue_manager: AppQueueManager,
  326. prompt_messages: list[PromptMessage]) -> bool:
  327. """
  328. Check hosting moderation
  329. :param application_generate_entity: application generate entity
  330. :param queue_manager: queue manager
  331. :param prompt_messages: prompt messages
  332. :return:
  333. """
  334. hosting_moderation_feature = HostingModerationFeature()
  335. moderation_result = hosting_moderation_feature.check(
  336. application_generate_entity=application_generate_entity,
  337. prompt_messages=prompt_messages
  338. )
  339. if moderation_result:
  340. self.direct_output(
  341. queue_manager=queue_manager,
  342. app_generate_entity=application_generate_entity,
  343. prompt_messages=prompt_messages,
  344. text="I apologize for any confusion, " \
  345. "but I'm an AI assistant to be helpful, harmless, and honest.",
  346. stream=application_generate_entity.stream
  347. )
  348. return moderation_result
  349. def fill_in_inputs_from_external_data_tools(self, tenant_id: str,
  350. app_id: str,
  351. external_data_tools: list[ExternalDataVariableEntity],
  352. inputs: dict,
  353. query: str) -> dict:
  354. """
  355. Fill in variable inputs from external data tools if exists.
  356. :param tenant_id: workspace id
  357. :param app_id: app id
  358. :param external_data_tools: external data tools configs
  359. :param inputs: the inputs
  360. :param query: the query
  361. :return: the filled inputs
  362. """
  363. external_data_fetch_feature = ExternalDataFetch()
  364. return external_data_fetch_feature.fetch(
  365. tenant_id=tenant_id,
  366. app_id=app_id,
  367. external_data_tools=external_data_tools,
  368. inputs=inputs,
  369. query=query
  370. )
  371. def query_app_annotations_to_reply(self, app_record: App,
  372. message: Message,
  373. query: str,
  374. user_id: str,
  375. invoke_from: InvokeFrom) -> Optional[MessageAnnotation]:
  376. """
  377. Query app annotations to reply
  378. :param app_record: app record
  379. :param message: message
  380. :param query: query
  381. :param user_id: user id
  382. :param invoke_from: invoke from
  383. :return:
  384. """
  385. annotation_reply_feature = AnnotationReplyFeature()
  386. return annotation_reply_feature.query(
  387. app_record=app_record,
  388. message=message,
  389. query=query,
  390. user_id=user_id,
  391. invoke_from=invoke_from
  392. )