simple_prompt_transform.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319
  1. import enum
  2. import json
  3. import os
  4. from typing import Optional
  5. from core.app.app_config.entities import PromptTemplateEntity
  6. from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
  7. from core.file.file_obj import FileVar
  8. from core.memory.token_buffer_memory import TokenBufferMemory
  9. from core.model_runtime.entities.message_entities import (
  10. PromptMessage,
  11. SystemPromptMessage,
  12. TextPromptMessageContent,
  13. UserPromptMessage,
  14. )
  15. from core.prompt.entities.advanced_prompt_entities import MemoryConfig
  16. from core.prompt.prompt_transform import PromptTransform
  17. from core.prompt.utils.prompt_template_parser import PromptTemplateParser
  18. from models.model import AppMode
  19. class ModelMode(enum.Enum):
  20. COMPLETION = 'completion'
  21. CHAT = 'chat'
  22. @classmethod
  23. def value_of(cls, value: str) -> 'ModelMode':
  24. """
  25. Get value of given mode.
  26. :param value: mode value
  27. :return: mode
  28. """
  29. for mode in cls:
  30. if mode.value == value:
  31. return mode
  32. raise ValueError(f'invalid mode value {value}')
  33. prompt_file_contents = {}
  34. class SimplePromptTransform(PromptTransform):
  35. """
  36. Simple Prompt Transform for Chatbot App Basic Mode.
  37. """
  38. def get_prompt(self,
  39. app_mode: AppMode,
  40. prompt_template_entity: PromptTemplateEntity,
  41. inputs: dict,
  42. query: str,
  43. files: list[FileVar],
  44. context: Optional[str],
  45. memory: Optional[TokenBufferMemory],
  46. model_config: ModelConfigWithCredentialsEntity) -> \
  47. tuple[list[PromptMessage], Optional[list[str]]]:
  48. model_mode = ModelMode.value_of(model_config.mode)
  49. if model_mode == ModelMode.CHAT:
  50. prompt_messages, stops = self._get_chat_model_prompt_messages(
  51. app_mode=app_mode,
  52. pre_prompt=prompt_template_entity.simple_prompt_template,
  53. inputs=inputs,
  54. query=query,
  55. files=files,
  56. context=context,
  57. memory=memory,
  58. model_config=model_config
  59. )
  60. else:
  61. prompt_messages, stops = self._get_completion_model_prompt_messages(
  62. app_mode=app_mode,
  63. pre_prompt=prompt_template_entity.simple_prompt_template,
  64. inputs=inputs,
  65. query=query,
  66. files=files,
  67. context=context,
  68. memory=memory,
  69. model_config=model_config
  70. )
  71. return prompt_messages, stops
  72. def get_prompt_str_and_rules(self, app_mode: AppMode,
  73. model_config: ModelConfigWithCredentialsEntity,
  74. pre_prompt: str,
  75. inputs: dict,
  76. query: Optional[str] = None,
  77. context: Optional[str] = None,
  78. histories: Optional[str] = None,
  79. ) -> tuple[str, dict]:
  80. # get prompt template
  81. prompt_template_config = self.get_prompt_template(
  82. app_mode=app_mode,
  83. provider=model_config.provider,
  84. model=model_config.model,
  85. pre_prompt=pre_prompt,
  86. has_context=context is not None,
  87. query_in_prompt=query is not None,
  88. with_memory_prompt=histories is not None
  89. )
  90. variables = {k: inputs[k] for k in prompt_template_config['custom_variable_keys'] if k in inputs}
  91. for v in prompt_template_config['special_variable_keys']:
  92. # support #context#, #query# and #histories#
  93. if v == '#context#':
  94. variables['#context#'] = context if context else ''
  95. elif v == '#query#':
  96. variables['#query#'] = query if query else ''
  97. elif v == '#histories#':
  98. variables['#histories#'] = histories if histories else ''
  99. prompt_template = prompt_template_config['prompt_template']
  100. prompt = prompt_template.format(variables)
  101. return prompt, prompt_template_config['prompt_rules']
  102. def get_prompt_template(self, app_mode: AppMode,
  103. provider: str,
  104. model: str,
  105. pre_prompt: str,
  106. has_context: bool,
  107. query_in_prompt: bool,
  108. with_memory_prompt: bool = False) -> dict:
  109. prompt_rules = self._get_prompt_rule(
  110. app_mode=app_mode,
  111. provider=provider,
  112. model=model
  113. )
  114. custom_variable_keys = []
  115. special_variable_keys = []
  116. prompt = ''
  117. for order in prompt_rules['system_prompt_orders']:
  118. if order == 'context_prompt' and has_context:
  119. prompt += prompt_rules['context_prompt']
  120. special_variable_keys.append('#context#')
  121. elif order == 'pre_prompt' and pre_prompt:
  122. prompt += pre_prompt + '\n'
  123. pre_prompt_template = PromptTemplateParser(template=pre_prompt)
  124. custom_variable_keys = pre_prompt_template.variable_keys
  125. elif order == 'histories_prompt' and with_memory_prompt:
  126. prompt += prompt_rules['histories_prompt']
  127. special_variable_keys.append('#histories#')
  128. if query_in_prompt:
  129. prompt += prompt_rules['query_prompt'] if 'query_prompt' in prompt_rules else '{{#query#}}'
  130. special_variable_keys.append('#query#')
  131. return {
  132. "prompt_template": PromptTemplateParser(template=prompt),
  133. "custom_variable_keys": custom_variable_keys,
  134. "special_variable_keys": special_variable_keys,
  135. "prompt_rules": prompt_rules
  136. }
  137. def _get_chat_model_prompt_messages(self, app_mode: AppMode,
  138. pre_prompt: str,
  139. inputs: dict,
  140. query: str,
  141. context: Optional[str],
  142. files: list[FileVar],
  143. memory: Optional[TokenBufferMemory],
  144. model_config: ModelConfigWithCredentialsEntity) \
  145. -> tuple[list[PromptMessage], Optional[list[str]]]:
  146. prompt_messages = []
  147. # get prompt
  148. prompt, _ = self.get_prompt_str_and_rules(
  149. app_mode=app_mode,
  150. model_config=model_config,
  151. pre_prompt=pre_prompt,
  152. inputs=inputs,
  153. query=None,
  154. context=context
  155. )
  156. if prompt and query:
  157. prompt_messages.append(SystemPromptMessage(content=prompt))
  158. if memory:
  159. prompt_messages = self._append_chat_histories(
  160. memory=memory,
  161. memory_config=MemoryConfig(
  162. window=MemoryConfig.WindowConfig(
  163. enabled=False,
  164. )
  165. ),
  166. prompt_messages=prompt_messages,
  167. model_config=model_config
  168. )
  169. if query:
  170. prompt_messages.append(self.get_last_user_message(query, files))
  171. else:
  172. prompt_messages.append(self.get_last_user_message(prompt, files))
  173. return prompt_messages, None
  174. def _get_completion_model_prompt_messages(self, app_mode: AppMode,
  175. pre_prompt: str,
  176. inputs: dict,
  177. query: str,
  178. context: Optional[str],
  179. files: list[FileVar],
  180. memory: Optional[TokenBufferMemory],
  181. model_config: ModelConfigWithCredentialsEntity) \
  182. -> tuple[list[PromptMessage], Optional[list[str]]]:
  183. # get prompt
  184. prompt, prompt_rules = self.get_prompt_str_and_rules(
  185. app_mode=app_mode,
  186. model_config=model_config,
  187. pre_prompt=pre_prompt,
  188. inputs=inputs,
  189. query=query,
  190. context=context
  191. )
  192. if memory:
  193. tmp_human_message = UserPromptMessage(
  194. content=prompt
  195. )
  196. rest_tokens = self._calculate_rest_token([tmp_human_message], model_config)
  197. histories = self._get_history_messages_from_memory(
  198. memory=memory,
  199. memory_config=MemoryConfig(
  200. window=MemoryConfig.WindowConfig(
  201. enabled=False,
  202. )
  203. ),
  204. max_token_limit=rest_tokens,
  205. ai_prefix=prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human',
  206. human_prefix=prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant'
  207. )
  208. # get prompt
  209. prompt, prompt_rules = self.get_prompt_str_and_rules(
  210. app_mode=app_mode,
  211. model_config=model_config,
  212. pre_prompt=pre_prompt,
  213. inputs=inputs,
  214. query=query,
  215. context=context,
  216. histories=histories
  217. )
  218. stops = prompt_rules.get('stops')
  219. if stops is not None and len(stops) == 0:
  220. stops = None
  221. return [self.get_last_user_message(prompt, files)], stops
  222. def get_last_user_message(self, prompt: str, files: list[FileVar]) -> UserPromptMessage:
  223. if files:
  224. prompt_message_contents = [TextPromptMessageContent(data=prompt)]
  225. for file in files:
  226. prompt_message_contents.append(file.prompt_message_content)
  227. prompt_message = UserPromptMessage(content=prompt_message_contents)
  228. else:
  229. prompt_message = UserPromptMessage(content=prompt)
  230. return prompt_message
  231. def _get_prompt_rule(self, app_mode: AppMode, provider: str, model: str) -> dict:
  232. """
  233. Get simple prompt rule.
  234. :param app_mode: app mode
  235. :param provider: model provider
  236. :param model: model name
  237. :return:
  238. """
  239. prompt_file_name = self._prompt_file_name(
  240. app_mode=app_mode,
  241. provider=provider,
  242. model=model
  243. )
  244. # Check if the prompt file is already loaded
  245. if prompt_file_name in prompt_file_contents:
  246. return prompt_file_contents[prompt_file_name]
  247. # Get the absolute path of the subdirectory
  248. prompt_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'prompt_templates')
  249. json_file_path = os.path.join(prompt_path, f'{prompt_file_name}.json')
  250. # Open the JSON file and read its content
  251. with open(json_file_path, encoding='utf-8') as json_file:
  252. content = json.load(json_file)
  253. # Store the content of the prompt file
  254. prompt_file_contents[prompt_file_name] = content
  255. return content
  256. def _prompt_file_name(self, app_mode: AppMode, provider: str, model: str) -> str:
  257. # baichuan
  258. is_baichuan = False
  259. if provider == 'baichuan':
  260. is_baichuan = True
  261. else:
  262. baichuan_supported_providers = ["huggingface_hub", "openllm", "xinference"]
  263. if provider in baichuan_supported_providers and 'baichuan' in model.lower():
  264. is_baichuan = True
  265. if is_baichuan:
  266. if app_mode == AppMode.COMPLETION:
  267. return 'baichuan_completion'
  268. else:
  269. return 'baichuan_chat'
  270. # common
  271. if app_mode == AppMode.COMPLETION:
  272. return 'common_completion'
  273. else:
  274. return 'common_chat'