|
@@ -498,12 +498,16 @@ class ToolManageService:
|
|
|
|
|
|
@staticmethod
|
|
|
def test_api_tool_preview(
|
|
|
- tenant_id: str, tool_name: str, credentials: dict, parameters: dict, schema_type: str, schema: str
|
|
|
+ tenant_id: str,
|
|
|
+ provider_name: str,
|
|
|
+ tool_name: str,
|
|
|
+ credentials: dict,
|
|
|
+ parameters: dict,
|
|
|
+ schema_type: str,
|
|
|
+ schema: str
|
|
|
):
|
|
|
"""
|
|
|
test api tool before adding api tool provider
|
|
|
-
|
|
|
- 1. parse schema into tool bundle
|
|
|
"""
|
|
|
if schema_type not in [member.value for member in ApiProviderSchemaType]:
|
|
|
raise ValueError(f'invalid schema type {schema_type}')
|
|
@@ -518,15 +522,21 @@ class ToolManageService:
|
|
|
if tool_bundle is None:
|
|
|
raise ValueError(f'invalid tool name {tool_name}')
|
|
|
|
|
|
- # create a fake db provider
|
|
|
- db_provider = ApiToolProvider(
|
|
|
- tenant_id='', user_id='', name='', icon='',
|
|
|
- schema=schema,
|
|
|
- description='',
|
|
|
- schema_type_str=ApiProviderSchemaType.OPENAPI.value,
|
|
|
- tools_str=serialize_base_model_array(tool_bundles),
|
|
|
- credentials_str=json.dumps(credentials),
|
|
|
- )
|
|
|
+ db_provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
|
|
|
+ ApiToolProvider.tenant_id == tenant_id,
|
|
|
+ ApiToolProvider.name == provider_name,
|
|
|
+ ).first()
|
|
|
+
|
|
|
+ if not db_provider:
|
|
|
+ # create a fake db provider
|
|
|
+ db_provider = ApiToolProvider(
|
|
|
+ tenant_id='', user_id='', name='', icon='',
|
|
|
+ schema=schema,
|
|
|
+ description='',
|
|
|
+ schema_type_str=ApiProviderSchemaType.OPENAPI.value,
|
|
|
+ tools_str=serialize_base_model_array(tool_bundles),
|
|
|
+ credentials_str=json.dumps(credentials),
|
|
|
+ )
|
|
|
|
|
|
if 'auth_type' not in credentials:
|
|
|
raise ValueError('auth_type is required')
|
|
@@ -539,6 +549,19 @@ class ToolManageService:
|
|
|
# load tools into provider entity
|
|
|
provider_controller.load_bundled_tools(tool_bundles)
|
|
|
|
|
|
+ # decrypt credentials
|
|
|
+ if db_provider.id:
|
|
|
+ tool_configuration = ToolConfiguration(
|
|
|
+ tenant_id=tenant_id,
|
|
|
+ provider_controller=provider_controller
|
|
|
+ )
|
|
|
+ decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
|
|
|
+ # check if the credential has changed, save the original credential
|
|
|
+ masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
|
|
|
+ for name, value in credentials.items():
|
|
|
+ if name in masked_credentials and value == masked_credentials[name]:
|
|
|
+ credentials[name] = decrypted_credentials[name]
|
|
|
+
|
|
|
try:
|
|
|
provider_controller.validate_credentials_format(credentials)
|
|
|
# get tool
|