openai_function_call.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353
  1. from typing import Any, List, Optional, Sequence, Tuple, Union, cast
  2. from core.agent.agent.agent_llm_callback import AgentLLMCallback
  3. from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
  4. from core.chain.llm_chain import LLMChain
  5. from core.entities.application_entities import ModelConfigEntity
  6. from core.entities.message_entities import lc_messages_to_prompt_messages
  7. from core.model_manager import ModelInstance
  8. from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
  9. from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
  10. from core.third_party.langchain.llms.fake import FakeLLM
  11. from langchain.agents import BaseSingleActionAgent, OpenAIFunctionsAgent
  12. from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message
  13. from langchain.callbacks.base import BaseCallbackManager
  14. from langchain.callbacks.manager import Callbacks
  15. from langchain.chat_models.openai import _convert_message_to_dict, _import_tiktoken
  16. from langchain.memory.prompt import SUMMARY_PROMPT
  17. from langchain.prompts.chat import BaseMessagePromptTemplate
  18. from langchain.schema import (AgentAction, AgentFinish, AIMessage, BaseMessage, HumanMessage, SystemMessage,
  19. get_buffer_string)
  20. from langchain.tools import BaseTool
  21. from pydantic import root_validator
  22. class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixin):
  23. moving_summary_buffer: str = ""
  24. moving_summary_index: int = 0
  25. summary_model_config: ModelConfigEntity = None
  26. model_config: ModelConfigEntity
  27. agent_llm_callback: Optional[AgentLLMCallback] = None
  28. class Config:
  29. """Configuration for this pydantic object."""
  30. arbitrary_types_allowed = True
  31. @root_validator
  32. def validate_llm(cls, values: dict) -> dict:
  33. return values
  34. @classmethod
  35. def from_llm_and_tools(
  36. cls,
  37. model_config: ModelConfigEntity,
  38. tools: Sequence[BaseTool],
  39. callback_manager: Optional[BaseCallbackManager] = None,
  40. extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
  41. system_message: Optional[SystemMessage] = SystemMessage(
  42. content="You are a helpful AI assistant."
  43. ),
  44. agent_llm_callback: Optional[AgentLLMCallback] = None,
  45. **kwargs: Any,
  46. ) -> BaseSingleActionAgent:
  47. prompt = cls.create_prompt(
  48. extra_prompt_messages=extra_prompt_messages,
  49. system_message=system_message,
  50. )
  51. return cls(
  52. model_config=model_config,
  53. llm=FakeLLM(response=''),
  54. prompt=prompt,
  55. tools=tools,
  56. callback_manager=callback_manager,
  57. agent_llm_callback=agent_llm_callback,
  58. **kwargs,
  59. )
  60. def should_use_agent(self, query: str):
  61. """
  62. return should use agent
  63. :param query:
  64. :return:
  65. """
  66. original_max_tokens = 0
  67. for parameter_rule in self.model_config.model_schema.parameter_rules:
  68. if (parameter_rule.name == 'max_tokens'
  69. or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
  70. original_max_tokens = (self.model_config.parameters.get(parameter_rule.name)
  71. or self.model_config.parameters.get(parameter_rule.use_template)) or 0
  72. self.model_config.parameters['max_tokens'] = 40
  73. prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
  74. messages = prompt.to_messages()
  75. try:
  76. prompt_messages = lc_messages_to_prompt_messages(messages)
  77. model_instance = ModelInstance(
  78. provider_model_bundle=self.model_config.provider_model_bundle,
  79. model=self.model_config.model,
  80. )
  81. tools = []
  82. for function in self.functions:
  83. tool = PromptMessageTool(
  84. **function
  85. )
  86. tools.append(tool)
  87. result = model_instance.invoke_llm(
  88. prompt_messages=prompt_messages,
  89. tools=tools,
  90. stream=False,
  91. model_parameters={
  92. 'temperature': 0.2,
  93. 'top_p': 0.3,
  94. 'max_tokens': 1500
  95. }
  96. )
  97. except Exception as e:
  98. raise e
  99. self.model_config.parameters['max_tokens'] = original_max_tokens
  100. return True if result.message.tool_calls else False
  101. def plan(
  102. self,
  103. intermediate_steps: List[Tuple[AgentAction, str]],
  104. callbacks: Callbacks = None,
  105. **kwargs: Any,
  106. ) -> Union[AgentAction, AgentFinish]:
  107. """Given input, decided what to do.
  108. Args:
  109. intermediate_steps: Steps the LLM has taken to date, along with observations
  110. **kwargs: User inputs.
  111. Returns:
  112. Action specifying what tool to use.
  113. """
  114. agent_scratchpad = _format_intermediate_steps(intermediate_steps)
  115. selected_inputs = {
  116. k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
  117. }
  118. full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
  119. prompt = self.prompt.format_prompt(**full_inputs)
  120. messages = prompt.to_messages()
  121. prompt_messages = lc_messages_to_prompt_messages(messages)
  122. # summarize messages if rest_tokens < 0
  123. try:
  124. prompt_messages = self.summarize_messages_if_needed(prompt_messages, functions=self.functions)
  125. except ExceededLLMTokensLimitError as e:
  126. return AgentFinish(return_values={"output": str(e)}, log=str(e))
  127. model_instance = ModelInstance(
  128. provider_model_bundle=self.model_config.provider_model_bundle,
  129. model=self.model_config.model,
  130. )
  131. tools = []
  132. for function in self.functions:
  133. tool = PromptMessageTool(
  134. **function
  135. )
  136. tools.append(tool)
  137. result = model_instance.invoke_llm(
  138. prompt_messages=prompt_messages,
  139. tools=tools,
  140. stream=False,
  141. callbacks=[self.agent_llm_callback] if self.agent_llm_callback else [],
  142. model_parameters={
  143. 'temperature': 0.2,
  144. 'top_p': 0.3,
  145. 'max_tokens': 1500
  146. }
  147. )
  148. ai_message = AIMessage(
  149. content=result.message.content or "",
  150. additional_kwargs={
  151. 'function_call': {
  152. 'id': result.message.tool_calls[0].id,
  153. **result.message.tool_calls[0].function.dict()
  154. } if result.message.tool_calls else None
  155. }
  156. )
  157. agent_decision = _parse_ai_message(ai_message)
  158. if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset':
  159. tool_inputs = agent_decision.tool_input
  160. if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
  161. tool_inputs['query'] = kwargs['input']
  162. agent_decision.tool_input = tool_inputs
  163. return agent_decision
  164. @classmethod
  165. def get_system_message(cls):
  166. return SystemMessage(content="You are a helpful AI assistant.\n"
  167. "The current date or current time you know is wrong.\n"
  168. "Respond directly if appropriate.")
  169. def return_stopped_response(
  170. self,
  171. early_stopping_method: str,
  172. intermediate_steps: List[Tuple[AgentAction, str]],
  173. **kwargs: Any,
  174. ) -> AgentFinish:
  175. try:
  176. return super().return_stopped_response(early_stopping_method, intermediate_steps, **kwargs)
  177. except ValueError:
  178. return AgentFinish({"output": "I'm sorry, I don't know how to respond to that."}, "")
  179. def summarize_messages_if_needed(self, messages: List[PromptMessage], **kwargs) -> List[PromptMessage]:
  180. # calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
  181. rest_tokens = self.get_message_rest_tokens(
  182. self.model_config,
  183. messages,
  184. **kwargs
  185. )
  186. rest_tokens = rest_tokens - 20 # to deal with the inaccuracy of rest_tokens
  187. if rest_tokens >= 0:
  188. return messages
  189. system_message = None
  190. human_message = None
  191. should_summary_messages = []
  192. for message in messages:
  193. if isinstance(message, SystemMessage):
  194. system_message = message
  195. elif isinstance(message, HumanMessage):
  196. human_message = message
  197. else:
  198. should_summary_messages.append(message)
  199. if len(should_summary_messages) > 2:
  200. ai_message = should_summary_messages[-2]
  201. function_message = should_summary_messages[-1]
  202. should_summary_messages = should_summary_messages[self.moving_summary_index:-2]
  203. self.moving_summary_index = len(should_summary_messages)
  204. else:
  205. error_msg = "Exceeded LLM tokens limit, stopped."
  206. raise ExceededLLMTokensLimitError(error_msg)
  207. new_messages = [system_message, human_message]
  208. if self.moving_summary_index == 0:
  209. should_summary_messages.insert(0, human_message)
  210. self.moving_summary_buffer = self.predict_new_summary(
  211. messages=should_summary_messages,
  212. existing_summary=self.moving_summary_buffer
  213. )
  214. new_messages.append(AIMessage(content=self.moving_summary_buffer))
  215. new_messages.append(ai_message)
  216. new_messages.append(function_message)
  217. return new_messages
  218. def predict_new_summary(
  219. self, messages: List[BaseMessage], existing_summary: str
  220. ) -> str:
  221. new_lines = get_buffer_string(
  222. messages,
  223. human_prefix="Human",
  224. ai_prefix="AI",
  225. )
  226. chain = LLMChain(model_config=self.summary_model_config, prompt=SUMMARY_PROMPT)
  227. return chain.predict(summary=existing_summary, new_lines=new_lines)
  228. def get_num_tokens_from_messages(self, model_config: ModelConfigEntity, messages: List[BaseMessage], **kwargs) -> int:
  229. """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
  230. Official documentation: https://github.com/openai/openai-cookbook/blob/
  231. main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
  232. if model_config.provider == 'azure_openai':
  233. model = model_config.model
  234. model = model.replace("gpt-35", "gpt-3.5")
  235. else:
  236. model = model_config.credentials.get("base_model_name")
  237. tiktoken_ = _import_tiktoken()
  238. try:
  239. encoding = tiktoken_.encoding_for_model(model)
  240. except KeyError:
  241. model = "cl100k_base"
  242. encoding = tiktoken_.get_encoding(model)
  243. if model.startswith("gpt-3.5-turbo"):
  244. # every message follows <im_start>{role/name}\n{content}<im_end>\n
  245. tokens_per_message = 4
  246. # if there's a name, the role is omitted
  247. tokens_per_name = -1
  248. elif model.startswith("gpt-4"):
  249. tokens_per_message = 3
  250. tokens_per_name = 1
  251. else:
  252. raise NotImplementedError(
  253. f"get_num_tokens_from_messages() is not presently implemented "
  254. f"for model {model}."
  255. "See https://github.com/openai/openai-python/blob/main/chatml.md for "
  256. "information on how messages are converted to tokens."
  257. )
  258. num_tokens = 0
  259. for m in messages:
  260. message = _convert_message_to_dict(m)
  261. num_tokens += tokens_per_message
  262. for key, value in message.items():
  263. if key == "function_call":
  264. for f_key, f_value in value.items():
  265. num_tokens += len(encoding.encode(f_key))
  266. num_tokens += len(encoding.encode(f_value))
  267. else:
  268. num_tokens += len(encoding.encode(value))
  269. if key == "name":
  270. num_tokens += tokens_per_name
  271. # every reply is primed with <im_start>assistant
  272. num_tokens += 3
  273. if kwargs.get('functions'):
  274. for function in kwargs.get('functions'):
  275. num_tokens += len(encoding.encode('name'))
  276. num_tokens += len(encoding.encode(function.get("name")))
  277. num_tokens += len(encoding.encode('description'))
  278. num_tokens += len(encoding.encode(function.get("description")))
  279. parameters = function.get("parameters")
  280. num_tokens += len(encoding.encode('parameters'))
  281. if 'title' in parameters:
  282. num_tokens += len(encoding.encode('title'))
  283. num_tokens += len(encoding.encode(parameters.get("title")))
  284. num_tokens += len(encoding.encode('type'))
  285. num_tokens += len(encoding.encode(parameters.get("type")))
  286. if 'properties' in parameters:
  287. num_tokens += len(encoding.encode('properties'))
  288. for key, value in parameters.get('properties').items():
  289. num_tokens += len(encoding.encode(key))
  290. for field_key, field_value in value.items():
  291. num_tokens += len(encoding.encode(field_key))
  292. if field_key == 'enum':
  293. for enum_field in field_value:
  294. num_tokens += 3
  295. num_tokens += len(encoding.encode(enum_field))
  296. else:
  297. num_tokens += len(encoding.encode(field_key))
  298. num_tokens += len(encoding.encode(str(field_value)))
  299. if 'required' in parameters:
  300. num_tokens += len(encoding.encode('required'))
  301. for required_field in parameters['required']:
  302. num_tokens += 3
  303. num_tokens += len(encoding.encode(required_field))
  304. return num_tokens