agent_executor.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. import enum
  2. import logging
  3. from typing import Union, Optional
  4. from langchain.agents import BaseSingleActionAgent, BaseMultiActionAgent
  5. from langchain.callbacks.manager import Callbacks
  6. from langchain.tools import BaseTool
  7. from pydantic import BaseModel, Extra
  8. from core.agent.agent.agent_llm_callback import AgentLLMCallback
  9. from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
  10. from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent
  11. from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser
  12. from core.agent.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent
  13. from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
  14. from langchain.agents import AgentExecutor as LCAgentExecutor
  15. from core.entities.application_entities import ModelConfigEntity
  16. from core.entities.message_entities import prompt_messages_to_lc_messages
  17. from core.helper import moderation
  18. from core.memory.token_buffer_memory import TokenBufferMemory
  19. from core.model_runtime.errors.invoke import InvokeError
  20. from core.tool.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
  21. from core.tool.dataset_retriever_tool import DatasetRetrieverTool
  22. class PlanningStrategy(str, enum.Enum):
  23. ROUTER = 'router'
  24. REACT_ROUTER = 'react_router'
  25. REACT = 'react'
  26. FUNCTION_CALL = 'function_call'
  27. class AgentConfiguration(BaseModel):
  28. strategy: PlanningStrategy
  29. model_config: ModelConfigEntity
  30. tools: list[BaseTool]
  31. summary_model_config: Optional[ModelConfigEntity] = None
  32. memory: Optional[TokenBufferMemory] = None
  33. callbacks: Callbacks = None
  34. max_iterations: int = 6
  35. max_execution_time: Optional[float] = None
  36. early_stopping_method: str = "generate"
  37. agent_llm_callback: Optional[AgentLLMCallback] = None
  38. # `generate` will continue to complete the last inference after reaching the iteration limit or request time limit
  39. class Config:
  40. """Configuration for this pydantic object."""
  41. extra = Extra.forbid
  42. arbitrary_types_allowed = True
  43. class AgentExecuteResult(BaseModel):
  44. strategy: PlanningStrategy
  45. output: Optional[str]
  46. configuration: AgentConfiguration
  47. class AgentExecutor:
  48. def __init__(self, configuration: AgentConfiguration):
  49. self.configuration = configuration
  50. self.agent = self._init_agent()
  51. def _init_agent(self) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]:
  52. if self.configuration.strategy == PlanningStrategy.REACT:
  53. agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools(
  54. model_config=self.configuration.model_config,
  55. tools=self.configuration.tools,
  56. output_parser=StructuredChatOutputParser(),
  57. summary_model_config=self.configuration.summary_model_config
  58. if self.configuration.summary_model_config else None,
  59. agent_llm_callback=self.configuration.agent_llm_callback,
  60. verbose=True
  61. )
  62. elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL:
  63. agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools(
  64. model_config=self.configuration.model_config,
  65. tools=self.configuration.tools,
  66. extra_prompt_messages=prompt_messages_to_lc_messages(self.configuration.memory.get_history_prompt_messages())
  67. if self.configuration.memory else None, # used for read chat histories memory
  68. summary_model_config=self.configuration.summary_model_config
  69. if self.configuration.summary_model_config else None,
  70. agent_llm_callback=self.configuration.agent_llm_callback,
  71. verbose=True
  72. )
  73. elif self.configuration.strategy == PlanningStrategy.ROUTER:
  74. self.configuration.tools = [t for t in self.configuration.tools
  75. if isinstance(t, DatasetRetrieverTool)
  76. or isinstance(t, DatasetMultiRetrieverTool)]
  77. agent = MultiDatasetRouterAgent.from_llm_and_tools(
  78. model_config=self.configuration.model_config,
  79. tools=self.configuration.tools,
  80. extra_prompt_messages=prompt_messages_to_lc_messages(self.configuration.memory.get_history_prompt_messages())
  81. if self.configuration.memory else None,
  82. verbose=True
  83. )
  84. elif self.configuration.strategy == PlanningStrategy.REACT_ROUTER:
  85. self.configuration.tools = [t for t in self.configuration.tools
  86. if isinstance(t, DatasetRetrieverTool)
  87. or isinstance(t, DatasetMultiRetrieverTool)]
  88. agent = StructuredMultiDatasetRouterAgent.from_llm_and_tools(
  89. model_config=self.configuration.model_config,
  90. tools=self.configuration.tools,
  91. output_parser=StructuredChatOutputParser(),
  92. verbose=True
  93. )
  94. else:
  95. raise NotImplementedError(f"Unknown Agent Strategy: {self.configuration.strategy}")
  96. return agent
  97. def should_use_agent(self, query: str) -> bool:
  98. return self.agent.should_use_agent(query)
  99. def run(self, query: str) -> AgentExecuteResult:
  100. moderation_result = moderation.check_moderation(
  101. self.configuration.model_config,
  102. query
  103. )
  104. if moderation_result:
  105. return AgentExecuteResult(
  106. output="I apologize for any confusion, but I'm an AI assistant to be helpful, harmless, and honest.",
  107. strategy=self.configuration.strategy,
  108. configuration=self.configuration
  109. )
  110. agent_executor = LCAgentExecutor.from_agent_and_tools(
  111. agent=self.agent,
  112. tools=self.configuration.tools,
  113. max_iterations=self.configuration.max_iterations,
  114. max_execution_time=self.configuration.max_execution_time,
  115. early_stopping_method=self.configuration.early_stopping_method,
  116. callbacks=self.configuration.callbacks
  117. )
  118. try:
  119. output = agent_executor.run(input=query)
  120. except InvokeError as ex:
  121. raise ex
  122. except Exception as ex:
  123. logging.exception("agent_executor run failed")
  124. output = None
  125. return AgentExecuteResult(
  126. output=output,
  127. strategy=self.configuration.strategy,
  128. configuration=self.configuration
  129. )