api_tool.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325
  1. import json
  2. from json import dumps
  3. from os import getenv
  4. from typing import Any, Union
  5. from urllib.parse import urlencode
  6. import httpx
  7. import requests
  8. import core.helper.ssrf_proxy as ssrf_proxy
  9. from core.tools.entities.tool_bundle import ApiToolBundle
  10. from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
  11. from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError
  12. from core.tools.tool.tool import Tool
  13. API_TOOL_DEFAULT_TIMEOUT = (
  14. int(getenv('API_TOOL_DEFAULT_CONNECT_TIMEOUT', '10')),
  15. int(getenv('API_TOOL_DEFAULT_READ_TIMEOUT', '60'))
  16. )
  17. class ApiTool(Tool):
  18. api_bundle: ApiToolBundle
  19. """
  20. Api tool
  21. """
  22. def fork_tool_runtime(self, runtime: dict[str, Any]) -> 'Tool':
  23. """
  24. fork a new tool with meta data
  25. :param meta: the meta data of a tool call processing, tenant_id is required
  26. :return: the new tool
  27. """
  28. return self.__class__(
  29. identity=self.identity.copy() if self.identity else None,
  30. parameters=self.parameters.copy() if self.parameters else None,
  31. description=self.description.copy() if self.description else None,
  32. api_bundle=self.api_bundle.copy() if self.api_bundle else None,
  33. runtime=Tool.Runtime(**runtime)
  34. )
  35. def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False) -> str:
  36. """
  37. validate the credentials for Api tool
  38. """
  39. # assemble validate request and request parameters
  40. headers = self.assembling_request(parameters)
  41. if format_only:
  42. return
  43. response = self.do_http_request(self.api_bundle.server_url, self.api_bundle.method, headers, parameters)
  44. # validate response
  45. return self.validate_and_parse_response(response)
  46. def tool_provider_type(self) -> ToolProviderType:
  47. return ToolProviderType.API
  48. def assembling_request(self, parameters: dict[str, Any]) -> dict[str, Any]:
  49. headers = {}
  50. credentials = self.runtime.credentials or {}
  51. if 'auth_type' not in credentials:
  52. raise ToolProviderCredentialValidationError('Missing auth_type')
  53. if credentials['auth_type'] == 'api_key':
  54. api_key_header = 'api_key'
  55. if 'api_key_header' in credentials:
  56. api_key_header = credentials['api_key_header']
  57. if 'api_key_value' not in credentials:
  58. raise ToolProviderCredentialValidationError('Missing api_key_value')
  59. elif not isinstance(credentials['api_key_value'], str):
  60. raise ToolProviderCredentialValidationError('api_key_value must be a string')
  61. if 'api_key_header_prefix' in credentials:
  62. api_key_header_prefix = credentials['api_key_header_prefix']
  63. if api_key_header_prefix == 'basic' and credentials['api_key_value']:
  64. credentials['api_key_value'] = f'Basic {credentials["api_key_value"]}'
  65. elif api_key_header_prefix == 'bearer' and credentials['api_key_value']:
  66. credentials['api_key_value'] = f'Bearer {credentials["api_key_value"]}'
  67. elif api_key_header_prefix == 'custom':
  68. pass
  69. headers[api_key_header] = credentials['api_key_value']
  70. needed_parameters = [parameter for parameter in self.api_bundle.parameters if parameter.required]
  71. for parameter in needed_parameters:
  72. if parameter.required and parameter.name not in parameters:
  73. raise ToolParameterValidationError(f"Missing required parameter {parameter.name}")
  74. if parameter.default is not None and parameter.name not in parameters:
  75. parameters[parameter.name] = parameter.default
  76. return headers
  77. def validate_and_parse_response(self, response: Union[httpx.Response, requests.Response]) -> str:
  78. """
  79. validate the response
  80. """
  81. if isinstance(response, httpx.Response):
  82. if response.status_code >= 400:
  83. raise ToolInvokeError(f"Request failed with status code {response.status_code} and {response.text}")
  84. if not response.content:
  85. return 'Empty response from the tool, please check your parameters and try again.'
  86. try:
  87. response = response.json()
  88. try:
  89. return json.dumps(response, ensure_ascii=False)
  90. except Exception as e:
  91. return json.dumps(response)
  92. except Exception as e:
  93. return response.text
  94. elif isinstance(response, requests.Response):
  95. if not response.ok:
  96. raise ToolInvokeError(f"Request failed with status code {response.status_code} and {response.text}")
  97. if not response.content:
  98. return 'Empty response from the tool, please check your parameters and try again.'
  99. try:
  100. response = response.json()
  101. try:
  102. return json.dumps(response, ensure_ascii=False)
  103. except Exception as e:
  104. return json.dumps(response)
  105. except Exception as e:
  106. return response.text
  107. else:
  108. raise ValueError(f'Invalid response type {type(response)}')
  109. def do_http_request(self, url: str, method: str, headers: dict[str, Any], parameters: dict[str, Any]) -> httpx.Response:
  110. """
  111. do http request depending on api bundle
  112. """
  113. method = method.lower()
  114. params = {}
  115. path_params = {}
  116. body = {}
  117. cookies = {}
  118. # check parameters
  119. for parameter in self.api_bundle.openapi.get('parameters', []):
  120. if parameter['in'] == 'path':
  121. value = ''
  122. if parameter['name'] in parameters:
  123. value = parameters[parameter['name']]
  124. elif parameter['required']:
  125. raise ToolParameterValidationError(f"Missing required parameter {parameter['name']}")
  126. else:
  127. value = (parameter.get('schema', {}) or {}).get('default', '')
  128. path_params[parameter['name']] = value
  129. elif parameter['in'] == 'query':
  130. value = ''
  131. if parameter['name'] in parameters:
  132. value = parameters[parameter['name']]
  133. elif parameter.get('required', False):
  134. raise ToolParameterValidationError(f"Missing required parameter {parameter['name']}")
  135. else:
  136. value = (parameter.get('schema', {}) or {}).get('default', '')
  137. params[parameter['name']] = value
  138. elif parameter['in'] == 'cookie':
  139. value = ''
  140. if parameter['name'] in parameters:
  141. value = parameters[parameter['name']]
  142. elif parameter.get('required', False):
  143. raise ToolParameterValidationError(f"Missing required parameter {parameter['name']}")
  144. else:
  145. value = (parameter.get('schema', {}) or {}).get('default', '')
  146. cookies[parameter['name']] = value
  147. elif parameter['in'] == 'header':
  148. value = ''
  149. if parameter['name'] in parameters:
  150. value = parameters[parameter['name']]
  151. elif parameter.get('required', False):
  152. raise ToolParameterValidationError(f"Missing required parameter {parameter['name']}")
  153. else:
  154. value = (parameter.get('schema', {}) or {}).get('default', '')
  155. headers[parameter['name']] = value
  156. # check if there is a request body and handle it
  157. if 'requestBody' in self.api_bundle.openapi and self.api_bundle.openapi['requestBody'] is not None:
  158. # handle json request body
  159. if 'content' in self.api_bundle.openapi['requestBody']:
  160. for content_type in self.api_bundle.openapi['requestBody']['content']:
  161. headers['Content-Type'] = content_type
  162. body_schema = self.api_bundle.openapi['requestBody']['content'][content_type]['schema']
  163. required = body_schema['required'] if 'required' in body_schema else []
  164. properties = body_schema['properties'] if 'properties' in body_schema else {}
  165. for name, property in properties.items():
  166. if name in parameters:
  167. # convert type
  168. body[name] = self._convert_body_property_type(property, parameters[name])
  169. elif name in required:
  170. raise ToolParameterValidationError(
  171. f"Missing required parameter {name} in operation {self.api_bundle.operation_id}"
  172. )
  173. elif 'default' in property:
  174. body[name] = property['default']
  175. else:
  176. body[name] = None
  177. break
  178. # replace path parameters
  179. for name, value in path_params.items():
  180. url = url.replace(f'{{{name}}}', f'{value}')
  181. # parse http body data if needed, for GET/HEAD/OPTIONS/TRACE, the body is ignored
  182. if 'Content-Type' in headers:
  183. if headers['Content-Type'] == 'application/json':
  184. body = dumps(body)
  185. elif headers['Content-Type'] == 'application/x-www-form-urlencoded':
  186. body = urlencode(body)
  187. else:
  188. body = body
  189. # do http request
  190. if method == 'get':
  191. response = ssrf_proxy.get(url, params=params, headers=headers, cookies=cookies, timeout=API_TOOL_DEFAULT_TIMEOUT, follow_redirects=True)
  192. elif method == 'post':
  193. response = ssrf_proxy.post(url, params=params, headers=headers, cookies=cookies, data=body, timeout=API_TOOL_DEFAULT_TIMEOUT, follow_redirects=True)
  194. elif method == 'put':
  195. response = ssrf_proxy.put(url, params=params, headers=headers, cookies=cookies, data=body, timeout=API_TOOL_DEFAULT_TIMEOUT, follow_redirects=True)
  196. elif method == 'delete':
  197. response = ssrf_proxy.delete(url, params=params, headers=headers, cookies=cookies, data=body, timeout=API_TOOL_DEFAULT_TIMEOUT, allow_redirects=True)
  198. elif method == 'patch':
  199. response = ssrf_proxy.patch(url, params=params, headers=headers, cookies=cookies, data=body, timeout=API_TOOL_DEFAULT_TIMEOUT, follow_redirects=True)
  200. elif method == 'head':
  201. response = ssrf_proxy.head(url, params=params, headers=headers, cookies=cookies, timeout=API_TOOL_DEFAULT_TIMEOUT, follow_redirects=True)
  202. elif method == 'options':
  203. response = ssrf_proxy.options(url, params=params, headers=headers, cookies=cookies, timeout=API_TOOL_DEFAULT_TIMEOUT, follow_redirects=True)
  204. else:
  205. raise ValueError(f'Invalid http method {method}')
  206. return response
  207. def _convert_body_property_any_of(self, property: dict[str, Any], value: Any, any_of: list[dict[str, Any]], max_recursive=10) -> Any:
  208. if max_recursive <= 0:
  209. raise Exception("Max recursion depth reached")
  210. for option in any_of or []:
  211. try:
  212. if 'type' in option:
  213. # Attempt to convert the value based on the type.
  214. if option['type'] == 'integer' or option['type'] == 'int':
  215. return int(value)
  216. elif option['type'] == 'number':
  217. if '.' in str(value):
  218. return float(value)
  219. else:
  220. return int(value)
  221. elif option['type'] == 'string':
  222. return str(value)
  223. elif option['type'] == 'boolean':
  224. if str(value).lower() in ['true', '1']:
  225. return True
  226. elif str(value).lower() in ['false', '0']:
  227. return False
  228. else:
  229. continue # Not a boolean, try next option
  230. elif option['type'] == 'null' and not value:
  231. return None
  232. else:
  233. continue # Unsupported type, try next option
  234. elif 'anyOf' in option and isinstance(option['anyOf'], list):
  235. # Recursive call to handle nested anyOf
  236. return self._convert_body_property_any_of(property, value, option['anyOf'], max_recursive - 1)
  237. except ValueError:
  238. continue # Conversion failed, try next option
  239. # If no option succeeded, you might want to return the value as is or raise an error
  240. return value # or raise ValueError(f"Cannot convert value '{value}' to any specified type in anyOf")
  241. def _convert_body_property_type(self, property: dict[str, Any], value: Any) -> Any:
  242. try:
  243. if 'type' in property:
  244. if property['type'] == 'integer' or property['type'] == 'int':
  245. return int(value)
  246. elif property['type'] == 'number':
  247. # check if it is a float
  248. if '.' in value:
  249. return float(value)
  250. else:
  251. return int(value)
  252. elif property['type'] == 'string':
  253. return str(value)
  254. elif property['type'] == 'boolean':
  255. return bool(value)
  256. elif property['type'] == 'null':
  257. if value is None:
  258. return None
  259. elif property['type'] == 'object':
  260. if isinstance(value, str):
  261. try:
  262. return json.loads(value)
  263. except ValueError:
  264. return value
  265. elif isinstance(value, dict):
  266. return value
  267. else:
  268. return value
  269. else:
  270. raise ValueError(f"Invalid type {property['type']} for property {property}")
  271. elif 'anyOf' in property and isinstance(property['anyOf'], list):
  272. return self._convert_body_property_any_of(property, value, property['anyOf'])
  273. except ValueError as e:
  274. return value
  275. def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
  276. """
  277. invoke http request
  278. """
  279. # assemble request
  280. headers = self.assembling_request(tool_parameters)
  281. # do http request
  282. response = self.do_http_request(self.api_bundle.server_url, self.api_bundle.method, headers, tool_parameters)
  283. # validate response
  284. response = self.validate_and_parse_response(response)
  285. # assemble invoke message
  286. return self.create_text_message(response)