123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326 |
- import enum
- import json
- import os
- from typing import TYPE_CHECKING, Optional
- from core.app.app_config.entities import PromptTemplateEntity
- from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
- from core.file import file_manager
- from core.memory.token_buffer_memory import TokenBufferMemory
- from core.model_runtime.entities.message_entities import (
- PromptMessage,
- PromptMessageContent,
- SystemPromptMessage,
- TextPromptMessageContent,
- UserPromptMessage,
- )
- from core.prompt.entities.advanced_prompt_entities import MemoryConfig
- from core.prompt.prompt_transform import PromptTransform
- from core.prompt.utils.prompt_template_parser import PromptTemplateParser
- from models.model import AppMode
- if TYPE_CHECKING:
- from core.file.models import File
- class ModelMode(str, enum.Enum):
- COMPLETION = "completion"
- CHAT = "chat"
- @classmethod
- def value_of(cls, value: str) -> "ModelMode":
- """
- Get value of given mode.
- :param value: mode value
- :return: mode
- """
- for mode in cls:
- if mode.value == value:
- return mode
- raise ValueError(f"invalid mode value {value}")
- prompt_file_contents = {}
- class SimplePromptTransform(PromptTransform):
- """
- Simple Prompt Transform for Chatbot App Basic Mode.
- """
- def get_prompt(
- self,
- app_mode: AppMode,
- prompt_template_entity: PromptTemplateEntity,
- inputs: dict,
- query: str,
- files: list["File"],
- context: Optional[str],
- memory: Optional[TokenBufferMemory],
- model_config: ModelConfigWithCredentialsEntity,
- ) -> tuple[list[PromptMessage], Optional[list[str]]]:
- inputs = {key: str(value) for key, value in inputs.items()}
- model_mode = ModelMode.value_of(model_config.mode)
- if model_mode == ModelMode.CHAT:
- prompt_messages, stops = self._get_chat_model_prompt_messages(
- app_mode=app_mode,
- pre_prompt=prompt_template_entity.simple_prompt_template,
- inputs=inputs,
- query=query,
- files=files,
- context=context,
- memory=memory,
- model_config=model_config,
- )
- else:
- prompt_messages, stops = self._get_completion_model_prompt_messages(
- app_mode=app_mode,
- pre_prompt=prompt_template_entity.simple_prompt_template,
- inputs=inputs,
- query=query,
- files=files,
- context=context,
- memory=memory,
- model_config=model_config,
- )
- return prompt_messages, stops
- def get_prompt_str_and_rules(
- self,
- app_mode: AppMode,
- model_config: ModelConfigWithCredentialsEntity,
- pre_prompt: str,
- inputs: dict,
- query: Optional[str] = None,
- context: Optional[str] = None,
- histories: Optional[str] = None,
- ) -> tuple[str, dict]:
- # get prompt template
- prompt_template_config = self.get_prompt_template(
- app_mode=app_mode,
- provider=model_config.provider,
- model=model_config.model,
- pre_prompt=pre_prompt,
- has_context=context is not None,
- query_in_prompt=query is not None,
- with_memory_prompt=histories is not None,
- )
- variables = {k: inputs[k] for k in prompt_template_config["custom_variable_keys"] if k in inputs}
- for v in prompt_template_config["special_variable_keys"]:
- # support #context#, #query# and #histories#
- if v == "#context#":
- variables["#context#"] = context or ""
- elif v == "#query#":
- variables["#query#"] = query or ""
- elif v == "#histories#":
- variables["#histories#"] = histories or ""
- prompt_template = prompt_template_config["prompt_template"]
- prompt = prompt_template.format(variables)
- return prompt, prompt_template_config["prompt_rules"]
- def get_prompt_template(
- self,
- app_mode: AppMode,
- provider: str,
- model: str,
- pre_prompt: str,
- has_context: bool,
- query_in_prompt: bool,
- with_memory_prompt: bool = False,
- ) -> dict:
- prompt_rules = self._get_prompt_rule(app_mode=app_mode, provider=provider, model=model)
- custom_variable_keys = []
- special_variable_keys = []
- prompt = ""
- for order in prompt_rules["system_prompt_orders"]:
- if order == "context_prompt" and has_context:
- prompt += prompt_rules["context_prompt"]
- special_variable_keys.append("#context#")
- elif order == "pre_prompt" and pre_prompt:
- prompt += pre_prompt + "\n"
- pre_prompt_template = PromptTemplateParser(template=pre_prompt)
- custom_variable_keys = pre_prompt_template.variable_keys
- elif order == "histories_prompt" and with_memory_prompt:
- prompt += prompt_rules["histories_prompt"]
- special_variable_keys.append("#histories#")
- if query_in_prompt:
- prompt += prompt_rules.get("query_prompt", "{{#query#}}")
- special_variable_keys.append("#query#")
- return {
- "prompt_template": PromptTemplateParser(template=prompt),
- "custom_variable_keys": custom_variable_keys,
- "special_variable_keys": special_variable_keys,
- "prompt_rules": prompt_rules,
- }
- def _get_chat_model_prompt_messages(
- self,
- app_mode: AppMode,
- pre_prompt: str,
- inputs: dict,
- query: str,
- context: Optional[str],
- files: list["File"],
- memory: Optional[TokenBufferMemory],
- model_config: ModelConfigWithCredentialsEntity,
- ) -> tuple[list[PromptMessage], Optional[list[str]]]:
- prompt_messages = []
- # get prompt
- prompt, _ = self.get_prompt_str_and_rules(
- app_mode=app_mode,
- model_config=model_config,
- pre_prompt=pre_prompt,
- inputs=inputs,
- query=None,
- context=context,
- )
- if prompt and query:
- prompt_messages.append(SystemPromptMessage(content=prompt))
- if memory:
- prompt_messages = self._append_chat_histories(
- memory=memory,
- memory_config=MemoryConfig(
- window=MemoryConfig.WindowConfig(
- enabled=False,
- )
- ),
- prompt_messages=prompt_messages,
- model_config=model_config,
- )
- if query:
- prompt_messages.append(self.get_last_user_message(query, files))
- else:
- prompt_messages.append(self.get_last_user_message(prompt, files))
- return prompt_messages, None
- def _get_completion_model_prompt_messages(
- self,
- app_mode: AppMode,
- pre_prompt: str,
- inputs: dict,
- query: str,
- context: Optional[str],
- files: list["File"],
- memory: Optional[TokenBufferMemory],
- model_config: ModelConfigWithCredentialsEntity,
- ) -> tuple[list[PromptMessage], Optional[list[str]]]:
- # get prompt
- prompt, prompt_rules = self.get_prompt_str_and_rules(
- app_mode=app_mode,
- model_config=model_config,
- pre_prompt=pre_prompt,
- inputs=inputs,
- query=query,
- context=context,
- )
- if memory:
- tmp_human_message = UserPromptMessage(content=prompt)
- rest_tokens = self._calculate_rest_token([tmp_human_message], model_config)
- histories = self._get_history_messages_from_memory(
- memory=memory,
- memory_config=MemoryConfig(
- window=MemoryConfig.WindowConfig(
- enabled=False,
- )
- ),
- max_token_limit=rest_tokens,
- human_prefix=prompt_rules.get("human_prefix", "Human"),
- ai_prefix=prompt_rules.get("assistant_prefix", "Assistant"),
- )
- # get prompt
- prompt, prompt_rules = self.get_prompt_str_and_rules(
- app_mode=app_mode,
- model_config=model_config,
- pre_prompt=pre_prompt,
- inputs=inputs,
- query=query,
- context=context,
- histories=histories,
- )
- stops = prompt_rules.get("stops")
- if stops is not None and len(stops) == 0:
- stops = None
- return [self.get_last_user_message(prompt, files)], stops
- def get_last_user_message(self, prompt: str, files: list["File"]) -> UserPromptMessage:
- if files:
- prompt_message_contents: list[PromptMessageContent] = []
- prompt_message_contents.append(TextPromptMessageContent(data=prompt))
- for file in files:
- prompt_message_contents.append(file_manager.to_prompt_message_content(file))
- prompt_message = UserPromptMessage(content=prompt_message_contents)
- else:
- prompt_message = UserPromptMessage(content=prompt)
- return prompt_message
- def _get_prompt_rule(self, app_mode: AppMode, provider: str, model: str) -> dict:
- """
- Get simple prompt rule.
- :param app_mode: app mode
- :param provider: model provider
- :param model: model name
- :return:
- """
- prompt_file_name = self._prompt_file_name(app_mode=app_mode, provider=provider, model=model)
- # Check if the prompt file is already loaded
- if prompt_file_name in prompt_file_contents:
- return prompt_file_contents[prompt_file_name]
- # Get the absolute path of the subdirectory
- prompt_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "prompt_templates")
- json_file_path = os.path.join(prompt_path, f"{prompt_file_name}.json")
- # Open the JSON file and read its content
- with open(json_file_path, encoding="utf-8") as json_file:
- content = json.load(json_file)
- # Store the content of the prompt file
- prompt_file_contents[prompt_file_name] = content
- return content
- def _prompt_file_name(self, app_mode: AppMode, provider: str, model: str) -> str:
- # baichuan
- is_baichuan = False
- if provider == "baichuan":
- is_baichuan = True
- else:
- baichuan_supported_providers = ["huggingface_hub", "openllm", "xinference"]
- if provider in baichuan_supported_providers and "baichuan" in model.lower():
- is_baichuan = True
- if is_baichuan:
- if app_mode == AppMode.COMPLETION:
- return "baichuan_completion"
- else:
- return "baichuan_chat"
- # common
- if app_mode == AppMode.COMPLETION:
- return "common_completion"
- else:
- return "common_chat"
|