base_callback.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. from typing import Optional
  2. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
  3. from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
  4. from core.model_runtime.model_providers.__base.ai_model import AIModel
  5. _TEXT_COLOR_MAPPING = {
  6. "blue": "36;1",
  7. "yellow": "33;1",
  8. "pink": "38;5;200",
  9. "green": "32;1",
  10. "red": "31;1",
  11. }
  12. class Callback:
  13. """
  14. Base class for callbacks.
  15. Only for LLM.
  16. """
  17. raise_error: bool = False
  18. def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict,
  19. prompt_messages: list[PromptMessage], model_parameters: dict,
  20. tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
  21. stream: bool = True, user: Optional[str] = None) -> None:
  22. """
  23. Before invoke callback
  24. :param llm_instance: LLM instance
  25. :param model: model name
  26. :param credentials: model credentials
  27. :param prompt_messages: prompt messages
  28. :param model_parameters: model parameters
  29. :param tools: tools for tool calling
  30. :param stop: stop words
  31. :param stream: is stream response
  32. :param user: unique user id
  33. """
  34. raise NotImplementedError()
  35. def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict,
  36. prompt_messages: list[PromptMessage], model_parameters: dict,
  37. tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
  38. stream: bool = True, user: Optional[str] = None):
  39. """
  40. On new chunk callback
  41. :param llm_instance: LLM instance
  42. :param chunk: chunk
  43. :param model: model name
  44. :param credentials: model credentials
  45. :param prompt_messages: prompt messages
  46. :param model_parameters: model parameters
  47. :param tools: tools for tool calling
  48. :param stop: stop words
  49. :param stream: is stream response
  50. :param user: unique user id
  51. """
  52. raise NotImplementedError()
  53. def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict,
  54. prompt_messages: list[PromptMessage], model_parameters: dict,
  55. tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
  56. stream: bool = True, user: Optional[str] = None) -> None:
  57. """
  58. After invoke callback
  59. :param llm_instance: LLM instance
  60. :param result: result
  61. :param model: model name
  62. :param credentials: model credentials
  63. :param prompt_messages: prompt messages
  64. :param model_parameters: model parameters
  65. :param tools: tools for tool calling
  66. :param stop: stop words
  67. :param stream: is stream response
  68. :param user: unique user id
  69. """
  70. raise NotImplementedError()
  71. def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict,
  72. prompt_messages: list[PromptMessage], model_parameters: dict,
  73. tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
  74. stream: bool = True, user: Optional[str] = None) -> None:
  75. """
  76. Invoke error callback
  77. :param llm_instance: LLM instance
  78. :param ex: exception
  79. :param model: model name
  80. :param credentials: model credentials
  81. :param prompt_messages: prompt messages
  82. :param model_parameters: model parameters
  83. :param tools: tools for tool calling
  84. :param stop: stop words
  85. :param stream: is stream response
  86. :param user: unique user id
  87. """
  88. raise NotImplementedError()
  89. def print_text(
  90. self, text: str, color: Optional[str] = None, end: str = ""
  91. ) -> None:
  92. """Print text with highlighting and no end characters."""
  93. text_to_print = self._get_colored_text(text, color) if color else text
  94. print(text_to_print, end=end)
  95. def _get_colored_text(self, text: str, color: str) -> str:
  96. """Get colored text."""
  97. color_str = _TEXT_COLOR_MAPPING[color]
  98. return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m"