agent_builder.py 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. from typing import Optional
  2. from langchain import LLMChain
  3. from langchain.agents import ZeroShotAgent, AgentExecutor, ConversationalAgent
  4. from langchain.callbacks import CallbackManager
  5. from langchain.memory.chat_memory import BaseChatMemory
  6. from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
  7. from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
  8. from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
  9. from core.llm.llm_builder import LLMBuilder
  10. class AgentBuilder:
  11. @classmethod
  12. def to_agent_chain(cls, tenant_id: str, tools, memory: Optional[BaseChatMemory],
  13. dataset_tool_callback_handler: DatasetToolCallbackHandler,
  14. agent_loop_gather_callback_handler: AgentLoopGatherCallbackHandler):
  15. llm_callback_manager = CallbackManager([agent_loop_gather_callback_handler, DifyStdOutCallbackHandler()])
  16. llm = LLMBuilder.to_llm(
  17. tenant_id=tenant_id,
  18. model_name=agent_loop_gather_callback_handler.model_name,
  19. temperature=0,
  20. max_tokens=1024,
  21. callback_manager=llm_callback_manager
  22. )
  23. tool_callback_manager = CallbackManager([
  24. agent_loop_gather_callback_handler,
  25. dataset_tool_callback_handler,
  26. DifyStdOutCallbackHandler()
  27. ])
  28. for tool in tools:
  29. tool.callback_manager = tool_callback_manager
  30. prompt = cls.build_agent_prompt_template(
  31. tools=tools,
  32. memory=memory,
  33. )
  34. agent_llm_chain = LLMChain(
  35. llm=llm,
  36. prompt=prompt,
  37. )
  38. agent = cls.build_agent(agent_llm_chain=agent_llm_chain, memory=memory)
  39. agent_callback_manager = CallbackManager(
  40. [agent_loop_gather_callback_handler, DifyStdOutCallbackHandler()]
  41. )
  42. agent_chain = AgentExecutor.from_agent_and_tools(
  43. tools=tools,
  44. agent=agent,
  45. memory=memory,
  46. callback_manager=agent_callback_manager,
  47. max_iterations=6,
  48. early_stopping_method="generate",
  49. # `generate` will continue to complete the last inference after reaching the iteration limit or request time limit
  50. )
  51. return agent_chain
  52. @classmethod
  53. def build_agent_prompt_template(cls, tools, memory: Optional[BaseChatMemory]):
  54. if memory:
  55. prompt = ConversationalAgent.create_prompt(
  56. tools=tools,
  57. )
  58. else:
  59. prompt = ZeroShotAgent.create_prompt(
  60. tools=tools,
  61. )
  62. return prompt
  63. @classmethod
  64. def build_agent(cls, agent_llm_chain: LLMChain, memory: Optional[BaseChatMemory]):
  65. if memory:
  66. agent = ConversationalAgent(
  67. llm_chain=agent_llm_chain
  68. )
  69. else:
  70. agent = ZeroShotAgent(
  71. llm_chain=agent_llm_chain
  72. )
  73. return agent