agent_executor.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. import enum
  2. import logging
  3. from typing import Optional, Union
  4. from core.agent.agent.agent_llm_callback import AgentLLMCallback
  5. from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
  6. from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent
  7. from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser
  8. from core.agent.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent
  9. from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
  10. from core.entities.application_entities import ModelConfigEntity
  11. from core.entities.message_entities import prompt_messages_to_lc_messages
  12. from core.helper import moderation
  13. from core.memory.token_buffer_memory import TokenBufferMemory
  14. from core.model_runtime.errors.invoke import InvokeError
  15. from core.tool.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
  16. from core.tool.dataset_retriever_tool import DatasetRetrieverTool
  17. from langchain.agents import AgentExecutor as LCAgentExecutor
  18. from langchain.agents import BaseMultiActionAgent, BaseSingleActionAgent
  19. from langchain.callbacks.manager import Callbacks
  20. from langchain.tools import BaseTool
  21. from pydantic import BaseModel, Extra
  22. class PlanningStrategy(str, enum.Enum):
  23. ROUTER = 'router'
  24. REACT_ROUTER = 'react_router'
  25. REACT = 'react'
  26. FUNCTION_CALL = 'function_call'
  27. class AgentConfiguration(BaseModel):
  28. strategy: PlanningStrategy
  29. model_config: ModelConfigEntity
  30. tools: list[BaseTool]
  31. summary_model_config: Optional[ModelConfigEntity] = None
  32. memory: Optional[TokenBufferMemory] = None
  33. callbacks: Callbacks = None
  34. max_iterations: int = 6
  35. max_execution_time: Optional[float] = None
  36. early_stopping_method: str = "generate"
  37. agent_llm_callback: Optional[AgentLLMCallback] = None
  38. # `generate` will continue to complete the last inference after reaching the iteration limit or request time limit
  39. class Config:
  40. """Configuration for this pydantic object."""
  41. extra = Extra.forbid
  42. arbitrary_types_allowed = True
  43. class AgentExecuteResult(BaseModel):
  44. strategy: PlanningStrategy
  45. output: Optional[str]
  46. configuration: AgentConfiguration
  47. class AgentExecutor:
  48. def __init__(self, configuration: AgentConfiguration):
  49. self.configuration = configuration
  50. self.agent = self._init_agent()
  51. def _init_agent(self) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]:
  52. if self.configuration.strategy == PlanningStrategy.REACT:
  53. agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools(
  54. model_config=self.configuration.model_config,
  55. tools=self.configuration.tools,
  56. output_parser=StructuredChatOutputParser(),
  57. summary_model_config=self.configuration.summary_model_config
  58. if self.configuration.summary_model_config else None,
  59. agent_llm_callback=self.configuration.agent_llm_callback,
  60. verbose=True
  61. )
  62. elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL:
  63. agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools(
  64. model_config=self.configuration.model_config,
  65. tools=self.configuration.tools,
  66. extra_prompt_messages=prompt_messages_to_lc_messages(self.configuration.memory.get_history_prompt_messages())
  67. if self.configuration.memory else None, # used for read chat histories memory
  68. summary_model_config=self.configuration.summary_model_config
  69. if self.configuration.summary_model_config else None,
  70. agent_llm_callback=self.configuration.agent_llm_callback,
  71. verbose=True
  72. )
  73. elif self.configuration.strategy == PlanningStrategy.ROUTER:
  74. self.configuration.tools = [t for t in self.configuration.tools
  75. if isinstance(t, DatasetRetrieverTool)
  76. or isinstance(t, DatasetMultiRetrieverTool)]
  77. agent = MultiDatasetRouterAgent.from_llm_and_tools(
  78. model_config=self.configuration.model_config,
  79. tools=self.configuration.tools,
  80. extra_prompt_messages=prompt_messages_to_lc_messages(self.configuration.memory.get_history_prompt_messages())
  81. if self.configuration.memory else None,
  82. verbose=True
  83. )
  84. elif self.configuration.strategy == PlanningStrategy.REACT_ROUTER:
  85. self.configuration.tools = [t for t in self.configuration.tools
  86. if isinstance(t, DatasetRetrieverTool)
  87. or isinstance(t, DatasetMultiRetrieverTool)]
  88. agent = StructuredMultiDatasetRouterAgent.from_llm_and_tools(
  89. model_config=self.configuration.model_config,
  90. tools=self.configuration.tools,
  91. output_parser=StructuredChatOutputParser(),
  92. verbose=True
  93. )
  94. else:
  95. raise NotImplementedError(f"Unknown Agent Strategy: {self.configuration.strategy}")
  96. return agent
  97. def should_use_agent(self, query: str) -> bool:
  98. return self.agent.should_use_agent(query)
  99. def run(self, query: str) -> AgentExecuteResult:
  100. moderation_result = moderation.check_moderation(
  101. self.configuration.model_config,
  102. query
  103. )
  104. if moderation_result:
  105. return AgentExecuteResult(
  106. output="I apologize for any confusion, but I'm an AI assistant to be helpful, harmless, and honest.",
  107. strategy=self.configuration.strategy,
  108. configuration=self.configuration
  109. )
  110. agent_executor = LCAgentExecutor.from_agent_and_tools(
  111. agent=self.agent,
  112. tools=self.configuration.tools,
  113. max_iterations=self.configuration.max_iterations,
  114. max_execution_time=self.configuration.max_execution_time,
  115. early_stopping_method=self.configuration.early_stopping_method,
  116. callbacks=self.configuration.callbacks
  117. )
  118. try:
  119. output = agent_executor.run(input=query)
  120. except InvokeError as ex:
  121. raise ex
  122. except Exception as ex:
  123. logging.exception("agent_executor run failed")
  124. output = None
  125. return AgentExecuteResult(
  126. output=output,
  127. strategy=self.configuration.strategy,
  128. configuration=self.configuration
  129. )