openai_function_call.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360
  1. from typing import Any, List, Optional, Sequence, Tuple, Union
  2. from langchain.agents import BaseSingleActionAgent, OpenAIFunctionsAgent
  3. from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message
  4. from langchain.callbacks.base import BaseCallbackManager
  5. from langchain.callbacks.manager import Callbacks
  6. from langchain.chat_models.openai import _convert_message_to_dict, _import_tiktoken
  7. from langchain.memory.prompt import SUMMARY_PROMPT
  8. from langchain.prompts.chat import BaseMessagePromptTemplate
  9. from langchain.schema import (
  10. AgentAction,
  11. AgentFinish,
  12. AIMessage,
  13. BaseMessage,
  14. HumanMessage,
  15. SystemMessage,
  16. get_buffer_string,
  17. )
  18. from langchain.tools import BaseTool
  19. from pydantic import root_validator
  20. from core.agent.agent.agent_llm_callback import AgentLLMCallback
  21. from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
  22. from core.chain.llm_chain import LLMChain
  23. from core.entities.application_entities import ModelConfigEntity
  24. from core.entities.message_entities import lc_messages_to_prompt_messages
  25. from core.model_manager import ModelInstance
  26. from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
  27. from core.third_party.langchain.llms.fake import FakeLLM
  28. class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixin):
  29. moving_summary_buffer: str = ""
  30. moving_summary_index: int = 0
  31. summary_model_config: ModelConfigEntity = None
  32. model_config: ModelConfigEntity
  33. agent_llm_callback: Optional[AgentLLMCallback] = None
  34. class Config:
  35. """Configuration for this pydantic object."""
  36. arbitrary_types_allowed = True
  37. @root_validator
  38. def validate_llm(cls, values: dict) -> dict:
  39. return values
  40. @classmethod
  41. def from_llm_and_tools(
  42. cls,
  43. model_config: ModelConfigEntity,
  44. tools: Sequence[BaseTool],
  45. callback_manager: Optional[BaseCallbackManager] = None,
  46. extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
  47. system_message: Optional[SystemMessage] = SystemMessage(
  48. content="You are a helpful AI assistant."
  49. ),
  50. agent_llm_callback: Optional[AgentLLMCallback] = None,
  51. **kwargs: Any,
  52. ) -> BaseSingleActionAgent:
  53. prompt = cls.create_prompt(
  54. extra_prompt_messages=extra_prompt_messages,
  55. system_message=system_message,
  56. )
  57. return cls(
  58. model_config=model_config,
  59. llm=FakeLLM(response=''),
  60. prompt=prompt,
  61. tools=tools,
  62. callback_manager=callback_manager,
  63. agent_llm_callback=agent_llm_callback,
  64. **kwargs,
  65. )
  66. def should_use_agent(self, query: str):
  67. """
  68. return should use agent
  69. :param query:
  70. :return:
  71. """
  72. original_max_tokens = 0
  73. for parameter_rule in self.model_config.model_schema.parameter_rules:
  74. if (parameter_rule.name == 'max_tokens'
  75. or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
  76. original_max_tokens = (self.model_config.parameters.get(parameter_rule.name)
  77. or self.model_config.parameters.get(parameter_rule.use_template)) or 0
  78. self.model_config.parameters['max_tokens'] = 40
  79. prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
  80. messages = prompt.to_messages()
  81. try:
  82. prompt_messages = lc_messages_to_prompt_messages(messages)
  83. model_instance = ModelInstance(
  84. provider_model_bundle=self.model_config.provider_model_bundle,
  85. model=self.model_config.model,
  86. )
  87. tools = []
  88. for function in self.functions:
  89. tool = PromptMessageTool(
  90. **function
  91. )
  92. tools.append(tool)
  93. result = model_instance.invoke_llm(
  94. prompt_messages=prompt_messages,
  95. tools=tools,
  96. stream=False,
  97. model_parameters={
  98. 'temperature': 0.2,
  99. 'top_p': 0.3,
  100. 'max_tokens': 1500
  101. }
  102. )
  103. except Exception as e:
  104. raise e
  105. self.model_config.parameters['max_tokens'] = original_max_tokens
  106. return True if result.message.tool_calls else False
  107. def plan(
  108. self,
  109. intermediate_steps: List[Tuple[AgentAction, str]],
  110. callbacks: Callbacks = None,
  111. **kwargs: Any,
  112. ) -> Union[AgentAction, AgentFinish]:
  113. """Given input, decided what to do.
  114. Args:
  115. intermediate_steps: Steps the LLM has taken to date, along with observations
  116. **kwargs: User inputs.
  117. Returns:
  118. Action specifying what tool to use.
  119. """
  120. agent_scratchpad = _format_intermediate_steps(intermediate_steps)
  121. selected_inputs = {
  122. k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
  123. }
  124. full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
  125. prompt = self.prompt.format_prompt(**full_inputs)
  126. messages = prompt.to_messages()
  127. prompt_messages = lc_messages_to_prompt_messages(messages)
  128. # summarize messages if rest_tokens < 0
  129. try:
  130. prompt_messages = self.summarize_messages_if_needed(prompt_messages, functions=self.functions)
  131. except ExceededLLMTokensLimitError as e:
  132. return AgentFinish(return_values={"output": str(e)}, log=str(e))
  133. model_instance = ModelInstance(
  134. provider_model_bundle=self.model_config.provider_model_bundle,
  135. model=self.model_config.model,
  136. )
  137. tools = []
  138. for function in self.functions:
  139. tool = PromptMessageTool(
  140. **function
  141. )
  142. tools.append(tool)
  143. result = model_instance.invoke_llm(
  144. prompt_messages=prompt_messages,
  145. tools=tools,
  146. stream=False,
  147. callbacks=[self.agent_llm_callback] if self.agent_llm_callback else [],
  148. model_parameters={
  149. 'temperature': 0.2,
  150. 'top_p': 0.3,
  151. 'max_tokens': 1500
  152. }
  153. )
  154. ai_message = AIMessage(
  155. content=result.message.content or "",
  156. additional_kwargs={
  157. 'function_call': {
  158. 'id': result.message.tool_calls[0].id,
  159. **result.message.tool_calls[0].function.dict()
  160. } if result.message.tool_calls else None
  161. }
  162. )
  163. agent_decision = _parse_ai_message(ai_message)
  164. if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset':
  165. tool_inputs = agent_decision.tool_input
  166. if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
  167. tool_inputs['query'] = kwargs['input']
  168. agent_decision.tool_input = tool_inputs
  169. return agent_decision
  170. @classmethod
  171. def get_system_message(cls):
  172. return SystemMessage(content="You are a helpful AI assistant.\n"
  173. "The current date or current time you know is wrong.\n"
  174. "Respond directly if appropriate.")
  175. def return_stopped_response(
  176. self,
  177. early_stopping_method: str,
  178. intermediate_steps: List[Tuple[AgentAction, str]],
  179. **kwargs: Any,
  180. ) -> AgentFinish:
  181. try:
  182. return super().return_stopped_response(early_stopping_method, intermediate_steps, **kwargs)
  183. except ValueError:
  184. return AgentFinish({"output": "I'm sorry, I don't know how to respond to that."}, "")
  185. def summarize_messages_if_needed(self, messages: List[PromptMessage], **kwargs) -> List[PromptMessage]:
  186. # calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
  187. rest_tokens = self.get_message_rest_tokens(
  188. self.model_config,
  189. messages,
  190. **kwargs
  191. )
  192. rest_tokens = rest_tokens - 20 # to deal with the inaccuracy of rest_tokens
  193. if rest_tokens >= 0:
  194. return messages
  195. system_message = None
  196. human_message = None
  197. should_summary_messages = []
  198. for message in messages:
  199. if isinstance(message, SystemMessage):
  200. system_message = message
  201. elif isinstance(message, HumanMessage):
  202. human_message = message
  203. else:
  204. should_summary_messages.append(message)
  205. if len(should_summary_messages) > 2:
  206. ai_message = should_summary_messages[-2]
  207. function_message = should_summary_messages[-1]
  208. should_summary_messages = should_summary_messages[self.moving_summary_index:-2]
  209. self.moving_summary_index = len(should_summary_messages)
  210. else:
  211. error_msg = "Exceeded LLM tokens limit, stopped."
  212. raise ExceededLLMTokensLimitError(error_msg)
  213. new_messages = [system_message, human_message]
  214. if self.moving_summary_index == 0:
  215. should_summary_messages.insert(0, human_message)
  216. self.moving_summary_buffer = self.predict_new_summary(
  217. messages=should_summary_messages,
  218. existing_summary=self.moving_summary_buffer
  219. )
  220. new_messages.append(AIMessage(content=self.moving_summary_buffer))
  221. new_messages.append(ai_message)
  222. new_messages.append(function_message)
  223. return new_messages
  224. def predict_new_summary(
  225. self, messages: List[BaseMessage], existing_summary: str
  226. ) -> str:
  227. new_lines = get_buffer_string(
  228. messages,
  229. human_prefix="Human",
  230. ai_prefix="AI",
  231. )
  232. chain = LLMChain(model_config=self.summary_model_config, prompt=SUMMARY_PROMPT)
  233. return chain.predict(summary=existing_summary, new_lines=new_lines)
  234. def get_num_tokens_from_messages(self, model_config: ModelConfigEntity, messages: List[BaseMessage], **kwargs) -> int:
  235. """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
  236. Official documentation: https://github.com/openai/openai-cookbook/blob/
  237. main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
  238. if model_config.provider == 'azure_openai':
  239. model = model_config.model
  240. model = model.replace("gpt-35", "gpt-3.5")
  241. else:
  242. model = model_config.credentials.get("base_model_name")
  243. tiktoken_ = _import_tiktoken()
  244. try:
  245. encoding = tiktoken_.encoding_for_model(model)
  246. except KeyError:
  247. model = "cl100k_base"
  248. encoding = tiktoken_.get_encoding(model)
  249. if model.startswith("gpt-3.5-turbo"):
  250. # every message follows <im_start>{role/name}\n{content}<im_end>\n
  251. tokens_per_message = 4
  252. # if there's a name, the role is omitted
  253. tokens_per_name = -1
  254. elif model.startswith("gpt-4"):
  255. tokens_per_message = 3
  256. tokens_per_name = 1
  257. else:
  258. raise NotImplementedError(
  259. f"get_num_tokens_from_messages() is not presently implemented "
  260. f"for model {model}."
  261. "See https://github.com/openai/openai-python/blob/main/chatml.md for "
  262. "information on how messages are converted to tokens."
  263. )
  264. num_tokens = 0
  265. for m in messages:
  266. message = _convert_message_to_dict(m)
  267. num_tokens += tokens_per_message
  268. for key, value in message.items():
  269. if key == "function_call":
  270. for f_key, f_value in value.items():
  271. num_tokens += len(encoding.encode(f_key))
  272. num_tokens += len(encoding.encode(f_value))
  273. else:
  274. num_tokens += len(encoding.encode(value))
  275. if key == "name":
  276. num_tokens += tokens_per_name
  277. # every reply is primed with <im_start>assistant
  278. num_tokens += 3
  279. if kwargs.get('functions'):
  280. for function in kwargs.get('functions'):
  281. num_tokens += len(encoding.encode('name'))
  282. num_tokens += len(encoding.encode(function.get("name")))
  283. num_tokens += len(encoding.encode('description'))
  284. num_tokens += len(encoding.encode(function.get("description")))
  285. parameters = function.get("parameters")
  286. num_tokens += len(encoding.encode('parameters'))
  287. if 'title' in parameters:
  288. num_tokens += len(encoding.encode('title'))
  289. num_tokens += len(encoding.encode(parameters.get("title")))
  290. num_tokens += len(encoding.encode('type'))
  291. num_tokens += len(encoding.encode(parameters.get("type")))
  292. if 'properties' in parameters:
  293. num_tokens += len(encoding.encode('properties'))
  294. for key, value in parameters.get('properties').items():
  295. num_tokens += len(encoding.encode(key))
  296. for field_key, field_value in value.items():
  297. num_tokens += len(encoding.encode(field_key))
  298. if field_key == 'enum':
  299. for enum_field in field_value:
  300. num_tokens += 3
  301. num_tokens += len(encoding.encode(enum_field))
  302. else:
  303. num_tokens += len(encoding.encode(field_key))
  304. num_tokens += len(encoding.encode(str(field_value)))
  305. if 'required' in parameters:
  306. num_tokens += len(encoding.encode('required'))
  307. for required_field in parameters['required']:
  308. num_tokens += 3
  309. num_tokens += len(encoding.encode(required_field))
  310. return num_tokens