agent_executor.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. import enum
  2. import logging
  3. from typing import Union, Optional
  4. from langchain.agents import BaseSingleActionAgent, BaseMultiActionAgent
  5. from langchain.callbacks.manager import Callbacks
  6. from langchain.memory.chat_memory import BaseChatMemory
  7. from langchain.tools import BaseTool
  8. from pydantic import BaseModel, Extra
  9. from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
  10. from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent
  11. from core.agent.agent.openai_multi_function_call import AutoSummarizingOpenMultiAIFunctionCallAgent
  12. from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser
  13. from core.agent.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent
  14. from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
  15. from langchain.agents import AgentExecutor as LCAgentExecutor
  16. from core.helper import moderation
  17. from core.model_providers.error import LLMError
  18. from core.model_providers.models.llm.base import BaseLLM
  19. from core.tool.dataset_retriever_tool import DatasetRetrieverTool
  20. class PlanningStrategy(str, enum.Enum):
  21. ROUTER = 'router'
  22. REACT_ROUTER = 'react_router'
  23. REACT = 'react'
  24. FUNCTION_CALL = 'function_call'
  25. MULTI_FUNCTION_CALL = 'multi_function_call'
  26. class AgentConfiguration(BaseModel):
  27. strategy: PlanningStrategy
  28. model_instance: BaseLLM
  29. tools: list[BaseTool]
  30. summary_model_instance: BaseLLM = None
  31. memory: Optional[BaseChatMemory] = None
  32. callbacks: Callbacks = None
  33. max_iterations: int = 6
  34. max_execution_time: Optional[float] = None
  35. early_stopping_method: str = "generate"
  36. # `generate` will continue to complete the last inference after reaching the iteration limit or request time limit
  37. class Config:
  38. """Configuration for this pydantic object."""
  39. extra = Extra.forbid
  40. arbitrary_types_allowed = True
  41. class AgentExecuteResult(BaseModel):
  42. strategy: PlanningStrategy
  43. output: Optional[str]
  44. configuration: AgentConfiguration
  45. class AgentExecutor:
  46. def __init__(self, configuration: AgentConfiguration):
  47. self.configuration = configuration
  48. self.agent = self._init_agent()
  49. def _init_agent(self) -> Union[BaseSingleActionAgent | BaseMultiActionAgent]:
  50. if self.configuration.strategy == PlanningStrategy.REACT:
  51. agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools(
  52. model_instance=self.configuration.model_instance,
  53. llm=self.configuration.model_instance.client,
  54. tools=self.configuration.tools,
  55. output_parser=StructuredChatOutputParser(),
  56. summary_llm=self.configuration.summary_model_instance.client
  57. if self.configuration.summary_model_instance else None,
  58. verbose=True
  59. )
  60. elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL:
  61. agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools(
  62. model_instance=self.configuration.model_instance,
  63. llm=self.configuration.model_instance.client,
  64. tools=self.configuration.tools,
  65. extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
  66. summary_llm=self.configuration.summary_model_instance.client
  67. if self.configuration.summary_model_instance else None,
  68. verbose=True
  69. )
  70. elif self.configuration.strategy == PlanningStrategy.MULTI_FUNCTION_CALL:
  71. agent = AutoSummarizingOpenMultiAIFunctionCallAgent.from_llm_and_tools(
  72. model_instance=self.configuration.model_instance,
  73. llm=self.configuration.model_instance.client,
  74. tools=self.configuration.tools,
  75. extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
  76. summary_llm=self.configuration.summary_model_instance.client
  77. if self.configuration.summary_model_instance else None,
  78. verbose=True
  79. )
  80. elif self.configuration.strategy == PlanningStrategy.ROUTER:
  81. self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)]
  82. agent = MultiDatasetRouterAgent.from_llm_and_tools(
  83. model_instance=self.configuration.model_instance,
  84. llm=self.configuration.model_instance.client,
  85. tools=self.configuration.tools,
  86. extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None,
  87. verbose=True
  88. )
  89. elif self.configuration.strategy == PlanningStrategy.REACT_ROUTER:
  90. self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)]
  91. agent = StructuredMultiDatasetRouterAgent.from_llm_and_tools(
  92. model_instance=self.configuration.model_instance,
  93. llm=self.configuration.model_instance.client,
  94. tools=self.configuration.tools,
  95. output_parser=StructuredChatOutputParser(),
  96. verbose=True
  97. )
  98. else:
  99. raise NotImplementedError(f"Unknown Agent Strategy: {self.configuration.strategy}")
  100. return agent
  101. def should_use_agent(self, query: str) -> bool:
  102. return self.agent.should_use_agent(query)
  103. def run(self, query: str) -> AgentExecuteResult:
  104. moderation_result = moderation.check_moderation(
  105. self.configuration.model_instance.model_provider,
  106. query
  107. )
  108. if not moderation_result:
  109. return AgentExecuteResult(
  110. output="I apologize for any confusion, but I'm an AI assistant to be helpful, harmless, and honest.",
  111. strategy=self.configuration.strategy,
  112. configuration=self.configuration
  113. )
  114. agent_executor = LCAgentExecutor.from_agent_and_tools(
  115. agent=self.agent,
  116. tools=self.configuration.tools,
  117. memory=self.configuration.memory,
  118. max_iterations=self.configuration.max_iterations,
  119. max_execution_time=self.configuration.max_execution_time,
  120. early_stopping_method=self.configuration.early_stopping_method,
  121. callbacks=self.configuration.callbacks
  122. )
  123. try:
  124. output = agent_executor.run(query)
  125. except LLMError as ex:
  126. raise ex
  127. except Exception as ex:
  128. logging.exception("agent_executor run failed")
  129. output = None
  130. return AgentExecuteResult(
  131. output=output,
  132. strategy=self.configuration.strategy,
  133. configuration=self.configuration
  134. )