Sfoglia il codice sorgente

fix: test custom tool already exists without decrypting credentials (#2668)

Yeuoly 1 anno fa
parent
commit
36686d7425

+ 2 - 0
api/controllers/console/workspace/tool_providers.py

@@ -259,6 +259,7 @@ class ToolApiProviderPreviousTestApi(Resource):
         parser = reqparse.RequestParser()
 
         parser.add_argument('tool_name', type=str, required=True, nullable=False, location='json')
+        parser.add_argument('provider_name', type=str, required=False, nullable=False, location='json')
         parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
         parser.add_argument('parameters', type=dict, required=True, nullable=False, location='json')
         parser.add_argument('schema_type', type=str, required=True, nullable=False, location='json')
@@ -268,6 +269,7 @@ class ToolApiProviderPreviousTestApi(Resource):
 
         return ToolManageService.test_api_tool_preview(
             current_user.current_tenant_id,
+            args['provider_name'] if args['provider_name'] else '',
             args['tool_name'],
             args['credentials'],
             args['parameters'],

+ 3 - 0
api/core/tools/tool/api_tool.py

@@ -1,6 +1,7 @@
 import json
 from json import dumps
 from typing import Any, Union
+from urllib.parse import urlencode
 
 import httpx
 import requests
@@ -203,6 +204,8 @@ class ApiTool(Tool):
         if 'Content-Type' in headers:
             if headers['Content-Type'] == 'application/json':
                 body = dumps(body)
+            elif headers['Content-Type'] == 'application/x-www-form-urlencoded':
+                body = urlencode(body)
             else:
                 body = body
         

+ 35 - 12
api/services/tools_manage_service.py

@@ -498,12 +498,16 @@ class ToolManageService:
     
     @staticmethod
     def test_api_tool_preview(
-        tenant_id: str, tool_name: str, credentials: dict, parameters: dict, schema_type: str, schema: str
+        tenant_id: str, 
+        provider_name: str,
+        tool_name: str, 
+        credentials: dict, 
+        parameters: dict, 
+        schema_type: str, 
+        schema: str
     ):
         """
             test api tool before adding api tool provider
-
-            1. parse schema into tool bundle
         """
         if schema_type not in [member.value for member in ApiProviderSchemaType]:
             raise ValueError(f'invalid schema type {schema_type}')
@@ -518,15 +522,21 @@ class ToolManageService:
         if tool_bundle is None:
             raise ValueError(f'invalid tool name {tool_name}')
         
-        # create a fake db provider
-        db_provider = ApiToolProvider(
-            tenant_id='', user_id='', name='', icon='',
-            schema=schema,
-            description='',
-            schema_type_str=ApiProviderSchemaType.OPENAPI.value,
-            tools_str=serialize_base_model_array(tool_bundles),
-            credentials_str=json.dumps(credentials),
-        )
+        db_provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
+            ApiToolProvider.tenant_id == tenant_id,
+            ApiToolProvider.name == provider_name,
+        ).first()
+
+        if not db_provider:
+            # create a fake db provider
+            db_provider = ApiToolProvider(
+                tenant_id='', user_id='', name='', icon='',
+                schema=schema,
+                description='',
+                schema_type_str=ApiProviderSchemaType.OPENAPI.value,
+                tools_str=serialize_base_model_array(tool_bundles),
+                credentials_str=json.dumps(credentials),
+            )
 
         if 'auth_type' not in credentials:
             raise ValueError('auth_type is required')
@@ -539,6 +549,19 @@ class ToolManageService:
         # load tools into provider entity
         provider_controller.load_bundled_tools(tool_bundles)
 
+        # decrypt credentials
+        if db_provider.id:
+            tool_configuration = ToolConfiguration(
+                tenant_id=tenant_id, 
+                provider_controller=provider_controller
+            )
+            decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
+            # check if the credential has changed, save the original credential
+            masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
+            for name, value in credentials.items():
+                if name in masked_credentials and value == masked_credentials[name]:
+                    credentials[name] = decrypted_credentials[name]
+
         try:
             provider_controller.validate_credentials_format(credentials)
             # get tool

+ 1 - 0
web/app/components/tools/edit-custom-collection-modal/test-api.tsx

@@ -42,6 +42,7 @@ const TestApi: FC<Props> = ({
       delete credentials.api_key_value
     }
     const data = {
+      provider_name: customCollection.provider,
       tool_name: toolName,
       credentials,
       schema_type: customCollection.schema_type,