base_callback.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. from abc import ABC
  2. from typing import Optional
  3. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
  4. from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
  5. from core.model_runtime.model_providers.__base.ai_model import AIModel
  6. _TEXT_COLOR_MAPPING = {
  7. "blue": "36;1",
  8. "yellow": "33;1",
  9. "pink": "38;5;200",
  10. "green": "32;1",
  11. "red": "31;1",
  12. }
  13. class Callback(ABC):
  14. """
  15. Base class for callbacks.
  16. Only for LLM.
  17. """
  18. raise_error: bool = False
  19. def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict,
  20. prompt_messages: list[PromptMessage], model_parameters: dict,
  21. tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
  22. stream: bool = True, user: Optional[str] = None) -> None:
  23. """
  24. Before invoke callback
  25. :param llm_instance: LLM instance
  26. :param model: model name
  27. :param credentials: model credentials
  28. :param prompt_messages: prompt messages
  29. :param model_parameters: model parameters
  30. :param tools: tools for tool calling
  31. :param stop: stop words
  32. :param stream: is stream response
  33. :param user: unique user id
  34. """
  35. raise NotImplementedError()
  36. def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict,
  37. prompt_messages: list[PromptMessage], model_parameters: dict,
  38. tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
  39. stream: bool = True, user: Optional[str] = None):
  40. """
  41. On new chunk callback
  42. :param llm_instance: LLM instance
  43. :param chunk: chunk
  44. :param model: model name
  45. :param credentials: model credentials
  46. :param prompt_messages: prompt messages
  47. :param model_parameters: model parameters
  48. :param tools: tools for tool calling
  49. :param stop: stop words
  50. :param stream: is stream response
  51. :param user: unique user id
  52. """
  53. raise NotImplementedError()
  54. def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict,
  55. prompt_messages: list[PromptMessage], model_parameters: dict,
  56. tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
  57. stream: bool = True, user: Optional[str] = None) -> None:
  58. """
  59. After invoke callback
  60. :param llm_instance: LLM instance
  61. :param result: result
  62. :param model: model name
  63. :param credentials: model credentials
  64. :param prompt_messages: prompt messages
  65. :param model_parameters: model parameters
  66. :param tools: tools for tool calling
  67. :param stop: stop words
  68. :param stream: is stream response
  69. :param user: unique user id
  70. """
  71. raise NotImplementedError()
  72. def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict,
  73. prompt_messages: list[PromptMessage], model_parameters: dict,
  74. tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
  75. stream: bool = True, user: Optional[str] = None) -> None:
  76. """
  77. Invoke error callback
  78. :param llm_instance: LLM instance
  79. :param ex: exception
  80. :param model: model name
  81. :param credentials: model credentials
  82. :param prompt_messages: prompt messages
  83. :param model_parameters: model parameters
  84. :param tools: tools for tool calling
  85. :param stop: stop words
  86. :param stream: is stream response
  87. :param user: unique user id
  88. """
  89. raise NotImplementedError()
  90. def print_text(
  91. self, text: str, color: Optional[str] = None, end: str = ""
  92. ) -> None:
  93. """Print text with highlighting and no end characters."""
  94. text_to_print = self._get_colored_text(text, color) if color else text
  95. print(text_to_print, end=end)
  96. def _get_colored_text(self, text: str, color: str) -> str:
  97. """Get colored text."""
  98. color_str = _TEXT_COLOR_MAPPING[color]
  99. return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m"