app_runner.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267
  1. import time
  2. from typing import cast, Optional, List, Tuple, Generator, Union
  3. from core.application_queue_manager import ApplicationQueueManager
  4. from core.entities.application_entities import ModelConfigEntity, PromptTemplateEntity, AppOrchestrationConfigEntity
  5. from core.file.file_obj import FileObj
  6. from core.memory.token_buffer_memory import TokenBufferMemory
  7. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
  8. from core.model_runtime.entities.message_entities import PromptMessage, AssistantPromptMessage
  9. from core.model_runtime.entities.model_entities import ModelPropertyKey
  10. from core.model_runtime.errors.invoke import InvokeBadRequestError
  11. from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
  12. from core.prompt.prompt_transform import PromptTransform
  13. from models.model import App
  14. class AppRunner:
  15. def get_pre_calculate_rest_tokens(self, app_record: App,
  16. model_config: ModelConfigEntity,
  17. prompt_template_entity: PromptTemplateEntity,
  18. inputs: dict[str, str],
  19. files: list[FileObj],
  20. query: Optional[str] = None) -> int:
  21. """
  22. Get pre calculate rest tokens
  23. :param app_record: app record
  24. :param model_config: model config entity
  25. :param prompt_template_entity: prompt template entity
  26. :param inputs: inputs
  27. :param files: files
  28. :param query: query
  29. :return:
  30. """
  31. model_type_instance = model_config.provider_model_bundle.model_type_instance
  32. model_type_instance = cast(LargeLanguageModel, model_type_instance)
  33. model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
  34. max_tokens = 0
  35. for parameter_rule in model_config.model_schema.parameter_rules:
  36. if (parameter_rule.name == 'max_tokens'
  37. or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
  38. max_tokens = (model_config.parameters.get(parameter_rule.name)
  39. or model_config.parameters.get(parameter_rule.use_template)) or 0
  40. if model_context_tokens is None:
  41. return -1
  42. if max_tokens is None:
  43. max_tokens = 0
  44. # get prompt messages without memory and context
  45. prompt_messages, stop = self.organize_prompt_messages(
  46. app_record=app_record,
  47. model_config=model_config,
  48. prompt_template_entity=prompt_template_entity,
  49. inputs=inputs,
  50. files=files,
  51. query=query
  52. )
  53. prompt_tokens = model_type_instance.get_num_tokens(
  54. model_config.model,
  55. model_config.credentials,
  56. prompt_messages
  57. )
  58. rest_tokens = model_context_tokens - max_tokens - prompt_tokens
  59. if rest_tokens < 0:
  60. raise InvokeBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, "
  61. "or shrink the max token, or switch to a llm with a larger token limit size.")
  62. return rest_tokens
  63. def recale_llm_max_tokens(self, model_config: ModelConfigEntity,
  64. prompt_messages: List[PromptMessage]):
  65. # recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
  66. model_type_instance = model_config.provider_model_bundle.model_type_instance
  67. model_type_instance = cast(LargeLanguageModel, model_type_instance)
  68. model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
  69. max_tokens = 0
  70. for parameter_rule in model_config.model_schema.parameter_rules:
  71. if (parameter_rule.name == 'max_tokens'
  72. or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
  73. max_tokens = (model_config.parameters.get(parameter_rule.name)
  74. or model_config.parameters.get(parameter_rule.use_template)) or 0
  75. if model_context_tokens is None:
  76. return -1
  77. if max_tokens is None:
  78. max_tokens = 0
  79. prompt_tokens = model_type_instance.get_num_tokens(
  80. model_config.model,
  81. model_config.credentials,
  82. prompt_messages
  83. )
  84. if prompt_tokens + max_tokens > model_context_tokens:
  85. max_tokens = max(model_context_tokens - prompt_tokens, 16)
  86. for parameter_rule in model_config.model_schema.parameter_rules:
  87. if (parameter_rule.name == 'max_tokens'
  88. or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
  89. model_config.parameters[parameter_rule.name] = max_tokens
  90. def organize_prompt_messages(self, app_record: App,
  91. model_config: ModelConfigEntity,
  92. prompt_template_entity: PromptTemplateEntity,
  93. inputs: dict[str, str],
  94. files: list[FileObj],
  95. query: Optional[str] = None,
  96. context: Optional[str] = None,
  97. memory: Optional[TokenBufferMemory] = None) \
  98. -> Tuple[List[PromptMessage], Optional[List[str]]]:
  99. """
  100. Organize prompt messages
  101. :param context:
  102. :param app_record: app record
  103. :param model_config: model config entity
  104. :param prompt_template_entity: prompt template entity
  105. :param inputs: inputs
  106. :param files: files
  107. :param query: query
  108. :param memory: memory
  109. :return:
  110. """
  111. prompt_transform = PromptTransform()
  112. # get prompt without memory and context
  113. if prompt_template_entity.prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
  114. prompt_messages, stop = prompt_transform.get_prompt(
  115. app_mode=app_record.mode,
  116. prompt_template_entity=prompt_template_entity,
  117. inputs=inputs,
  118. query=query if query else '',
  119. files=files,
  120. context=context,
  121. memory=memory,
  122. model_config=model_config
  123. )
  124. else:
  125. prompt_messages = prompt_transform.get_advanced_prompt(
  126. app_mode=app_record.mode,
  127. prompt_template_entity=prompt_template_entity,
  128. inputs=inputs,
  129. query=query,
  130. files=files,
  131. context=context,
  132. memory=memory,
  133. model_config=model_config
  134. )
  135. stop = model_config.stop
  136. return prompt_messages, stop
  137. def direct_output(self, queue_manager: ApplicationQueueManager,
  138. app_orchestration_config: AppOrchestrationConfigEntity,
  139. prompt_messages: list,
  140. text: str,
  141. stream: bool,
  142. usage: Optional[LLMUsage] = None) -> None:
  143. """
  144. Direct output
  145. :param queue_manager: application queue manager
  146. :param app_orchestration_config: app orchestration config
  147. :param prompt_messages: prompt messages
  148. :param text: text
  149. :param stream: stream
  150. :param usage: usage
  151. :return:
  152. """
  153. if stream:
  154. index = 0
  155. for token in text:
  156. queue_manager.publish_chunk_message(LLMResultChunk(
  157. model=app_orchestration_config.model_config.model,
  158. prompt_messages=prompt_messages,
  159. delta=LLMResultChunkDelta(
  160. index=index,
  161. message=AssistantPromptMessage(content=token)
  162. )
  163. ))
  164. index += 1
  165. time.sleep(0.01)
  166. queue_manager.publish_message_end(
  167. llm_result=LLMResult(
  168. model=app_orchestration_config.model_config.model,
  169. prompt_messages=prompt_messages,
  170. message=AssistantPromptMessage(content=text),
  171. usage=usage if usage else LLMUsage.empty_usage()
  172. )
  173. )
  174. def _handle_invoke_result(self, invoke_result: Union[LLMResult, Generator],
  175. queue_manager: ApplicationQueueManager,
  176. stream: bool) -> None:
  177. """
  178. Handle invoke result
  179. :param invoke_result: invoke result
  180. :param queue_manager: application queue manager
  181. :param stream: stream
  182. :return:
  183. """
  184. if not stream:
  185. self._handle_invoke_result_direct(
  186. invoke_result=invoke_result,
  187. queue_manager=queue_manager
  188. )
  189. else:
  190. self._handle_invoke_result_stream(
  191. invoke_result=invoke_result,
  192. queue_manager=queue_manager
  193. )
  194. def _handle_invoke_result_direct(self, invoke_result: LLMResult,
  195. queue_manager: ApplicationQueueManager) -> None:
  196. """
  197. Handle invoke result direct
  198. :param invoke_result: invoke result
  199. :param queue_manager: application queue manager
  200. :return:
  201. """
  202. queue_manager.publish_message_end(
  203. llm_result=invoke_result
  204. )
  205. def _handle_invoke_result_stream(self, invoke_result: Generator,
  206. queue_manager: ApplicationQueueManager) -> None:
  207. """
  208. Handle invoke result
  209. :param invoke_result: invoke result
  210. :param queue_manager: application queue manager
  211. :return:
  212. """
  213. model = None
  214. prompt_messages = []
  215. text = ''
  216. usage = None
  217. for result in invoke_result:
  218. queue_manager.publish_chunk_message(result)
  219. text += result.delta.message.content
  220. if not model:
  221. model = result.model
  222. if not prompt_messages:
  223. prompt_messages = result.prompt_messages
  224. if not usage and result.delta.usage:
  225. usage = result.delta.usage
  226. llm_result = LLMResult(
  227. model=model,
  228. prompt_messages=prompt_messages,
  229. message=AssistantPromptMessage(content=text),
  230. usage=usage
  231. )
  232. queue_manager.publish_message_end(
  233. llm_result=llm_result
  234. )