configuration.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. from copy import deepcopy
  2. from typing import Any
  3. from pydantic import BaseModel
  4. from core.helper import encrypter
  5. from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType
  6. from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType
  7. from core.tools.entities.tool_entities import (
  8. ToolParameter,
  9. ToolProviderCredentials,
  10. )
  11. from core.tools.provider.tool_provider import ToolProviderController
  12. from core.tools.tool.tool import Tool
  13. class ToolConfigurationManager(BaseModel):
  14. tenant_id: str
  15. provider_controller: ToolProviderController
  16. def _deep_copy(self, credentials: dict[str, str]) -> dict[str, str]:
  17. """
  18. deep copy credentials
  19. """
  20. return deepcopy(credentials)
  21. def encrypt_tool_credentials(self, credentials: dict[str, str]) -> dict[str, str]:
  22. """
  23. encrypt tool credentials with tenant id
  24. return a deep copy of credentials with encrypted values
  25. """
  26. credentials = self._deep_copy(credentials)
  27. # get fields need to be decrypted
  28. fields = self.provider_controller.get_credentials_schema()
  29. for field_name, field in fields.items():
  30. if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT:
  31. if field_name in credentials:
  32. encrypted = encrypter.encrypt_token(self.tenant_id, credentials[field_name])
  33. credentials[field_name] = encrypted
  34. return credentials
  35. def mask_tool_credentials(self, credentials: dict[str, Any]) -> dict[str, Any]:
  36. """
  37. mask tool credentials
  38. return a deep copy of credentials with masked values
  39. """
  40. credentials = self._deep_copy(credentials)
  41. # get fields need to be decrypted
  42. fields = self.provider_controller.get_credentials_schema()
  43. for field_name, field in fields.items():
  44. if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT:
  45. if field_name in credentials:
  46. if len(credentials[field_name]) > 6:
  47. credentials[field_name] = (
  48. credentials[field_name][:2]
  49. + "*" * (len(credentials[field_name]) - 4)
  50. + credentials[field_name][-2:]
  51. )
  52. else:
  53. credentials[field_name] = "*" * len(credentials[field_name])
  54. return credentials
  55. def decrypt_tool_credentials(self, credentials: dict[str, str]) -> dict[str, str]:
  56. """
  57. decrypt tool credentials with tenant id
  58. return a deep copy of credentials with decrypted values
  59. """
  60. cache = ToolProviderCredentialsCache(
  61. tenant_id=self.tenant_id,
  62. identity_id=f"{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}",
  63. cache_type=ToolProviderCredentialsCacheType.PROVIDER,
  64. )
  65. cached_credentials = cache.get()
  66. if cached_credentials:
  67. return cached_credentials
  68. credentials = self._deep_copy(credentials)
  69. # get fields need to be decrypted
  70. fields = self.provider_controller.get_credentials_schema()
  71. for field_name, field in fields.items():
  72. if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT:
  73. if field_name in credentials:
  74. try:
  75. credentials[field_name] = encrypter.decrypt_token(self.tenant_id, credentials[field_name])
  76. except:
  77. pass
  78. cache.set(credentials)
  79. return credentials
  80. def delete_tool_credentials_cache(self):
  81. cache = ToolProviderCredentialsCache(
  82. tenant_id=self.tenant_id,
  83. identity_id=f"{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}",
  84. cache_type=ToolProviderCredentialsCacheType.PROVIDER,
  85. )
  86. cache.delete()
  87. class ToolParameterConfigurationManager(BaseModel):
  88. """
  89. Tool parameter configuration manager
  90. """
  91. tenant_id: str
  92. tool_runtime: Tool
  93. provider_name: str
  94. provider_type: str
  95. identity_id: str
  96. def _deep_copy(self, parameters: dict[str, Any]) -> dict[str, Any]:
  97. """
  98. deep copy parameters
  99. """
  100. return deepcopy(parameters)
  101. def _merge_parameters(self) -> list[ToolParameter]:
  102. """
  103. merge parameters
  104. """
  105. # get tool parameters
  106. tool_parameters = self.tool_runtime.parameters or []
  107. # get tool runtime parameters
  108. runtime_parameters = self.tool_runtime.get_runtime_parameters() or []
  109. # override parameters
  110. current_parameters = tool_parameters.copy()
  111. for runtime_parameter in runtime_parameters:
  112. found = False
  113. for index, parameter in enumerate(current_parameters):
  114. if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form:
  115. current_parameters[index] = runtime_parameter
  116. found = True
  117. break
  118. if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
  119. current_parameters.append(runtime_parameter)
  120. return current_parameters
  121. def mask_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
  122. """
  123. mask tool parameters
  124. return a deep copy of parameters with masked values
  125. """
  126. parameters = self._deep_copy(parameters)
  127. # override parameters
  128. current_parameters = self._merge_parameters()
  129. for parameter in current_parameters:
  130. if (
  131. parameter.form == ToolParameter.ToolParameterForm.FORM
  132. and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
  133. ):
  134. if parameter.name in parameters:
  135. if len(parameters[parameter.name]) > 6:
  136. parameters[parameter.name] = (
  137. parameters[parameter.name][:2]
  138. + "*" * (len(parameters[parameter.name]) - 4)
  139. + parameters[parameter.name][-2:]
  140. )
  141. else:
  142. parameters[parameter.name] = "*" * len(parameters[parameter.name])
  143. return parameters
  144. def encrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
  145. """
  146. encrypt tool parameters with tenant id
  147. return a deep copy of parameters with encrypted values
  148. """
  149. # override parameters
  150. current_parameters = self._merge_parameters()
  151. parameters = self._deep_copy(parameters)
  152. for parameter in current_parameters:
  153. if (
  154. parameter.form == ToolParameter.ToolParameterForm.FORM
  155. and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
  156. ):
  157. if parameter.name in parameters:
  158. encrypted = encrypter.encrypt_token(self.tenant_id, parameters[parameter.name])
  159. parameters[parameter.name] = encrypted
  160. return parameters
  161. def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
  162. """
  163. decrypt tool parameters with tenant id
  164. return a deep copy of parameters with decrypted values
  165. """
  166. cache = ToolParameterCache(
  167. tenant_id=self.tenant_id,
  168. provider=f"{self.provider_type}.{self.provider_name}",
  169. tool_name=self.tool_runtime.identity.name,
  170. cache_type=ToolParameterCacheType.PARAMETER,
  171. identity_id=self.identity_id,
  172. )
  173. cached_parameters = cache.get()
  174. if cached_parameters:
  175. return cached_parameters
  176. # override parameters
  177. current_parameters = self._merge_parameters()
  178. has_secret_input = False
  179. for parameter in current_parameters:
  180. if (
  181. parameter.form == ToolParameter.ToolParameterForm.FORM
  182. and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
  183. ):
  184. if parameter.name in parameters:
  185. try:
  186. has_secret_input = True
  187. parameters[parameter.name] = encrypter.decrypt_token(self.tenant_id, parameters[parameter.name])
  188. except:
  189. pass
  190. if has_secret_input:
  191. cache.set(parameters)
  192. return parameters
  193. def delete_tool_parameters_cache(self):
  194. cache = ToolParameterCache(
  195. tenant_id=self.tenant_id,
  196. provider=f"{self.provider_type}.{self.provider_name}",
  197. tool_name=self.tool_runtime.identity.name,
  198. cache_type=ToolParameterCacheType.PARAMETER,
  199. identity_id=self.identity_id,
  200. )
  201. cache.delete()