openai_function_call.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298
  1. from typing import List, Tuple, Any, Union, Sequence, Optional
  2. from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
  3. from langchain.agents.openai_functions_agent.base import _parse_ai_message, \
  4. _format_intermediate_steps
  5. from langchain.callbacks.base import BaseCallbackManager
  6. from langchain.callbacks.manager import Callbacks
  7. from langchain.chat_models.openai import _convert_message_to_dict, _import_tiktoken
  8. from langchain.memory.prompt import SUMMARY_PROMPT
  9. from langchain.prompts.chat import BaseMessagePromptTemplate
  10. from langchain.schema import AgentAction, AgentFinish, SystemMessage, AIMessage, HumanMessage, BaseMessage, \
  11. get_buffer_string
  12. from langchain.tools import BaseTool
  13. from pydantic import root_validator
  14. from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError, CalcTokenMixin
  15. from core.chain.llm_chain import LLMChain
  16. from core.model_providers.models.entity.message import to_prompt_messages
  17. from core.model_providers.models.llm.base import BaseLLM
  18. from core.third_party.langchain.llms.fake import FakeLLM
  19. class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixin):
  20. moving_summary_buffer: str = ""
  21. moving_summary_index: int = 0
  22. summary_model_instance: BaseLLM = None
  23. model_instance: BaseLLM
  24. class Config:
  25. """Configuration for this pydantic object."""
  26. arbitrary_types_allowed = True
  27. @root_validator
  28. def validate_llm(cls, values: dict) -> dict:
  29. return values
  30. @classmethod
  31. def from_llm_and_tools(
  32. cls,
  33. model_instance: BaseLLM,
  34. tools: Sequence[BaseTool],
  35. callback_manager: Optional[BaseCallbackManager] = None,
  36. extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
  37. system_message: Optional[SystemMessage] = SystemMessage(
  38. content="You are a helpful AI assistant."
  39. ),
  40. **kwargs: Any,
  41. ) -> BaseSingleActionAgent:
  42. prompt = cls.create_prompt(
  43. extra_prompt_messages=extra_prompt_messages,
  44. system_message=system_message,
  45. )
  46. return cls(
  47. model_instance=model_instance,
  48. llm=FakeLLM(response=''),
  49. prompt=prompt,
  50. tools=tools,
  51. callback_manager=callback_manager,
  52. **kwargs,
  53. )
  54. def should_use_agent(self, query: str):
  55. """
  56. return should use agent
  57. :param query:
  58. :return:
  59. """
  60. original_max_tokens = self.model_instance.model_kwargs.max_tokens
  61. self.model_instance.model_kwargs.max_tokens = 40
  62. prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
  63. messages = prompt.to_messages()
  64. try:
  65. prompt_messages = to_prompt_messages(messages)
  66. result = self.model_instance.run(
  67. messages=prompt_messages,
  68. functions=self.functions,
  69. callbacks=None
  70. )
  71. except Exception as e:
  72. new_exception = self.model_instance.handle_exceptions(e)
  73. raise new_exception
  74. function_call = result.function_call
  75. self.model_instance.model_kwargs.max_tokens = original_max_tokens
  76. return True if function_call else False
  77. def plan(
  78. self,
  79. intermediate_steps: List[Tuple[AgentAction, str]],
  80. callbacks: Callbacks = None,
  81. **kwargs: Any,
  82. ) -> Union[AgentAction, AgentFinish]:
  83. """Given input, decided what to do.
  84. Args:
  85. intermediate_steps: Steps the LLM has taken to date, along with observations
  86. **kwargs: User inputs.
  87. Returns:
  88. Action specifying what tool to use.
  89. """
  90. agent_scratchpad = _format_intermediate_steps(intermediate_steps)
  91. selected_inputs = {
  92. k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
  93. }
  94. full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
  95. prompt = self.prompt.format_prompt(**full_inputs)
  96. messages = prompt.to_messages()
  97. # summarize messages if rest_tokens < 0
  98. try:
  99. messages = self.summarize_messages_if_needed(messages, functions=self.functions)
  100. except ExceededLLMTokensLimitError as e:
  101. return AgentFinish(return_values={"output": str(e)}, log=str(e))
  102. prompt_messages = to_prompt_messages(messages)
  103. result = self.model_instance.run(
  104. messages=prompt_messages,
  105. functions=self.functions,
  106. )
  107. ai_message = AIMessage(
  108. content=result.content,
  109. additional_kwargs={
  110. 'function_call': result.function_call
  111. }
  112. )
  113. agent_decision = _parse_ai_message(ai_message)
  114. if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset':
  115. tool_inputs = agent_decision.tool_input
  116. if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
  117. tool_inputs['query'] = kwargs['input']
  118. agent_decision.tool_input = tool_inputs
  119. return agent_decision
  120. @classmethod
  121. def get_system_message(cls):
  122. return SystemMessage(content="You are a helpful AI assistant.\n"
  123. "The current date or current time you know is wrong.\n"
  124. "Respond directly if appropriate.")
  125. def return_stopped_response(
  126. self,
  127. early_stopping_method: str,
  128. intermediate_steps: List[Tuple[AgentAction, str]],
  129. **kwargs: Any,
  130. ) -> AgentFinish:
  131. try:
  132. return super().return_stopped_response(early_stopping_method, intermediate_steps, **kwargs)
  133. except ValueError:
  134. return AgentFinish({"output": "I'm sorry, I don't know how to respond to that."}, "")
  135. def summarize_messages_if_needed(self, messages: List[BaseMessage], **kwargs) -> List[BaseMessage]:
  136. # calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
  137. rest_tokens = self.get_message_rest_tokens(self.model_instance, messages, **kwargs)
  138. rest_tokens = rest_tokens - 20 # to deal with the inaccuracy of rest_tokens
  139. if rest_tokens >= 0:
  140. return messages
  141. system_message = None
  142. human_message = None
  143. should_summary_messages = []
  144. for message in messages:
  145. if isinstance(message, SystemMessage):
  146. system_message = message
  147. elif isinstance(message, HumanMessage):
  148. human_message = message
  149. else:
  150. should_summary_messages.append(message)
  151. if len(should_summary_messages) > 2:
  152. ai_message = should_summary_messages[-2]
  153. function_message = should_summary_messages[-1]
  154. should_summary_messages = should_summary_messages[self.moving_summary_index:-2]
  155. self.moving_summary_index = len(should_summary_messages)
  156. else:
  157. error_msg = "Exceeded LLM tokens limit, stopped."
  158. raise ExceededLLMTokensLimitError(error_msg)
  159. new_messages = [system_message, human_message]
  160. if self.moving_summary_index == 0:
  161. should_summary_messages.insert(0, human_message)
  162. self.moving_summary_buffer = self.predict_new_summary(
  163. messages=should_summary_messages,
  164. existing_summary=self.moving_summary_buffer
  165. )
  166. new_messages.append(AIMessage(content=self.moving_summary_buffer))
  167. new_messages.append(ai_message)
  168. new_messages.append(function_message)
  169. return new_messages
  170. def predict_new_summary(
  171. self, messages: List[BaseMessage], existing_summary: str
  172. ) -> str:
  173. new_lines = get_buffer_string(
  174. messages,
  175. human_prefix="Human",
  176. ai_prefix="AI",
  177. )
  178. chain = LLMChain(model_instance=self.summary_model_instance, prompt=SUMMARY_PROMPT)
  179. return chain.predict(summary=existing_summary, new_lines=new_lines)
  180. def get_num_tokens_from_messages(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int:
  181. """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
  182. Official documentation: https://github.com/openai/openai-cookbook/blob/
  183. main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
  184. if model_instance.model_provider.provider_name == 'azure_openai':
  185. model = model_instance.base_model_name
  186. model = model.replace("gpt-35", "gpt-3.5")
  187. else:
  188. model = model_instance.base_model_name
  189. tiktoken_ = _import_tiktoken()
  190. try:
  191. encoding = tiktoken_.encoding_for_model(model)
  192. except KeyError:
  193. model = "cl100k_base"
  194. encoding = tiktoken_.get_encoding(model)
  195. if model.startswith("gpt-3.5-turbo"):
  196. # every message follows <im_start>{role/name}\n{content}<im_end>\n
  197. tokens_per_message = 4
  198. # if there's a name, the role is omitted
  199. tokens_per_name = -1
  200. elif model.startswith("gpt-4"):
  201. tokens_per_message = 3
  202. tokens_per_name = 1
  203. else:
  204. raise NotImplementedError(
  205. f"get_num_tokens_from_messages() is not presently implemented "
  206. f"for model {model}."
  207. "See https://github.com/openai/openai-python/blob/main/chatml.md for "
  208. "information on how messages are converted to tokens."
  209. )
  210. num_tokens = 0
  211. for m in messages:
  212. message = _convert_message_to_dict(m)
  213. num_tokens += tokens_per_message
  214. for key, value in message.items():
  215. if key == "function_call":
  216. for f_key, f_value in value.items():
  217. num_tokens += len(encoding.encode(f_key))
  218. num_tokens += len(encoding.encode(f_value))
  219. else:
  220. num_tokens += len(encoding.encode(value))
  221. if key == "name":
  222. num_tokens += tokens_per_name
  223. # every reply is primed with <im_start>assistant
  224. num_tokens += 3
  225. if kwargs.get('functions'):
  226. for function in kwargs.get('functions'):
  227. num_tokens += len(encoding.encode('name'))
  228. num_tokens += len(encoding.encode(function.get("name")))
  229. num_tokens += len(encoding.encode('description'))
  230. num_tokens += len(encoding.encode(function.get("description")))
  231. parameters = function.get("parameters")
  232. num_tokens += len(encoding.encode('parameters'))
  233. if 'title' in parameters:
  234. num_tokens += len(encoding.encode('title'))
  235. num_tokens += len(encoding.encode(parameters.get("title")))
  236. num_tokens += len(encoding.encode('type'))
  237. num_tokens += len(encoding.encode(parameters.get("type")))
  238. if 'properties' in parameters:
  239. num_tokens += len(encoding.encode('properties'))
  240. for key, value in parameters.get('properties').items():
  241. num_tokens += len(encoding.encode(key))
  242. for field_key, field_value in value.items():
  243. num_tokens += len(encoding.encode(field_key))
  244. if field_key == 'enum':
  245. for enum_field in field_value:
  246. num_tokens += 3
  247. num_tokens += len(encoding.encode(enum_field))
  248. else:
  249. num_tokens += len(encoding.encode(field_key))
  250. num_tokens += len(encoding.encode(str(field_value)))
  251. if 'required' in parameters:
  252. num_tokens += len(encoding.encode('required'))
  253. for required_field in parameters['required']:
  254. num_tokens += 3
  255. num_tokens += len(encoding.encode(required_field))
  256. return num_tokens