Procházet zdrojové kódy

Add support of tool-call for model provider "hunyuan" (#6656)

Co-authored-by: sun <sun@centen.cn>
Giga Group před 9 měsíci
rodič
revize
ca696fe94c

+ 120 - 10
api/core/model_runtime/model_providers/hunyuan/llm/llm.py

@@ -14,6 +14,7 @@ from core.model_runtime.entities.message_entities import (
     PromptMessage,
     PromptMessageTool,
     SystemPromptMessage,
+    ToolPromptMessage,
     UserPromptMessage,
 )
 from core.model_runtime.errors.invoke import InvokeError
@@ -44,6 +45,17 @@ class HunyuanLargeLanguageModel(LargeLanguageModel):
             "Stream": stream,
             **custom_parameters,
         }
+        # add Tools and ToolChoice
+        if (tools and len(tools) > 0):
+            params['ToolChoice'] = "auto"
+            params['Tools'] = [{
+                "Type": "function", 
+                "Function": {
+                    "Name": tool.name, 
+                    "Description": tool.description,
+                    "Parameters": json.dumps(tool.parameters)
+                }
+            } for tool in tools]
 
         request.from_json_string(json.dumps(params))
         response = client.ChatCompletions(request)
@@ -89,9 +101,43 @@ class HunyuanLargeLanguageModel(LargeLanguageModel):
 
     def _convert_prompt_messages_to_dicts(self, prompt_messages: list[PromptMessage]) -> list[dict]:
         """Convert a list of PromptMessage objects to a list of dictionaries with 'Role' and 'Content' keys."""
-        return [{"Role": message.role.value, "Content": message.content} for message in prompt_messages]
+        dict_list = []
+        for message in prompt_messages:
+            if isinstance(message, AssistantPromptMessage):
+                tool_calls = message.tool_calls
+                if (tool_calls and len(tool_calls) > 0):
+                    dict_tool_calls = [
+                        {
+                            "Id": tool_call.id,
+                            "Type": tool_call.type,
+                            "Function": {
+                                "Name": tool_call.function.name,
+                                "Arguments": tool_call.function.arguments if (tool_call.function.arguments == "") else "{}"
+                            }
+                        } for tool_call in tool_calls]
+                    
+                    dict_list.append({ 
+                        "Role": message.role.value,
+                        # fix set content = "" while tool_call request
+                        # fix [hunyuan] None, [TencentCloudSDKException] code:InvalidParameter message:Messages Content and Contents not allowed empty at the same time.
+                        "Content": " ", # message.content if (message.content is not None) else "", 
+                        "ToolCalls":  dict_tool_calls
+                    })
+                else:
+                    dict_list.append({ "Role": message.role.value, "Content": message.content })
+            elif isinstance(message, ToolPromptMessage):
+                tool_execute_result = { "result": message.content }
+                content =json.dumps(tool_execute_result, ensure_ascii=False)
+                dict_list.append({ "Role": message.role.value, "Content": content, "ToolCallId": message.tool_call_id  })
+            else:
+                dict_list.append({ "Role": message.role.value, "Content": message.content })
+        return dict_list
 
     def _handle_stream_chat_response(self, model, credentials, prompt_messages, resp):
+
+        tool_call = None
+        tool_calls = []
+
         for index, event in enumerate(resp):
             logging.debug("_handle_stream_chat_response, event: %s", event)
 
@@ -109,20 +155,54 @@ class HunyuanLargeLanguageModel(LargeLanguageModel):
             usage = data.get('Usage', {})
             prompt_tokens = usage.get('PromptTokens', 0)
             completion_tokens = usage.get('CompletionTokens', 0)
-            usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
+
+            response_tool_calls = delta.get('ToolCalls')
+            if (response_tool_calls is not None):
+                new_tool_calls = self._extract_response_tool_calls(response_tool_calls)
+                if (len(new_tool_calls) > 0):
+                    new_tool_call = new_tool_calls[0]
+                    if (tool_call is None): tool_call = new_tool_call
+                    elif (tool_call.id != new_tool_call.id):
+                        tool_calls.append(tool_call)
+                        tool_call = new_tool_call
+                    else:
+                        tool_call.function.name += new_tool_call.function.name
+                        tool_call.function.arguments += new_tool_call.function.arguments
+                if (tool_call is not None and len(tool_call.function.name) > 0 and len(tool_call.function.arguments) > 0):
+                    tool_calls.append(tool_call)
+                    tool_call = None
 
             assistant_prompt_message = AssistantPromptMessage(
                 content=message_content,
                 tool_calls=[]
             )
-
-            delta_chunk = LLMResultChunkDelta(
-                index=index,
-                role=delta.get('Role', 'assistant'),
-                message=assistant_prompt_message,
-                usage=usage,
-                finish_reason=finish_reason,
-            )
+            # rewrite content = "" while tool_call to avoid show content on web page
+            if (len(tool_calls) > 0): assistant_prompt_message.content = ""
+            
+            # add tool_calls to assistant_prompt_message
+            if (finish_reason == 'tool_calls'):
+                assistant_prompt_message.tool_calls = tool_calls
+                tool_call = None
+                tool_calls = []
+
+            if (len(finish_reason) > 0):
+                usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
+
+                delta_chunk = LLMResultChunkDelta(
+                    index=index,
+                    role=delta.get('Role', 'assistant'),
+                    message=assistant_prompt_message,
+                    usage=usage,
+                    finish_reason=finish_reason,
+                )
+                tool_call = None
+                tool_calls = []
+
+            else:
+                delta_chunk = LLMResultChunkDelta(
+                    index=index,
+                    message=assistant_prompt_message,
+                )
 
             yield LLMResultChunk(
                 model=model,
@@ -177,12 +257,15 @@ class HunyuanLargeLanguageModel(LargeLanguageModel):
         """
         human_prompt = "\n\nHuman:"
         ai_prompt = "\n\nAssistant:"
+        tool_prompt = "\n\nTool:"
         content = message.content
 
         if isinstance(message, UserPromptMessage):
             message_text = f"{human_prompt} {content}"
         elif isinstance(message, AssistantPromptMessage):
             message_text = f"{ai_prompt} {content}"
+        elif isinstance(message, ToolPromptMessage):
+            message_text = f"{tool_prompt} {content}"
         elif isinstance(message, SystemPromptMessage):
             message_text = content
         else:
@@ -203,3 +286,30 @@ class HunyuanLargeLanguageModel(LargeLanguageModel):
         return {
             InvokeError: [TencentCloudSDKException],
         }
+    
+    def _extract_response_tool_calls(self,
+                                     response_tool_calls: list[dict]) \
+            -> list[AssistantPromptMessage.ToolCall]:
+        """
+        Extract tool calls from response
+
+        :param response_tool_calls: response tool calls
+        :return: list of tool calls
+        """
+        tool_calls = []
+        if response_tool_calls:
+            for response_tool_call in response_tool_calls:
+                response_function = response_tool_call.get('Function', {})
+                function = AssistantPromptMessage.ToolCall.ToolCallFunction(
+                    name=response_function.get('Name', ''),
+                    arguments=response_function.get('Arguments', '')
+                )
+
+                tool_call = AssistantPromptMessage.ToolCall(
+                    id=response_tool_call.get('Id', 0),
+                    type='function',
+                    function=function
+                )
+                tool_calls.append(tool_call)
+
+        return tool_calls