Browse Source

fix: Fix some type error in http executor. (#5915)

-LAN- 9 months ago
parent
commit
02982df0d4

+ 40 - 36
api/core/workflow/nodes/http_request/entities.py

@@ -9,49 +9,53 @@ MAX_CONNECT_TIMEOUT = int(os.environ.get('HTTP_REQUEST_MAX_CONNECT_TIMEOUT', '30
 MAX_READ_TIMEOUT = int(os.environ.get('HTTP_REQUEST_MAX_READ_TIMEOUT', '600'))
 MAX_WRITE_TIMEOUT = int(os.environ.get('HTTP_REQUEST_MAX_WRITE_TIMEOUT', '600'))
 
+
+class HttpRequestNodeAuthorizationConfig(BaseModel):
+    type: Literal[None, 'basic', 'bearer', 'custom']
+    api_key: Union[None, str] = None
+    header: Union[None, str] = None
+
+
+class HttpRequestNodeAuthorization(BaseModel):
+    type: Literal['no-auth', 'api-key']
+    config: Optional[HttpRequestNodeAuthorizationConfig] = None
+
+    @field_validator('config', mode='before')
+    @classmethod
+    def check_config(cls, v: HttpRequestNodeAuthorizationConfig, values: ValidationInfo):
+        """
+        Check config, if type is no-auth, config should be None, otherwise it should be a dict.
+        """
+        if values.data['type'] == 'no-auth':
+            return None
+        else:
+            if not v or not isinstance(v, dict):
+                raise ValueError('config should be a dict')
+
+            return v
+
+
+class HttpRequestNodeBody(BaseModel):
+    type: Literal['none', 'form-data', 'x-www-form-urlencoded', 'raw-text', 'json']
+    data: Union[None, str] = None
+
+
+class HttpRequestNodeTimeout(BaseModel):
+    connect: int = MAX_CONNECT_TIMEOUT
+    read: int = MAX_READ_TIMEOUT
+    write: int = MAX_WRITE_TIMEOUT
+
+
 class HttpRequestNodeData(BaseNodeData):
     """
     Code Node Data.
     """
-    class Authorization(BaseModel):
-        # TODO[pydantic]: The `Config` class inherits from another class, please create the `model_config` manually.
-        # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information.
-        class Config(BaseModel):
-            type: Literal[None, 'basic', 'bearer', 'custom']
-            api_key: Union[None, str] = None
-            header: Union[None, str] = None
-
-        type: Literal['no-auth', 'api-key']
-        config: Optional[Config] = None
-
-        @field_validator('config', mode='before')
-        @classmethod
-        def check_config(cls, v: Config, values: ValidationInfo):
-            """
-            Check config, if type is no-auth, config should be None, otherwise it should be a dict.
-            """
-            if values.data['type'] == 'no-auth':
-                return None
-            else:
-                if not v or not isinstance(v, dict):
-                    raise ValueError('config should be a dict')
-                
-                return v
-
-    class Body(BaseModel):
-        type: Literal['none', 'form-data', 'x-www-form-urlencoded', 'raw-text', 'json']
-        data: Union[None, str] = None
-
-    class Timeout(BaseModel):
-        connect: Optional[int] = MAX_CONNECT_TIMEOUT
-        read:  Optional[int] = MAX_READ_TIMEOUT
-        write:  Optional[int] = MAX_WRITE_TIMEOUT
 
     method: Literal['get', 'post', 'put', 'patch', 'delete', 'head']
     url: str
-    authorization: Authorization
+    authorization: HttpRequestNodeAuthorization
     headers: str
     params: str
-    body: Optional[Body] = None
-    timeout: Optional[Timeout] = None
+    body: Optional[HttpRequestNodeBody] = None
+    timeout: Optional[HttpRequestNodeTimeout] = None
     mask_authorization_header: Optional[bool] = True

+ 44 - 26
api/core/workflow/nodes/http_request/http_executor.py

@@ -10,7 +10,12 @@ import httpx
 import core.helper.ssrf_proxy as ssrf_proxy
 from core.workflow.entities.variable_entities import VariableSelector
 from core.workflow.entities.variable_pool import ValueType, VariablePool
-from core.workflow.nodes.http_request.entities import HttpRequestNodeData
+from core.workflow.nodes.http_request.entities import (
+    HttpRequestNodeAuthorization,
+    HttpRequestNodeBody,
+    HttpRequestNodeData,
+    HttpRequestNodeTimeout,
+)
 from core.workflow.utils.variable_template_parser import VariableTemplateParser
 
 MAX_BINARY_SIZE = int(os.environ.get('HTTP_REQUEST_NODE_MAX_BINARY_SIZE', 1024 * 1024 * 10))  # 10MB
@@ -23,7 +28,7 @@ class HttpExecutorResponse:
     headers: dict[str, str]
     response: httpx.Response
 
-    def __init__(self, response: httpx.Response = None):
+    def __init__(self, response: httpx.Response):
         self.response = response
         self.headers = dict(response.headers) if isinstance(self.response, httpx.Response) else {}
 
@@ -40,7 +45,6 @@ class HttpExecutorResponse:
     def get_content_type(self) -> str:
         return self.headers.get('content-type', '')
 
-
     def extract_file(self) -> tuple[str, bytes]:
         """
         extract file from response if content type is file related
@@ -88,17 +92,21 @@ class HttpExecutorResponse:
 class HttpExecutor:
     server_url: str
     method: str
-    authorization: HttpRequestNodeData.Authorization
+    authorization: HttpRequestNodeAuthorization
     params: dict[str, Any]
     headers: dict[str, Any]
     body: Union[None, str]
     files: Union[None, dict[str, Any]]
     boundary: str
     variable_selectors: list[VariableSelector]
-    timeout: HttpRequestNodeData.Timeout
-
-    def __init__(self, node_data: HttpRequestNodeData, timeout: HttpRequestNodeData.Timeout,
-                 variable_pool: Optional[VariablePool] = None):
+    timeout: HttpRequestNodeTimeout
+
+    def __init__(
+        self,
+        node_data: HttpRequestNodeData,
+        timeout: HttpRequestNodeTimeout,
+        variable_pool: Optional[VariablePool] = None,
+    ):
         self.server_url = node_data.url
         self.method = node_data.method
         self.authorization = node_data.authorization
@@ -113,11 +121,11 @@ class HttpExecutor:
         self._init_template(node_data, variable_pool)
 
     @staticmethod
-    def _is_json_body(body: HttpRequestNodeData.Body):
+    def _is_json_body(body: HttpRequestNodeBody):
         """
         check if body is json
         """
-        if body and body.type == 'json':
+        if body and body.type == 'json' and body.data:
             try:
                 json.loads(body.data)
                 return True
@@ -146,7 +154,6 @@ class HttpExecutor:
         return result
 
     def _init_template(self, node_data: HttpRequestNodeData, variable_pool: Optional[VariablePool] = None):
-
         # extract all template in url
         self.server_url, server_url_variable_selectors = self._format_template(node_data.url, variable_pool)
 
@@ -178,9 +185,7 @@ class HttpExecutor:
                 body = self._to_dict(body_data)
 
                 if node_data.body.type == 'form-data':
-                    self.files = {
-                        k: ('', v) for k, v in body.items()
-                    }
+                    self.files = {k: ('', v) for k, v in body.items()}
                     random_str = lambda n: ''.join([chr(randint(97, 122)) for _ in range(n)])
                     self.boundary = f'----WebKitFormBoundary{random_str(16)}'
 
@@ -192,13 +197,24 @@ class HttpExecutor:
             elif node_data.body.type == 'none':
                 self.body = ''
 
-        self.variable_selectors = (server_url_variable_selectors + params_variable_selectors
-                                   + headers_variable_selectors + body_data_variable_selectors)
+        self.variable_selectors = (
+            server_url_variable_selectors
+            + params_variable_selectors
+            + headers_variable_selectors
+            + body_data_variable_selectors
+        )
 
     def _assembling_headers(self) -> dict[str, Any]:
         authorization = deepcopy(self.authorization)
         headers = deepcopy(self.headers) or {}
         if self.authorization.type == 'api-key':
+            if self.authorization.config is None:
+                raise ValueError('self.authorization config is required')
+            if authorization.config is None:
+                raise ValueError('authorization config is required')
+            if authorization.config.header is None:
+                raise ValueError('authorization config header is required')
+
             if self.authorization.config.api_key is None:
                 raise ValueError('api_key is required')
 
@@ -216,7 +232,7 @@ class HttpExecutor:
 
     def _validate_and_parse_response(self, response: httpx.Response) -> HttpExecutorResponse:
         """
-            validate the response
+        validate the response
         """
         if isinstance(response, httpx.Response):
             executor_response = HttpExecutorResponse(response)
@@ -226,24 +242,26 @@ class HttpExecutor:
         if executor_response.is_file:
             if executor_response.size > MAX_BINARY_SIZE:
                 raise ValueError(
-                    f'File size is too large, max size is {READABLE_MAX_BINARY_SIZE}, but current size is {executor_response.readable_size}.')
+                    f'File size is too large, max size is {READABLE_MAX_BINARY_SIZE}, but current size is {executor_response.readable_size}.'
+                )
         else:
             if executor_response.size > MAX_TEXT_SIZE:
                 raise ValueError(
-                    f'Text size is too large, max size is {READABLE_MAX_TEXT_SIZE}, but current size is {executor_response.readable_size}.')
+                    f'Text size is too large, max size is {READABLE_MAX_TEXT_SIZE}, but current size is {executor_response.readable_size}.'
+                )
 
         return executor_response
 
     def _do_http_request(self, headers: dict[str, Any]) -> httpx.Response:
         """
-            do http request depending on api bundle
+        do http request depending on api bundle
         """
         kwargs = {
             'url': self.server_url,
             'headers': headers,
             'params': self.params,
             'timeout': (self.timeout.connect, self.timeout.read, self.timeout.write),
-            'follow_redirects': True
+            'follow_redirects': True,
         }
 
         if self.method in ('get', 'head', 'post', 'put', 'delete', 'patch'):
@@ -306,8 +324,9 @@ class HttpExecutor:
 
         return raw_request
 
-    def _format_template(self, template: str, variable_pool: VariablePool, escape_quotes: bool = False) \
-            -> tuple[str, list[VariableSelector]]:
+    def _format_template(
+        self, template: str, variable_pool: Optional[VariablePool], escape_quotes: bool = False
+    ) -> tuple[str, list[VariableSelector]]:
         """
         format template
         """
@@ -318,14 +337,13 @@ class HttpExecutor:
             variable_value_mapping = {}
             for variable_selector in variable_selectors:
                 value = variable_pool.get_variable_value(
-                    variable_selector=variable_selector.value_selector,
-                    target_value_type=ValueType.STRING
+                    variable_selector=variable_selector.value_selector, target_value_type=ValueType.STRING
                 )
 
                 if value is None:
                     raise ValueError(f'Variable {variable_selector.variable} not found')
 
-                if escape_quotes:
+                if escape_quotes and isinstance(value, str):
                     value = value.replace('"', '\\"')
 
                 variable_value_mapping[variable_selector.variable] = value

+ 49 - 47
api/core/workflow/nodes/http_request/http_request_node.py

@@ -5,6 +5,7 @@ from typing import cast
 
 from core.file.file_obj import FileTransferMethod, FileType, FileVar
 from core.tools.tool_file_manager import ToolFileManager
+from core.workflow.entities.base_node_data_entities import BaseNodeData
 from core.workflow.entities.node_entities import NodeRunResult, NodeType
 from core.workflow.entities.variable_pool import VariablePool
 from core.workflow.nodes.base_node import BaseNode
@@ -13,49 +14,50 @@ from core.workflow.nodes.http_request.entities import (
     MAX_READ_TIMEOUT,
     MAX_WRITE_TIMEOUT,
     HttpRequestNodeData,
+    HttpRequestNodeTimeout,
 )
 from core.workflow.nodes.http_request.http_executor import HttpExecutor, HttpExecutorResponse
 from models.workflow import WorkflowNodeExecutionStatus
 
-HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeData.Timeout(connect=min(10, MAX_CONNECT_TIMEOUT),
-                                                           read=min(60, MAX_READ_TIMEOUT),
-                                                           write=min(20, MAX_WRITE_TIMEOUT))
+HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout(
+    connect=min(10, MAX_CONNECT_TIMEOUT),
+    read=min(60, MAX_READ_TIMEOUT),
+    write=min(20, MAX_WRITE_TIMEOUT),
+)
 
 
 class HttpRequestNode(BaseNode):
     _node_data_cls = HttpRequestNodeData
-    node_type = NodeType.HTTP_REQUEST
+    _node_type = NodeType.HTTP_REQUEST
 
     @classmethod
-    def get_default_config(cls) -> dict:
+    def get_default_config(cls, filters: dict | None = None) -> dict:
         return {
-            "type": "http-request",
-            "config": {
-                "method": "get",
-                "authorization": {
-                    "type": "no-auth",
-                },
-                "body": {
-                    "type": "none"
+            'type': 'http-request',
+            'config': {
+                'method': 'get',
+                'authorization': {
+                    'type': 'no-auth',
                 },
-                "timeout": {
+                'body': {'type': 'none'},
+                'timeout': {
                     **HTTP_REQUEST_DEFAULT_TIMEOUT.model_dump(),
-                    "max_connect_timeout": MAX_CONNECT_TIMEOUT,
-                    "max_read_timeout": MAX_READ_TIMEOUT,
-                    "max_write_timeout": MAX_WRITE_TIMEOUT,
-                }
+                    'max_connect_timeout': MAX_CONNECT_TIMEOUT,
+                    'max_read_timeout': MAX_READ_TIMEOUT,
+                    'max_write_timeout': MAX_WRITE_TIMEOUT,
+                },
             },
         }
 
     def _run(self, variable_pool: VariablePool) -> NodeRunResult:
-        node_data: HttpRequestNodeData = cast(self._node_data_cls, self.node_data)
+        node_data: HttpRequestNodeData = cast(HttpRequestNodeData, self.node_data)
 
         # init http executor
         http_executor = None
         try:
-            http_executor = HttpExecutor(node_data=node_data,
-                                         timeout=self._get_request_timeout(node_data),
-                                         variable_pool=variable_pool)
+            http_executor = HttpExecutor(
+                node_data=node_data, timeout=self._get_request_timeout(node_data), variable_pool=variable_pool
+            )
 
             # invoke http executor
             response = http_executor.invoke()
@@ -70,7 +72,7 @@ class HttpRequestNode(BaseNode):
             return NodeRunResult(
                 status=WorkflowNodeExecutionStatus.FAILED,
                 error=str(e),
-                process_data=process_data
+                process_data=process_data,
             )
 
         files = self.extract_files(http_executor.server_url, response)
@@ -85,34 +87,32 @@ class HttpRequestNode(BaseNode):
             },
             process_data={
                 'request': http_executor.to_raw_request(
-                    mask_authorization_header=node_data.mask_authorization_header
+                    mask_authorization_header=node_data.mask_authorization_header,
                 ),
-            }
+            },
         )
 
-    def _get_request_timeout(self, node_data: HttpRequestNodeData) -> HttpRequestNodeData.Timeout:
+    def _get_request_timeout(self, node_data: HttpRequestNodeData) -> HttpRequestNodeTimeout:
         timeout = node_data.timeout
         if timeout is None:
             return HTTP_REQUEST_DEFAULT_TIMEOUT
 
-        if timeout.connect is None:
-            timeout.connect = HTTP_REQUEST_DEFAULT_TIMEOUT.connect
+        timeout.connect = timeout.connect or HTTP_REQUEST_DEFAULT_TIMEOUT.connect
         timeout.connect = min(timeout.connect, MAX_CONNECT_TIMEOUT)
-        if timeout.read is None:
-            timeout.read = HTTP_REQUEST_DEFAULT_TIMEOUT.read
+        timeout.read = timeout.read or HTTP_REQUEST_DEFAULT_TIMEOUT.read
         timeout.read = min(timeout.read, MAX_READ_TIMEOUT)
-        if timeout.write is None:
-            timeout.write = HTTP_REQUEST_DEFAULT_TIMEOUT.write
+        timeout.write = timeout.write or HTTP_REQUEST_DEFAULT_TIMEOUT.write
         timeout.write = min(timeout.write, MAX_WRITE_TIMEOUT)
         return timeout
 
     @classmethod
-    def _extract_variable_selector_to_variable_mapping(cls, node_data: HttpRequestNodeData) -> dict[str, list[str]]:
+    def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
         """
         Extract variable selector to variable mapping
         :param node_data: node data
         :return:
         """
+        node_data = cast(HttpRequestNodeData, node_data)
         try:
             http_executor = HttpExecutor(node_data=node_data, timeout=HTTP_REQUEST_DEFAULT_TIMEOUT)
 
@@ -124,7 +124,7 @@ class HttpRequestNode(BaseNode):
 
             return variable_mapping
         except Exception as e:
-            logging.exception(f"Failed to extract variable selector to variable mapping: {e}")
+            logging.exception(f'Failed to extract variable selector to variable mapping: {e}')
             return {}
 
     def extract_files(self, url: str, response: HttpExecutorResponse) -> list[FileVar]:
@@ -144,21 +144,23 @@ class HttpRequestNode(BaseNode):
             extension = guess_extension(mimetype) or '.bin'
 
             tool_file = ToolFileManager.create_file_by_raw(
-                user_id=self.user_id, 
-                tenant_id=self.tenant_id, 
-                conversation_id=None, 
-                file_binary=file_binary, 
+                user_id=self.user_id,
+                tenant_id=self.tenant_id,
+                conversation_id=None,
+                file_binary=file_binary,
                 mimetype=mimetype,
             )
 
-            files.append(FileVar(
-                tenant_id=self.tenant_id,
-                type=FileType.IMAGE,
-                transfer_method=FileTransferMethod.TOOL_FILE,
-                related_id=tool_file.id,
-                filename=filename,
-                extension=extension,
-                mime_type=mimetype,
-            ))
+            files.append(
+                FileVar(
+                    tenant_id=self.tenant_id,
+                    type=FileType.IMAGE,
+                    transfer_method=FileTransferMethod.TOOL_FILE,
+                    related_id=tool_file.id,
+                    filename=filename,
+                    extension=extension,
+                    mime_type=mimetype,
+                )
+            )
 
         return files