openai_function_call.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. from typing import List, Tuple, Any, Union, Sequence, Optional
  2. from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
  3. from langchain.agents.openai_functions_agent.base import _parse_ai_message, \
  4. _format_intermediate_steps
  5. from langchain.callbacks.base import BaseCallbackManager
  6. from langchain.callbacks.manager import Callbacks
  7. from langchain.prompts.chat import BaseMessagePromptTemplate
  8. from langchain.schema import AgentAction, AgentFinish, SystemMessage
  9. from langchain.schema.language_model import BaseLanguageModel
  10. from langchain.tools import BaseTool
  11. from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError
  12. from core.agent.agent.openai_function_call_summarize_mixin import OpenAIFunctionCallSummarizeMixin
  13. class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctionCallSummarizeMixin):
  14. @classmethod
  15. def from_llm_and_tools(
  16. cls,
  17. llm: BaseLanguageModel,
  18. tools: Sequence[BaseTool],
  19. callback_manager: Optional[BaseCallbackManager] = None,
  20. extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
  21. system_message: Optional[SystemMessage] = SystemMessage(
  22. content="You are a helpful AI assistant."
  23. ),
  24. **kwargs: Any,
  25. ) -> BaseSingleActionAgent:
  26. return super().from_llm_and_tools(
  27. llm=llm,
  28. tools=tools,
  29. callback_manager=callback_manager,
  30. extra_prompt_messages=extra_prompt_messages,
  31. system_message=cls.get_system_message(),
  32. **kwargs,
  33. )
  34. def should_use_agent(self, query: str):
  35. """
  36. return should use agent
  37. :param query:
  38. :return:
  39. """
  40. original_max_tokens = self.llm.max_tokens
  41. self.llm.max_tokens = 40
  42. prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
  43. messages = prompt.to_messages()
  44. try:
  45. predicted_message = self.llm.predict_messages(
  46. messages, functions=self.functions, callbacks=None
  47. )
  48. except Exception as e:
  49. new_exception = self.model_instance.handle_exceptions(e)
  50. raise new_exception
  51. function_call = predicted_message.additional_kwargs.get("function_call", {})
  52. self.llm.max_tokens = original_max_tokens
  53. return True if function_call else False
  54. def plan(
  55. self,
  56. intermediate_steps: List[Tuple[AgentAction, str]],
  57. callbacks: Callbacks = None,
  58. **kwargs: Any,
  59. ) -> Union[AgentAction, AgentFinish]:
  60. """Given input, decided what to do.
  61. Args:
  62. intermediate_steps: Steps the LLM has taken to date, along with observations
  63. **kwargs: User inputs.
  64. Returns:
  65. Action specifying what tool to use.
  66. """
  67. agent_scratchpad = _format_intermediate_steps(intermediate_steps)
  68. selected_inputs = {
  69. k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
  70. }
  71. full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
  72. prompt = self.prompt.format_prompt(**full_inputs)
  73. messages = prompt.to_messages()
  74. # summarize messages if rest_tokens < 0
  75. try:
  76. messages = self.summarize_messages_if_needed(messages, functions=self.functions)
  77. except ExceededLLMTokensLimitError as e:
  78. return AgentFinish(return_values={"output": str(e)}, log=str(e))
  79. predicted_message = self.llm.predict_messages(
  80. messages, functions=self.functions, callbacks=callbacks
  81. )
  82. agent_decision = _parse_ai_message(predicted_message)
  83. if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset':
  84. tool_inputs = agent_decision.tool_input
  85. if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
  86. tool_inputs['query'] = kwargs['input']
  87. agent_decision.tool_input = tool_inputs
  88. return agent_decision
  89. @classmethod
  90. def get_system_message(cls):
  91. return SystemMessage(content="You are a helpful AI assistant.\n"
  92. "The current date or current time you know is wrong.\n"
  93. "Respond directly if appropriate.")
  94. def return_stopped_response(
  95. self,
  96. early_stopping_method: str,
  97. intermediate_steps: List[Tuple[AgentAction, str]],
  98. **kwargs: Any,
  99. ) -> AgentFinish:
  100. try:
  101. return super().return_stopped_response(early_stopping_method, intermediate_steps, **kwargs)
  102. except ValueError:
  103. return AgentFinish({"output": "I'm sorry, I don't know how to respond to that."}, "")