configuration.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290
  1. import os
  2. from copy import deepcopy
  3. from typing import Any, Union
  4. from pydantic import BaseModel
  5. from yaml import FullLoader, load
  6. from core.helper import encrypter
  7. from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType
  8. from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType
  9. from core.tools.entities.tool_entities import (
  10. ModelToolConfiguration,
  11. ModelToolProviderConfiguration,
  12. ToolParameter,
  13. ToolProviderCredentials,
  14. )
  15. from core.tools.provider.tool_provider import ToolProviderController
  16. from core.tools.tool.tool import Tool
  17. class ToolConfigurationManager(BaseModel):
  18. tenant_id: str
  19. provider_controller: ToolProviderController
  20. def _deep_copy(self, credentials: dict[str, str]) -> dict[str, str]:
  21. """
  22. deep copy credentials
  23. """
  24. return deepcopy(credentials)
  25. def encrypt_tool_credentials(self, credentials: dict[str, str]) -> dict[str, str]:
  26. """
  27. encrypt tool credentials with tenant id
  28. return a deep copy of credentials with encrypted values
  29. """
  30. credentials = self._deep_copy(credentials)
  31. # get fields need to be decrypted
  32. fields = self.provider_controller.get_credentials_schema()
  33. for field_name, field in fields.items():
  34. if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT:
  35. if field_name in credentials:
  36. encrypted = encrypter.encrypt_token(self.tenant_id, credentials[field_name])
  37. credentials[field_name] = encrypted
  38. return credentials
  39. def mask_tool_credentials(self, credentials: dict[str, Any]) -> dict[str, Any]:
  40. """
  41. mask tool credentials
  42. return a deep copy of credentials with masked values
  43. """
  44. credentials = self._deep_copy(credentials)
  45. # get fields need to be decrypted
  46. fields = self.provider_controller.get_credentials_schema()
  47. for field_name, field in fields.items():
  48. if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT:
  49. if field_name in credentials:
  50. if len(credentials[field_name]) > 6:
  51. credentials[field_name] = \
  52. credentials[field_name][:2] + \
  53. '*' * (len(credentials[field_name]) - 4) +\
  54. credentials[field_name][-2:]
  55. else:
  56. credentials[field_name] = '*' * len(credentials[field_name])
  57. return credentials
  58. def decrypt_tool_credentials(self, credentials: dict[str, str]) -> dict[str, str]:
  59. """
  60. decrypt tool credentials with tenant id
  61. return a deep copy of credentials with decrypted values
  62. """
  63. cache = ToolProviderCredentialsCache(
  64. tenant_id=self.tenant_id,
  65. identity_id=f'{self.provider_controller.app_type.value}.{self.provider_controller.identity.name}',
  66. cache_type=ToolProviderCredentialsCacheType.PROVIDER
  67. )
  68. cached_credentials = cache.get()
  69. if cached_credentials:
  70. return cached_credentials
  71. credentials = self._deep_copy(credentials)
  72. # get fields need to be decrypted
  73. fields = self.provider_controller.get_credentials_schema()
  74. for field_name, field in fields.items():
  75. if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT:
  76. if field_name in credentials:
  77. try:
  78. credentials[field_name] = encrypter.decrypt_token(self.tenant_id, credentials[field_name])
  79. except:
  80. pass
  81. cache.set(credentials)
  82. return credentials
  83. def delete_tool_credentials_cache(self):
  84. cache = ToolProviderCredentialsCache(
  85. tenant_id=self.tenant_id,
  86. identity_id=f'{self.provider_controller.app_type.value}.{self.provider_controller.identity.name}',
  87. cache_type=ToolProviderCredentialsCacheType.PROVIDER
  88. )
  89. cache.delete()
  90. class ToolParameterConfigurationManager(BaseModel):
  91. """
  92. Tool parameter configuration manager
  93. """
  94. tenant_id: str
  95. tool_runtime: Tool
  96. provider_name: str
  97. provider_type: str
  98. def _deep_copy(self, parameters: dict[str, Any]) -> dict[str, Any]:
  99. """
  100. deep copy parameters
  101. """
  102. return {key: value for key, value in parameters.items()}
  103. def _merge_parameters(self) -> list[ToolParameter]:
  104. """
  105. merge parameters
  106. """
  107. # get tool parameters
  108. tool_parameters = self.tool_runtime.parameters or []
  109. # get tool runtime parameters
  110. runtime_parameters = self.tool_runtime.get_runtime_parameters() or []
  111. # override parameters
  112. current_parameters = tool_parameters.copy()
  113. for runtime_parameter in runtime_parameters:
  114. found = False
  115. for index, parameter in enumerate(current_parameters):
  116. if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form:
  117. current_parameters[index] = runtime_parameter
  118. found = True
  119. break
  120. if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
  121. current_parameters.append(runtime_parameter)
  122. return current_parameters
  123. def mask_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
  124. """
  125. mask tool parameters
  126. return a deep copy of parameters with masked values
  127. """
  128. parameters = self._deep_copy(parameters)
  129. # override parameters
  130. current_parameters = self._merge_parameters()
  131. for parameter in current_parameters:
  132. if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT:
  133. if parameter.name in parameters:
  134. if len(parameters[parameter.name]) > 6:
  135. parameters[parameter.name] = \
  136. parameters[parameter.name][:2] + \
  137. '*' * (len(parameters[parameter.name]) - 4) +\
  138. parameters[parameter.name][-2:]
  139. else:
  140. parameters[parameter.name] = '*' * len(parameters[parameter.name])
  141. return parameters
  142. def encrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
  143. """
  144. encrypt tool parameters with tenant id
  145. return a deep copy of parameters with encrypted values
  146. """
  147. # override parameters
  148. current_parameters = self._merge_parameters()
  149. for parameter in current_parameters:
  150. if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT:
  151. if parameter.name in parameters:
  152. encrypted = encrypter.encrypt_token(self.tenant_id, parameters[parameter.name])
  153. parameters[parameter.name] = encrypted
  154. return parameters
  155. def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
  156. """
  157. decrypt tool parameters with tenant id
  158. return a deep copy of parameters with decrypted values
  159. """
  160. cache = ToolParameterCache(
  161. tenant_id=self.tenant_id,
  162. provider=f'{self.provider_type}.{self.provider_name}',
  163. tool_name=self.tool_runtime.identity.name,
  164. cache_type=ToolParameterCacheType.PARAMETER
  165. )
  166. cached_parameters = cache.get()
  167. if cached_parameters:
  168. return cached_parameters
  169. # override parameters
  170. current_parameters = self._merge_parameters()
  171. has_secret_input = False
  172. for parameter in current_parameters:
  173. if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT:
  174. if parameter.name in parameters:
  175. try:
  176. has_secret_input = True
  177. parameters[parameter.name] = encrypter.decrypt_token(self.tenant_id, parameters[parameter.name])
  178. except:
  179. pass
  180. if has_secret_input:
  181. cache.set(parameters)
  182. return parameters
  183. def delete_tool_parameters_cache(self):
  184. cache = ToolParameterCache(
  185. tenant_id=self.tenant_id,
  186. provider=f'{self.provider_type}.{self.provider_name}',
  187. tool_name=self.tool_runtime.identity.name,
  188. cache_type=ToolParameterCacheType.PARAMETER
  189. )
  190. cache.delete()
  191. class ModelToolConfigurationManager:
  192. """
  193. Model as tool configuration
  194. """
  195. _configurations: dict[str, ModelToolProviderConfiguration] = {}
  196. _model_configurations: dict[str, ModelToolConfiguration] = {}
  197. _inited = False
  198. @classmethod
  199. def _init_configuration(cls):
  200. """
  201. init configuration
  202. """
  203. absolute_path = os.path.abspath(os.path.dirname(__file__))
  204. model_tools_path = os.path.join(absolute_path, '..', 'model_tools')
  205. # get all .yaml file
  206. files = [f for f in os.listdir(model_tools_path) if f.endswith('.yaml')]
  207. for file in files:
  208. provider = file.split('.')[0]
  209. with open(os.path.join(model_tools_path, file), encoding='utf-8') as f:
  210. configurations = ModelToolProviderConfiguration(**load(f, Loader=FullLoader))
  211. models = configurations.models or []
  212. for model in models:
  213. model_key = f'{provider}.{model.model}'
  214. cls._model_configurations[model_key] = model
  215. cls._configurations[provider] = configurations
  216. cls._inited = True
  217. @classmethod
  218. def get_configuration(cls, provider: str) -> Union[ModelToolProviderConfiguration, None]:
  219. """
  220. get configuration by provider
  221. """
  222. if not cls._inited:
  223. cls._init_configuration()
  224. return cls._configurations.get(provider, None)
  225. @classmethod
  226. def get_all_configuration(cls) -> dict[str, ModelToolProviderConfiguration]:
  227. """
  228. get all configurations
  229. """
  230. if not cls._inited:
  231. cls._init_configuration()
  232. return cls._configurations
  233. @classmethod
  234. def get_model_configuration(cls, provider: str, model: str) -> Union[ModelToolConfiguration, None]:
  235. """
  236. get model configuration
  237. """
  238. key = f'{provider}.{model}'
  239. if not cls._inited:
  240. cls._init_configuration()
  241. return cls._model_configurations.get(key, None)