浏览代码

Fix/workflow tool incorrect parameter configurations (#3402)

Co-authored-by: Joel <iamjoel007@gmail.com>
Yeuoly 1 年之前
父节点
当前提交
64e395d6cf

+ 15 - 2
api/core/tools/tool/tool.py

@@ -243,8 +243,21 @@ class Tool(BaseModel, ABC):
                             tool_parameters[parameter.name] = float(tool_parameters[parameter.name])
                 elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN:
                     if not isinstance(tool_parameters[parameter.name], bool):
-                        tool_parameters[parameter.name] = bool(tool_parameters[parameter.name])
-
+                        # check if it is a string
+                        if isinstance(tool_parameters[parameter.name], str):
+                            # check true false
+                            if tool_parameters[parameter.name].lower() in ['true', 'false']:
+                                tool_parameters[parameter.name] = tool_parameters[parameter.name].lower() == 'true'
+                            # check 1 0
+                            elif tool_parameters[parameter.name] in ['1', '0']:
+                                tool_parameters[parameter.name] = tool_parameters[parameter.name] == '1'
+                            else:
+                                tool_parameters[parameter.name] = bool(tool_parameters[parameter.name])
+                        elif isinstance(tool_parameters[parameter.name], int | float):
+                            tool_parameters[parameter.name] = tool_parameters[parameter.name] != 0
+                        else:
+                            tool_parameters[parameter.name] = bool(tool_parameters[parameter.name])
+                            
         return tool_parameters
 
     @abstractmethod

+ 23 - 8
api/core/workflow/nodes/tool/entities.py

@@ -1,10 +1,9 @@
-from typing import Literal, Union
+from typing import Any, Literal, Union
 
 from pydantic import BaseModel, validator
 
 from core.workflow.entities.base_node_data_entities import BaseNodeData
 
-ToolParameterValue = Union[str, int, float, bool]
 
 class ToolEntity(BaseModel):
     provider_id: str
@@ -12,11 +11,23 @@ class ToolEntity(BaseModel):
     provider_name: str # redundancy
     tool_name: str
     tool_label: str # redundancy
-    tool_configurations: dict[str, ToolParameterValue]
+    tool_configurations: dict[str, Any]
+
+    @validator('tool_configurations', pre=True, always=True)
+    def validate_tool_configurations(cls, value, values):
+        if not isinstance(value, dict):
+            raise ValueError('tool_configurations must be a dictionary')
+        
+        for key in values.get('tool_configurations', {}).keys():
+            value = values.get('tool_configurations', {}).get(key)
+            if not isinstance(value, str | int | float | bool):
+                raise ValueError(f'{key} must be a string')
+            
+        return value
 
 class ToolNodeData(BaseNodeData, ToolEntity):
     class ToolInput(BaseModel):
-        value: Union[ToolParameterValue, list[str]]
+        value: Union[Any, list[str]]
         type: Literal['mixed', 'variable', 'constant']
 
         @validator('type', pre=True, always=True)
@@ -25,12 +36,16 @@ class ToolNodeData(BaseNodeData, ToolEntity):
             value = values.get('value')
             if typ == 'mixed' and not isinstance(value, str):
                 raise ValueError('value must be a string')
-            elif typ == 'variable' and not isinstance(value, list):
-                raise ValueError('value must be a list')
-            elif typ == 'constant' and not isinstance(value, ToolParameterValue):
+            elif typ == 'variable':
+                if not isinstance(value, list):
+                    raise ValueError('value must be a list')
+                for val in value:
+                    if not isinstance(val, str):
+                        raise ValueError('value must be a list of strings')
+            elif typ == 'constant' and not isinstance(value, str | int | float | bool):
                 raise ValueError('value must be a string, int, float, or bool')
             return typ
-            
+        
     """
     Tool Node Schema
     """

+ 34 - 3
web/app/components/workflow/nodes/tool/use-config.ts

@@ -1,4 +1,4 @@
-import { useCallback, useEffect, useState } from 'react'
+import { useCallback, useEffect, useMemo, useState } from 'react'
 import { useTranslation } from 'react-i18next'
 import produce from 'immer'
 import { useBoolean } from 'ahooks'
@@ -25,7 +25,7 @@ const useConfig = (id: string, payload: ToolNodeType) => {
   const { t } = useTranslation()
 
   const language = useLanguage()
-  const { inputs, setInputs } = useNodeCrud<ToolNodeType>(id, payload)
+  const { inputs, setInputs: doSetInputs } = useNodeCrud<ToolNodeType>(id, payload)
   /*
   * tool_configurations: tool setting, not dynamic setting
   * tool_parameters: tool dynamic setting(by user)
@@ -58,10 +58,41 @@ const useConfig = (id: string, payload: ToolNodeType) => {
   }, [currCollection?.name, hideSetAuthModal, t, handleFetchAllTools, provider_type])
 
   const currTool = currCollection?.tools.find(tool => tool.name === tool_name)
-  const formSchemas = currTool ? toolParametersToFormSchemas(currTool.parameters) : []
+  const formSchemas = useMemo(() => {
+    return currTool ? toolParametersToFormSchemas(currTool.parameters) : []
+  }, [currTool])
   const toolInputVarSchema = formSchemas.filter((item: any) => item.form === 'llm')
   // use setting
   const toolSettingSchema = formSchemas.filter((item: any) => item.form !== 'llm')
+  const hasShouldTransferTypeSettingInput = toolSettingSchema.some(item => item.type === 'boolean' || item.type === 'number-input')
+
+  const setInputs = useCallback((value: ToolNodeType) => {
+    if (!hasShouldTransferTypeSettingInput) {
+      doSetInputs(value)
+      return
+    }
+    const newInputs = produce(value, (draft) => {
+      const newConfig = { ...draft.tool_configurations }
+      Object.keys(draft.tool_configurations).forEach((key) => {
+        const schema = formSchemas.find(item => item.variable === key)
+        const value = newConfig[key]
+        if (schema?.type === 'boolean') {
+          if (typeof value === 'string')
+            newConfig[key] = parseInt(value, 10)
+
+          if (typeof value === 'boolean')
+            newConfig[key] = value ? 1 : 0
+        }
+
+        if (schema?.type === 'number-input') {
+          if (typeof value === 'string' && value !== '')
+            newConfig[key] = parseFloat(value)
+        }
+      })
+      draft.tool_configurations = newConfig
+    })
+    doSetInputs(newInputs)
+  }, [doSetInputs, formSchemas, hasShouldTransferTypeSettingInput])
   const [notSetDefaultValue, setNotSetDefaultValue] = useState(false)
   const toolSettingValue = (() => {
     if (notSetDefaultValue)