agent_llm_callback.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. import logging
  2. from typing import Optional, List
  3. from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
  4. from core.model_runtime.callbacks.base_callback import Callback
  5. from core.model_runtime.entities.llm_entities import LLMResultChunk, LLMResult
  6. from core.model_runtime.entities.message_entities import PromptMessageTool, PromptMessage
  7. from core.model_runtime.model_providers.__base.ai_model import AIModel
  8. logger = logging.getLogger(__name__)
  9. class AgentLLMCallback(Callback):
  10. def __init__(self, agent_callback: AgentLoopGatherCallbackHandler) -> None:
  11. self.agent_callback = agent_callback
  12. def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict,
  13. prompt_messages: list[PromptMessage], model_parameters: dict,
  14. tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
  15. stream: bool = True, user: Optional[str] = None) -> None:
  16. """
  17. Before invoke callback
  18. :param llm_instance: LLM instance
  19. :param model: model name
  20. :param credentials: model credentials
  21. :param prompt_messages: prompt messages
  22. :param model_parameters: model parameters
  23. :param tools: tools for tool calling
  24. :param stop: stop words
  25. :param stream: is stream response
  26. :param user: unique user id
  27. """
  28. self.agent_callback.on_llm_before_invoke(
  29. prompt_messages=prompt_messages
  30. )
  31. def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict,
  32. prompt_messages: list[PromptMessage], model_parameters: dict,
  33. tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
  34. stream: bool = True, user: Optional[str] = None):
  35. """
  36. On new chunk callback
  37. :param llm_instance: LLM instance
  38. :param chunk: chunk
  39. :param model: model name
  40. :param credentials: model credentials
  41. :param prompt_messages: prompt messages
  42. :param model_parameters: model parameters
  43. :param tools: tools for tool calling
  44. :param stop: stop words
  45. :param stream: is stream response
  46. :param user: unique user id
  47. """
  48. pass
  49. def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict,
  50. prompt_messages: list[PromptMessage], model_parameters: dict,
  51. tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
  52. stream: bool = True, user: Optional[str] = None) -> None:
  53. """
  54. After invoke callback
  55. :param llm_instance: LLM instance
  56. :param result: result
  57. :param model: model name
  58. :param credentials: model credentials
  59. :param prompt_messages: prompt messages
  60. :param model_parameters: model parameters
  61. :param tools: tools for tool calling
  62. :param stop: stop words
  63. :param stream: is stream response
  64. :param user: unique user id
  65. """
  66. self.agent_callback.on_llm_after_invoke(
  67. result=result
  68. )
  69. def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict,
  70. prompt_messages: list[PromptMessage], model_parameters: dict,
  71. tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
  72. stream: bool = True, user: Optional[str] = None) -> None:
  73. """
  74. Invoke error callback
  75. :param llm_instance: LLM instance
  76. :param ex: exception
  77. :param model: model name
  78. :param credentials: model credentials
  79. :param prompt_messages: prompt messages
  80. :param model_parameters: model parameters
  81. :param tools: tools for tool calling
  82. :param stop: stop words
  83. :param stream: is stream response
  84. :param user: unique user id
  85. """
  86. self.agent_callback.on_llm_error(
  87. error=ex
  88. )