configuration.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. from copy import deepcopy
  2. from typing import Any
  3. from pydantic import BaseModel
  4. from core.helper import encrypter
  5. from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType
  6. from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType
  7. from core.tools.entities.tool_entities import (
  8. ToolParameter,
  9. ToolProviderCredentials,
  10. )
  11. from core.tools.provider.tool_provider import ToolProviderController
  12. from core.tools.tool.tool import Tool
  13. class ToolConfigurationManager(BaseModel):
  14. tenant_id: str
  15. provider_controller: ToolProviderController
  16. def _deep_copy(self, credentials: dict[str, str]) -> dict[str, str]:
  17. """
  18. deep copy credentials
  19. """
  20. return deepcopy(credentials)
  21. def encrypt_tool_credentials(self, credentials: dict[str, str]) -> dict[str, str]:
  22. """
  23. encrypt tool credentials with tenant id
  24. return a deep copy of credentials with encrypted values
  25. """
  26. credentials = self._deep_copy(credentials)
  27. # get fields need to be decrypted
  28. fields = self.provider_controller.get_credentials_schema()
  29. for field_name, field in fields.items():
  30. if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT:
  31. if field_name in credentials:
  32. encrypted = encrypter.encrypt_token(self.tenant_id, credentials[field_name])
  33. credentials[field_name] = encrypted
  34. return credentials
  35. def mask_tool_credentials(self, credentials: dict[str, Any]) -> dict[str, Any]:
  36. """
  37. mask tool credentials
  38. return a deep copy of credentials with masked values
  39. """
  40. credentials = self._deep_copy(credentials)
  41. # get fields need to be decrypted
  42. fields = self.provider_controller.get_credentials_schema()
  43. for field_name, field in fields.items():
  44. if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT:
  45. if field_name in credentials:
  46. if len(credentials[field_name]) > 6:
  47. credentials[field_name] = \
  48. credentials[field_name][:2] + \
  49. '*' * (len(credentials[field_name]) - 4) + \
  50. credentials[field_name][-2:]
  51. else:
  52. credentials[field_name] = '*' * len(credentials[field_name])
  53. return credentials
  54. def decrypt_tool_credentials(self, credentials: dict[str, str]) -> dict[str, str]:
  55. """
  56. decrypt tool credentials with tenant id
  57. return a deep copy of credentials with decrypted values
  58. """
  59. cache = ToolProviderCredentialsCache(
  60. tenant_id=self.tenant_id,
  61. identity_id=f'{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}',
  62. cache_type=ToolProviderCredentialsCacheType.PROVIDER
  63. )
  64. cached_credentials = cache.get()
  65. if cached_credentials:
  66. return cached_credentials
  67. credentials = self._deep_copy(credentials)
  68. # get fields need to be decrypted
  69. fields = self.provider_controller.get_credentials_schema()
  70. for field_name, field in fields.items():
  71. if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT:
  72. if field_name in credentials:
  73. try:
  74. credentials[field_name] = encrypter.decrypt_token(self.tenant_id, credentials[field_name])
  75. except:
  76. pass
  77. cache.set(credentials)
  78. return credentials
  79. def delete_tool_credentials_cache(self):
  80. cache = ToolProviderCredentialsCache(
  81. tenant_id=self.tenant_id,
  82. identity_id=f'{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}',
  83. cache_type=ToolProviderCredentialsCacheType.PROVIDER
  84. )
  85. cache.delete()
  86. class ToolParameterConfigurationManager(BaseModel):
  87. """
  88. Tool parameter configuration manager
  89. """
  90. tenant_id: str
  91. tool_runtime: Tool
  92. provider_name: str
  93. provider_type: str
  94. identity_id: str
  95. def _deep_copy(self, parameters: dict[str, Any]) -> dict[str, Any]:
  96. """
  97. deep copy parameters
  98. """
  99. return deepcopy(parameters)
  100. def _merge_parameters(self) -> list[ToolParameter]:
  101. """
  102. merge parameters
  103. """
  104. # get tool parameters
  105. tool_parameters = self.tool_runtime.parameters or []
  106. # get tool runtime parameters
  107. runtime_parameters = self.tool_runtime.get_runtime_parameters() or []
  108. # override parameters
  109. current_parameters = tool_parameters.copy()
  110. for runtime_parameter in runtime_parameters:
  111. found = False
  112. for index, parameter in enumerate(current_parameters):
  113. if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form:
  114. current_parameters[index] = runtime_parameter
  115. found = True
  116. break
  117. if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
  118. current_parameters.append(runtime_parameter)
  119. return current_parameters
  120. def mask_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
  121. """
  122. mask tool parameters
  123. return a deep copy of parameters with masked values
  124. """
  125. parameters = self._deep_copy(parameters)
  126. # override parameters
  127. current_parameters = self._merge_parameters()
  128. for parameter in current_parameters:
  129. if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT:
  130. if parameter.name in parameters:
  131. if len(parameters[parameter.name]) > 6:
  132. parameters[parameter.name] = \
  133. parameters[parameter.name][:2] + \
  134. '*' * (len(parameters[parameter.name]) - 4) + \
  135. parameters[parameter.name][-2:]
  136. else:
  137. parameters[parameter.name] = '*' * len(parameters[parameter.name])
  138. return parameters
  139. def encrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
  140. """
  141. encrypt tool parameters with tenant id
  142. return a deep copy of parameters with encrypted values
  143. """
  144. # override parameters
  145. current_parameters = self._merge_parameters()
  146. parameters = self._deep_copy(parameters)
  147. for parameter in current_parameters:
  148. if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT:
  149. if parameter.name in parameters:
  150. encrypted = encrypter.encrypt_token(self.tenant_id, parameters[parameter.name])
  151. parameters[parameter.name] = encrypted
  152. return parameters
  153. def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
  154. """
  155. decrypt tool parameters with tenant id
  156. return a deep copy of parameters with decrypted values
  157. """
  158. cache = ToolParameterCache(
  159. tenant_id=self.tenant_id,
  160. provider=f'{self.provider_type}.{self.provider_name}',
  161. tool_name=self.tool_runtime.identity.name,
  162. cache_type=ToolParameterCacheType.PARAMETER,
  163. identity_id=self.identity_id
  164. )
  165. cached_parameters = cache.get()
  166. if cached_parameters:
  167. return cached_parameters
  168. # override parameters
  169. current_parameters = self._merge_parameters()
  170. has_secret_input = False
  171. for parameter in current_parameters:
  172. if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT:
  173. if parameter.name in parameters:
  174. try:
  175. has_secret_input = True
  176. parameters[parameter.name] = encrypter.decrypt_token(self.tenant_id, parameters[parameter.name])
  177. except:
  178. pass
  179. if has_secret_input:
  180. cache.set(parameters)
  181. return parameters
  182. def delete_tool_parameters_cache(self):
  183. cache = ToolParameterCache(
  184. tenant_id=self.tenant_id,
  185. provider=f'{self.provider_type}.{self.provider_name}',
  186. tool_name=self.tool_runtime.identity.name,
  187. cache_type=ToolParameterCacheType.PARAMETER,
  188. identity_id=self.identity_id
  189. )
  190. cache.delete()