openai_function_call.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361
  1. from collections.abc import Sequence
  2. from typing import Any, Optional, Union
  3. from langchain.agents import BaseSingleActionAgent, OpenAIFunctionsAgent
  4. from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message
  5. from langchain.callbacks.base import BaseCallbackManager
  6. from langchain.callbacks.manager import Callbacks
  7. from langchain.chat_models.openai import _convert_message_to_dict, _import_tiktoken
  8. from langchain.memory.prompt import SUMMARY_PROMPT
  9. from langchain.prompts.chat import BaseMessagePromptTemplate
  10. from langchain.schema import (
  11. AgentAction,
  12. AgentFinish,
  13. AIMessage,
  14. BaseMessage,
  15. HumanMessage,
  16. SystemMessage,
  17. get_buffer_string,
  18. )
  19. from langchain.tools import BaseTool
  20. from pydantic import root_validator
  21. from core.agent.agent.agent_llm_callback import AgentLLMCallback
  22. from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
  23. from core.chain.llm_chain import LLMChain
  24. from core.entities.application_entities import ModelConfigEntity
  25. from core.entities.message_entities import lc_messages_to_prompt_messages
  26. from core.model_manager import ModelInstance
  27. from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
  28. from core.third_party.langchain.llms.fake import FakeLLM
  29. class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixin):
  30. moving_summary_buffer: str = ""
  31. moving_summary_index: int = 0
  32. summary_model_config: ModelConfigEntity = None
  33. model_config: ModelConfigEntity
  34. agent_llm_callback: Optional[AgentLLMCallback] = None
  35. class Config:
  36. """Configuration for this pydantic object."""
  37. arbitrary_types_allowed = True
  38. @root_validator
  39. def validate_llm(cls, values: dict) -> dict:
  40. return values
  41. @classmethod
  42. def from_llm_and_tools(
  43. cls,
  44. model_config: ModelConfigEntity,
  45. tools: Sequence[BaseTool],
  46. callback_manager: Optional[BaseCallbackManager] = None,
  47. extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None,
  48. system_message: Optional[SystemMessage] = SystemMessage(
  49. content="You are a helpful AI assistant."
  50. ),
  51. agent_llm_callback: Optional[AgentLLMCallback] = None,
  52. **kwargs: Any,
  53. ) -> BaseSingleActionAgent:
  54. prompt = cls.create_prompt(
  55. extra_prompt_messages=extra_prompt_messages,
  56. system_message=system_message,
  57. )
  58. return cls(
  59. model_config=model_config,
  60. llm=FakeLLM(response=''),
  61. prompt=prompt,
  62. tools=tools,
  63. callback_manager=callback_manager,
  64. agent_llm_callback=agent_llm_callback,
  65. **kwargs,
  66. )
  67. def should_use_agent(self, query: str):
  68. """
  69. return should use agent
  70. :param query:
  71. :return:
  72. """
  73. original_max_tokens = 0
  74. for parameter_rule in self.model_config.model_schema.parameter_rules:
  75. if (parameter_rule.name == 'max_tokens'
  76. or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
  77. original_max_tokens = (self.model_config.parameters.get(parameter_rule.name)
  78. or self.model_config.parameters.get(parameter_rule.use_template)) or 0
  79. self.model_config.parameters['max_tokens'] = 40
  80. prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
  81. messages = prompt.to_messages()
  82. try:
  83. prompt_messages = lc_messages_to_prompt_messages(messages)
  84. model_instance = ModelInstance(
  85. provider_model_bundle=self.model_config.provider_model_bundle,
  86. model=self.model_config.model,
  87. )
  88. tools = []
  89. for function in self.functions:
  90. tool = PromptMessageTool(
  91. **function
  92. )
  93. tools.append(tool)
  94. result = model_instance.invoke_llm(
  95. prompt_messages=prompt_messages,
  96. tools=tools,
  97. stream=False,
  98. model_parameters={
  99. 'temperature': 0.2,
  100. 'top_p': 0.3,
  101. 'max_tokens': 1500
  102. }
  103. )
  104. except Exception as e:
  105. raise e
  106. self.model_config.parameters['max_tokens'] = original_max_tokens
  107. return True if result.message.tool_calls else False
  108. def plan(
  109. self,
  110. intermediate_steps: list[tuple[AgentAction, str]],
  111. callbacks: Callbacks = None,
  112. **kwargs: Any,
  113. ) -> Union[AgentAction, AgentFinish]:
  114. """Given input, decided what to do.
  115. Args:
  116. intermediate_steps: Steps the LLM has taken to date, along with observations
  117. **kwargs: User inputs.
  118. Returns:
  119. Action specifying what tool to use.
  120. """
  121. agent_scratchpad = _format_intermediate_steps(intermediate_steps)
  122. selected_inputs = {
  123. k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
  124. }
  125. full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
  126. prompt = self.prompt.format_prompt(**full_inputs)
  127. messages = prompt.to_messages()
  128. prompt_messages = lc_messages_to_prompt_messages(messages)
  129. # summarize messages if rest_tokens < 0
  130. try:
  131. prompt_messages = self.summarize_messages_if_needed(prompt_messages, functions=self.functions)
  132. except ExceededLLMTokensLimitError as e:
  133. return AgentFinish(return_values={"output": str(e)}, log=str(e))
  134. model_instance = ModelInstance(
  135. provider_model_bundle=self.model_config.provider_model_bundle,
  136. model=self.model_config.model,
  137. )
  138. tools = []
  139. for function in self.functions:
  140. tool = PromptMessageTool(
  141. **function
  142. )
  143. tools.append(tool)
  144. result = model_instance.invoke_llm(
  145. prompt_messages=prompt_messages,
  146. tools=tools,
  147. stream=False,
  148. callbacks=[self.agent_llm_callback] if self.agent_llm_callback else [],
  149. model_parameters={
  150. 'temperature': 0.2,
  151. 'top_p': 0.3,
  152. 'max_tokens': 1500
  153. }
  154. )
  155. ai_message = AIMessage(
  156. content=result.message.content or "",
  157. additional_kwargs={
  158. 'function_call': {
  159. 'id': result.message.tool_calls[0].id,
  160. **result.message.tool_calls[0].function.dict()
  161. } if result.message.tool_calls else None
  162. }
  163. )
  164. agent_decision = _parse_ai_message(ai_message)
  165. if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset':
  166. tool_inputs = agent_decision.tool_input
  167. if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
  168. tool_inputs['query'] = kwargs['input']
  169. agent_decision.tool_input = tool_inputs
  170. return agent_decision
  171. @classmethod
  172. def get_system_message(cls):
  173. return SystemMessage(content="You are a helpful AI assistant.\n"
  174. "The current date or current time you know is wrong.\n"
  175. "Respond directly if appropriate.")
  176. def return_stopped_response(
  177. self,
  178. early_stopping_method: str,
  179. intermediate_steps: list[tuple[AgentAction, str]],
  180. **kwargs: Any,
  181. ) -> AgentFinish:
  182. try:
  183. return super().return_stopped_response(early_stopping_method, intermediate_steps, **kwargs)
  184. except ValueError:
  185. return AgentFinish({"output": "I'm sorry, I don't know how to respond to that."}, "")
  186. def summarize_messages_if_needed(self, messages: list[PromptMessage], **kwargs) -> list[PromptMessage]:
  187. # calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
  188. rest_tokens = self.get_message_rest_tokens(
  189. self.model_config,
  190. messages,
  191. **kwargs
  192. )
  193. rest_tokens = rest_tokens - 20 # to deal with the inaccuracy of rest_tokens
  194. if rest_tokens >= 0:
  195. return messages
  196. system_message = None
  197. human_message = None
  198. should_summary_messages = []
  199. for message in messages:
  200. if isinstance(message, SystemMessage):
  201. system_message = message
  202. elif isinstance(message, HumanMessage):
  203. human_message = message
  204. else:
  205. should_summary_messages.append(message)
  206. if len(should_summary_messages) > 2:
  207. ai_message = should_summary_messages[-2]
  208. function_message = should_summary_messages[-1]
  209. should_summary_messages = should_summary_messages[self.moving_summary_index:-2]
  210. self.moving_summary_index = len(should_summary_messages)
  211. else:
  212. error_msg = "Exceeded LLM tokens limit, stopped."
  213. raise ExceededLLMTokensLimitError(error_msg)
  214. new_messages = [system_message, human_message]
  215. if self.moving_summary_index == 0:
  216. should_summary_messages.insert(0, human_message)
  217. self.moving_summary_buffer = self.predict_new_summary(
  218. messages=should_summary_messages,
  219. existing_summary=self.moving_summary_buffer
  220. )
  221. new_messages.append(AIMessage(content=self.moving_summary_buffer))
  222. new_messages.append(ai_message)
  223. new_messages.append(function_message)
  224. return new_messages
  225. def predict_new_summary(
  226. self, messages: list[BaseMessage], existing_summary: str
  227. ) -> str:
  228. new_lines = get_buffer_string(
  229. messages,
  230. human_prefix="Human",
  231. ai_prefix="AI",
  232. )
  233. chain = LLMChain(model_config=self.summary_model_config, prompt=SUMMARY_PROMPT)
  234. return chain.predict(summary=existing_summary, new_lines=new_lines)
  235. def get_num_tokens_from_messages(self, model_config: ModelConfigEntity, messages: list[BaseMessage], **kwargs) -> int:
  236. """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
  237. Official documentation: https://github.com/openai/openai-cookbook/blob/
  238. main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
  239. if model_config.provider == 'azure_openai':
  240. model = model_config.model
  241. model = model.replace("gpt-35", "gpt-3.5")
  242. else:
  243. model = model_config.credentials.get("base_model_name")
  244. tiktoken_ = _import_tiktoken()
  245. try:
  246. encoding = tiktoken_.encoding_for_model(model)
  247. except KeyError:
  248. model = "cl100k_base"
  249. encoding = tiktoken_.get_encoding(model)
  250. if model.startswith("gpt-3.5-turbo"):
  251. # every message follows <im_start>{role/name}\n{content}<im_end>\n
  252. tokens_per_message = 4
  253. # if there's a name, the role is omitted
  254. tokens_per_name = -1
  255. elif model.startswith("gpt-4"):
  256. tokens_per_message = 3
  257. tokens_per_name = 1
  258. else:
  259. raise NotImplementedError(
  260. f"get_num_tokens_from_messages() is not presently implemented "
  261. f"for model {model}."
  262. "See https://github.com/openai/openai-python/blob/main/chatml.md for "
  263. "information on how messages are converted to tokens."
  264. )
  265. num_tokens = 0
  266. for m in messages:
  267. message = _convert_message_to_dict(m)
  268. num_tokens += tokens_per_message
  269. for key, value in message.items():
  270. if key == "function_call":
  271. for f_key, f_value in value.items():
  272. num_tokens += len(encoding.encode(f_key))
  273. num_tokens += len(encoding.encode(f_value))
  274. else:
  275. num_tokens += len(encoding.encode(value))
  276. if key == "name":
  277. num_tokens += tokens_per_name
  278. # every reply is primed with <im_start>assistant
  279. num_tokens += 3
  280. if kwargs.get('functions'):
  281. for function in kwargs.get('functions'):
  282. num_tokens += len(encoding.encode('name'))
  283. num_tokens += len(encoding.encode(function.get("name")))
  284. num_tokens += len(encoding.encode('description'))
  285. num_tokens += len(encoding.encode(function.get("description")))
  286. parameters = function.get("parameters")
  287. num_tokens += len(encoding.encode('parameters'))
  288. if 'title' in parameters:
  289. num_tokens += len(encoding.encode('title'))
  290. num_tokens += len(encoding.encode(parameters.get("title")))
  291. num_tokens += len(encoding.encode('type'))
  292. num_tokens += len(encoding.encode(parameters.get("type")))
  293. if 'properties' in parameters:
  294. num_tokens += len(encoding.encode('properties'))
  295. for key, value in parameters.get('properties').items():
  296. num_tokens += len(encoding.encode(key))
  297. for field_key, field_value in value.items():
  298. num_tokens += len(encoding.encode(field_key))
  299. if field_key == 'enum':
  300. for enum_field in field_value:
  301. num_tokens += 3
  302. num_tokens += len(encoding.encode(enum_field))
  303. else:
  304. num_tokens += len(encoding.encode(field_key))
  305. num_tokens += len(encoding.encode(str(field_value)))
  306. if 'required' in parameters:
  307. num_tokens += len(encoding.encode('required'))
  308. for required_field in parameters['required']:
  309. num_tokens += 3
  310. num_tokens += len(encoding.encode(required_field))
  311. return num_tokens