prompt_transform.py 3.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. from typing import Optional
  2. from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
  3. from core.memory.token_buffer_memory import TokenBufferMemory
  4. from core.model_manager import ModelInstance
  5. from core.model_runtime.entities.message_entities import PromptMessage
  6. from core.model_runtime.entities.model_entities import ModelPropertyKey
  7. from core.prompt.entities.advanced_prompt_entities import MemoryConfig
  8. class PromptTransform:
  9. def _append_chat_histories(
  10. self,
  11. memory: TokenBufferMemory,
  12. memory_config: MemoryConfig,
  13. prompt_messages: list[PromptMessage],
  14. model_config: ModelConfigWithCredentialsEntity,
  15. ) -> list[PromptMessage]:
  16. rest_tokens = self._calculate_rest_token(prompt_messages, model_config)
  17. histories = self._get_history_messages_list_from_memory(memory, memory_config, rest_tokens)
  18. prompt_messages.extend(histories)
  19. return prompt_messages
  20. def _calculate_rest_token(
  21. self, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity
  22. ) -> int:
  23. rest_tokens = 2000
  24. model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
  25. if model_context_tokens:
  26. model_instance = ModelInstance(
  27. provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
  28. )
  29. curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
  30. max_tokens = 0
  31. for parameter_rule in model_config.model_schema.parameter_rules:
  32. if parameter_rule.name == "max_tokens" or (
  33. parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
  34. ):
  35. max_tokens = (
  36. model_config.parameters.get(parameter_rule.name)
  37. or model_config.parameters.get(parameter_rule.use_template)
  38. ) or 0
  39. rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
  40. rest_tokens = max(rest_tokens, 0)
  41. return rest_tokens
  42. def _get_history_messages_from_memory(
  43. self,
  44. memory: TokenBufferMemory,
  45. memory_config: MemoryConfig,
  46. max_token_limit: int,
  47. human_prefix: Optional[str] = None,
  48. ai_prefix: Optional[str] = None,
  49. ) -> str:
  50. """Get memory messages."""
  51. kwargs = {"max_token_limit": max_token_limit}
  52. if human_prefix:
  53. kwargs["human_prefix"] = human_prefix
  54. if ai_prefix:
  55. kwargs["ai_prefix"] = ai_prefix
  56. if memory_config.window.enabled and memory_config.window.size is not None and memory_config.window.size > 0:
  57. kwargs["message_limit"] = memory_config.window.size
  58. return memory.get_history_prompt_text(**kwargs)
  59. def _get_history_messages_list_from_memory(
  60. self, memory: TokenBufferMemory, memory_config: MemoryConfig, max_token_limit: int
  61. ) -> list[PromptMessage]:
  62. """Get memory messages."""
  63. return memory.get_history_prompt_messages(
  64. max_token_limit=max_token_limit,
  65. message_limit=memory_config.window.size
  66. if (
  67. memory_config.window.enabled and memory_config.window.size is not None and memory_config.window.size > 0
  68. )
  69. else None,
  70. )