simple_prompt_transform.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326
  1. import enum
  2. import json
  3. import os
  4. from typing import TYPE_CHECKING, Optional
  5. from core.app.app_config.entities import PromptTemplateEntity
  6. from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
  7. from core.file import file_manager
  8. from core.memory.token_buffer_memory import TokenBufferMemory
  9. from core.model_runtime.entities.message_entities import (
  10. PromptMessage,
  11. PromptMessageContent,
  12. SystemPromptMessage,
  13. TextPromptMessageContent,
  14. UserPromptMessage,
  15. )
  16. from core.prompt.entities.advanced_prompt_entities import MemoryConfig
  17. from core.prompt.prompt_transform import PromptTransform
  18. from core.prompt.utils.prompt_template_parser import PromptTemplateParser
  19. from models.model import AppMode
  20. if TYPE_CHECKING:
  21. from core.file.models import File
  22. class ModelMode(str, enum.Enum):
  23. COMPLETION = "completion"
  24. CHAT = "chat"
  25. @classmethod
  26. def value_of(cls, value: str) -> "ModelMode":
  27. """
  28. Get value of given mode.
  29. :param value: mode value
  30. :return: mode
  31. """
  32. for mode in cls:
  33. if mode.value == value:
  34. return mode
  35. raise ValueError(f"invalid mode value {value}")
  36. prompt_file_contents = {}
  37. class SimplePromptTransform(PromptTransform):
  38. """
  39. Simple Prompt Transform for Chatbot App Basic Mode.
  40. """
  41. def get_prompt(
  42. self,
  43. app_mode: AppMode,
  44. prompt_template_entity: PromptTemplateEntity,
  45. inputs: dict,
  46. query: str,
  47. files: list["File"],
  48. context: Optional[str],
  49. memory: Optional[TokenBufferMemory],
  50. model_config: ModelConfigWithCredentialsEntity,
  51. ) -> tuple[list[PromptMessage], Optional[list[str]]]:
  52. inputs = {key: str(value) for key, value in inputs.items()}
  53. model_mode = ModelMode.value_of(model_config.mode)
  54. if model_mode == ModelMode.CHAT:
  55. prompt_messages, stops = self._get_chat_model_prompt_messages(
  56. app_mode=app_mode,
  57. pre_prompt=prompt_template_entity.simple_prompt_template,
  58. inputs=inputs,
  59. query=query,
  60. files=files,
  61. context=context,
  62. memory=memory,
  63. model_config=model_config,
  64. )
  65. else:
  66. prompt_messages, stops = self._get_completion_model_prompt_messages(
  67. app_mode=app_mode,
  68. pre_prompt=prompt_template_entity.simple_prompt_template,
  69. inputs=inputs,
  70. query=query,
  71. files=files,
  72. context=context,
  73. memory=memory,
  74. model_config=model_config,
  75. )
  76. return prompt_messages, stops
  77. def get_prompt_str_and_rules(
  78. self,
  79. app_mode: AppMode,
  80. model_config: ModelConfigWithCredentialsEntity,
  81. pre_prompt: str,
  82. inputs: dict,
  83. query: Optional[str] = None,
  84. context: Optional[str] = None,
  85. histories: Optional[str] = None,
  86. ) -> tuple[str, dict]:
  87. # get prompt template
  88. prompt_template_config = self.get_prompt_template(
  89. app_mode=app_mode,
  90. provider=model_config.provider,
  91. model=model_config.model,
  92. pre_prompt=pre_prompt,
  93. has_context=context is not None,
  94. query_in_prompt=query is not None,
  95. with_memory_prompt=histories is not None,
  96. )
  97. variables = {k: inputs[k] for k in prompt_template_config["custom_variable_keys"] if k in inputs}
  98. for v in prompt_template_config["special_variable_keys"]:
  99. # support #context#, #query# and #histories#
  100. if v == "#context#":
  101. variables["#context#"] = context or ""
  102. elif v == "#query#":
  103. variables["#query#"] = query or ""
  104. elif v == "#histories#":
  105. variables["#histories#"] = histories or ""
  106. prompt_template = prompt_template_config["prompt_template"]
  107. prompt = prompt_template.format(variables)
  108. return prompt, prompt_template_config["prompt_rules"]
  109. def get_prompt_template(
  110. self,
  111. app_mode: AppMode,
  112. provider: str,
  113. model: str,
  114. pre_prompt: str,
  115. has_context: bool,
  116. query_in_prompt: bool,
  117. with_memory_prompt: bool = False,
  118. ) -> dict:
  119. prompt_rules = self._get_prompt_rule(app_mode=app_mode, provider=provider, model=model)
  120. custom_variable_keys = []
  121. special_variable_keys = []
  122. prompt = ""
  123. for order in prompt_rules["system_prompt_orders"]:
  124. if order == "context_prompt" and has_context:
  125. prompt += prompt_rules["context_prompt"]
  126. special_variable_keys.append("#context#")
  127. elif order == "pre_prompt" and pre_prompt:
  128. prompt += pre_prompt + "\n"
  129. pre_prompt_template = PromptTemplateParser(template=pre_prompt)
  130. custom_variable_keys = pre_prompt_template.variable_keys
  131. elif order == "histories_prompt" and with_memory_prompt:
  132. prompt += prompt_rules["histories_prompt"]
  133. special_variable_keys.append("#histories#")
  134. if query_in_prompt:
  135. prompt += prompt_rules.get("query_prompt", "{{#query#}}")
  136. special_variable_keys.append("#query#")
  137. return {
  138. "prompt_template": PromptTemplateParser(template=prompt),
  139. "custom_variable_keys": custom_variable_keys,
  140. "special_variable_keys": special_variable_keys,
  141. "prompt_rules": prompt_rules,
  142. }
  143. def _get_chat_model_prompt_messages(
  144. self,
  145. app_mode: AppMode,
  146. pre_prompt: str,
  147. inputs: dict,
  148. query: str,
  149. context: Optional[str],
  150. files: list["File"],
  151. memory: Optional[TokenBufferMemory],
  152. model_config: ModelConfigWithCredentialsEntity,
  153. ) -> tuple[list[PromptMessage], Optional[list[str]]]:
  154. prompt_messages = []
  155. # get prompt
  156. prompt, _ = self.get_prompt_str_and_rules(
  157. app_mode=app_mode,
  158. model_config=model_config,
  159. pre_prompt=pre_prompt,
  160. inputs=inputs,
  161. query=None,
  162. context=context,
  163. )
  164. if prompt and query:
  165. prompt_messages.append(SystemPromptMessage(content=prompt))
  166. if memory:
  167. prompt_messages = self._append_chat_histories(
  168. memory=memory,
  169. memory_config=MemoryConfig(
  170. window=MemoryConfig.WindowConfig(
  171. enabled=False,
  172. )
  173. ),
  174. prompt_messages=prompt_messages,
  175. model_config=model_config,
  176. )
  177. if query:
  178. prompt_messages.append(self.get_last_user_message(query, files))
  179. else:
  180. prompt_messages.append(self.get_last_user_message(prompt, files))
  181. return prompt_messages, None
  182. def _get_completion_model_prompt_messages(
  183. self,
  184. app_mode: AppMode,
  185. pre_prompt: str,
  186. inputs: dict,
  187. query: str,
  188. context: Optional[str],
  189. files: list["File"],
  190. memory: Optional[TokenBufferMemory],
  191. model_config: ModelConfigWithCredentialsEntity,
  192. ) -> tuple[list[PromptMessage], Optional[list[str]]]:
  193. # get prompt
  194. prompt, prompt_rules = self.get_prompt_str_and_rules(
  195. app_mode=app_mode,
  196. model_config=model_config,
  197. pre_prompt=pre_prompt,
  198. inputs=inputs,
  199. query=query,
  200. context=context,
  201. )
  202. if memory:
  203. tmp_human_message = UserPromptMessage(content=prompt)
  204. rest_tokens = self._calculate_rest_token([tmp_human_message], model_config)
  205. histories = self._get_history_messages_from_memory(
  206. memory=memory,
  207. memory_config=MemoryConfig(
  208. window=MemoryConfig.WindowConfig(
  209. enabled=False,
  210. )
  211. ),
  212. max_token_limit=rest_tokens,
  213. human_prefix=prompt_rules.get("human_prefix", "Human"),
  214. ai_prefix=prompt_rules.get("assistant_prefix", "Assistant"),
  215. )
  216. # get prompt
  217. prompt, prompt_rules = self.get_prompt_str_and_rules(
  218. app_mode=app_mode,
  219. model_config=model_config,
  220. pre_prompt=pre_prompt,
  221. inputs=inputs,
  222. query=query,
  223. context=context,
  224. histories=histories,
  225. )
  226. stops = prompt_rules.get("stops")
  227. if stops is not None and len(stops) == 0:
  228. stops = None
  229. return [self.get_last_user_message(prompt, files)], stops
  230. def get_last_user_message(self, prompt: str, files: list["File"]) -> UserPromptMessage:
  231. if files:
  232. prompt_message_contents: list[PromptMessageContent] = []
  233. prompt_message_contents.append(TextPromptMessageContent(data=prompt))
  234. for file in files:
  235. prompt_message_contents.append(file_manager.to_prompt_message_content(file))
  236. prompt_message = UserPromptMessage(content=prompt_message_contents)
  237. else:
  238. prompt_message = UserPromptMessage(content=prompt)
  239. return prompt_message
  240. def _get_prompt_rule(self, app_mode: AppMode, provider: str, model: str) -> dict:
  241. """
  242. Get simple prompt rule.
  243. :param app_mode: app mode
  244. :param provider: model provider
  245. :param model: model name
  246. :return:
  247. """
  248. prompt_file_name = self._prompt_file_name(app_mode=app_mode, provider=provider, model=model)
  249. # Check if the prompt file is already loaded
  250. if prompt_file_name in prompt_file_contents:
  251. return prompt_file_contents[prompt_file_name]
  252. # Get the absolute path of the subdirectory
  253. prompt_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "prompt_templates")
  254. json_file_path = os.path.join(prompt_path, f"{prompt_file_name}.json")
  255. # Open the JSON file and read its content
  256. with open(json_file_path, encoding="utf-8") as json_file:
  257. content = json.load(json_file)
  258. # Store the content of the prompt file
  259. prompt_file_contents[prompt_file_name] = content
  260. return content
  261. def _prompt_file_name(self, app_mode: AppMode, provider: str, model: str) -> str:
  262. # baichuan
  263. is_baichuan = False
  264. if provider == "baichuan":
  265. is_baichuan = True
  266. else:
  267. baichuan_supported_providers = ["huggingface_hub", "openllm", "xinference"]
  268. if provider in baichuan_supported_providers and "baichuan" in model.lower():
  269. is_baichuan = True
  270. if is_baichuan:
  271. if app_mode == AppMode.COMPLETION:
  272. return "baichuan_completion"
  273. else:
  274. return "baichuan_chat"
  275. # common
  276. if app_mode == AppMode.COMPLETION:
  277. return "common_completion"
  278. else:
  279. return "common_chat"