Browse Source

feature: Add presence_penalty and frequency_penalty parameters to the … (#5637)

Co-authored-by: liuzhenghua-jk <liuzhenghua-jk@360shuke.com>
liuzhenghua 9 months ago
parent
commit
2b080b5cfc
1 changed files with 73 additions and 25 deletions
  1. 73 25
      api/core/model_runtime/model_providers/xinference/llm/llm.py

+ 73 - 25
api/core/model_runtime/model_providers/xinference/llm/llm.py

@@ -39,6 +39,7 @@ from core.model_runtime.entities.message_entities import (
 )
 from core.model_runtime.entities.model_entities import (
     AIModelEntity,
+    DefaultParameterName,
     FetchFrom,
     ModelFeature,
     ModelPropertyKey,
@@ -67,7 +68,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
     def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
                 model_parameters: dict, tools: list[PromptMessageTool] | None = None,
                 stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
-        -> LLMResult | Generator:
+            -> LLMResult | Generator:
         """
             invoke LLM
 
@@ -113,7 +114,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
                 elif 'generate' in extra_param.model_ability:
                     credentials['completion_type'] = 'completion'
                 else:
-                    raise ValueError(f'xinference model ability {extra_param.model_ability} is not supported, check if you have the right model type')
+                    raise ValueError(
+                        f'xinference model ability {extra_param.model_ability} is not supported, check if you have the right model type')
 
             if extra_param.support_function_call:
                 credentials['support_function_call'] = True
@@ -206,6 +208,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
         :param tools: tools for tool calling
         :return: number of tokens
         """
+
         def tokens(text: str):
             return self._get_num_tokens_by_gpt2(text)
 
@@ -339,6 +342,45 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
                     zh_Hans='最大生成长度',
                     en_US='Max Tokens'
                 )
+            ),
+            ParameterRule(
+                name=DefaultParameterName.PRESENCE_PENALTY,
+                use_template=DefaultParameterName.PRESENCE_PENALTY,
+                type=ParameterType.FLOAT,
+                label=I18nObject(
+                    en_US='Presence Penalty',
+                    zh_Hans='存在惩罚',
+                ),
+                required=False,
+                help=I18nObject(
+                    en_US='Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they '
+                          'appear in the text so far, increasing the model\'s likelihood to talk about new topics.',
+                    zh_Hans='介于 -2.0 和 2.0 之间的数字。正值会根据新词是否已出现在文本中对其进行惩罚,从而增加模型谈论新话题的可能性。'
+                ),
+                default=0.0,
+                min=-2.0,
+                max=2.0,
+                precision=2
+            ),
+            ParameterRule(
+                name=DefaultParameterName.FREQUENCY_PENALTY,
+                use_template=DefaultParameterName.FREQUENCY_PENALTY,
+                type=ParameterType.FLOAT,
+                label=I18nObject(
+                    en_US='Frequency Penalty',
+                    zh_Hans='频率惩罚',
+                ),
+                required=False,
+                help=I18nObject(
+                    en_US='Number between -2.0 and 2.0. Positive values penalize new tokens based on their '
+                          'existing frequency in the text so far, decreasing the model\'s likelihood to repeat the '
+                          'same line verbatim.',
+                    zh_Hans='介于 -2.0 和 2.0 之间的数字。正值会根据新词在文本中的现有频率对其进行惩罚,从而降低模型逐字重复相同内容的可能性。'
+                ),
+                default=0.0,
+                min=-2.0,
+                max=2.0,
+                precision=2
             )
         ]
 
@@ -364,7 +406,6 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
             else:
                 raise ValueError(f'xinference model ability {extra_args.model_ability} is not supported')
 
-
         features = []
 
         support_function_call = credentials.get('support_function_call', False)
@@ -395,9 +436,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
         return entity
 
     def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
-                 model_parameters: dict, extra_model_kwargs: XinferenceModelExtraParameter,
-                 tools: list[PromptMessageTool] | None = None,
-                 stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
+                  model_parameters: dict, extra_model_kwargs: XinferenceModelExtraParameter,
+                  tools: list[PromptMessageTool] | None = None,
+                  stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
             -> LLMResult | Generator:
         """
             generate text from LLM
@@ -429,6 +470,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
             'temperature': model_parameters.get('temperature', 1.0),
             'top_p': model_parameters.get('top_p', 0.7),
             'max_tokens': model_parameters.get('max_tokens', 512),
+            'presence_penalty': model_parameters.get('presence_penalty', 0.0),
+            'frequency_penalty': model_parameters.get('frequency_penalty', 0.0),
         }
 
         if stop:
@@ -453,10 +496,12 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
             if stream:
                 if tools and len(tools) > 0:
                     raise InvokeBadRequestError('xinference tool calls does not support stream mode')
-                return self._handle_chat_stream_response(model=model, credentials=credentials, prompt_messages=prompt_messages,
-                                                        tools=tools, resp=resp)
-            return self._handle_chat_generate_response(model=model, credentials=credentials, prompt_messages=prompt_messages,
-                                                        tools=tools, resp=resp)
+                return self._handle_chat_stream_response(model=model, credentials=credentials,
+                                                         prompt_messages=prompt_messages,
+                                                         tools=tools, resp=resp)
+            return self._handle_chat_generate_response(model=model, credentials=credentials,
+                                                       prompt_messages=prompt_messages,
+                                                       tools=tools, resp=resp)
         elif isinstance(xinference_model, RESTfulGenerateModelHandle):
             resp = client.completions.create(
                 model=credentials['model_uid'],
@@ -466,10 +511,12 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
                 **generate_config,
             )
             if stream:
-                return self._handle_completion_stream_response(model=model, credentials=credentials, prompt_messages=prompt_messages,
-                                                        tools=tools, resp=resp)
-            return self._handle_completion_generate_response(model=model, credentials=credentials, prompt_messages=prompt_messages,
-                                                        tools=tools, resp=resp)
+                return self._handle_completion_stream_response(model=model, credentials=credentials,
+                                                               prompt_messages=prompt_messages,
+                                                               tools=tools, resp=resp)
+            return self._handle_completion_generate_response(model=model, credentials=credentials,
+                                                             prompt_messages=prompt_messages,
+                                                             tools=tools, resp=resp)
         else:
             raise NotImplementedError(f'xinference model handle type {type(xinference_model)} is not supported')
 
@@ -523,8 +570,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
         return tool_call
 
     def _handle_chat_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
-                                        tools: list[PromptMessageTool],
-                                        resp: ChatCompletion) -> LLMResult:
+                                       tools: list[PromptMessageTool],
+                                       resp: ChatCompletion) -> LLMResult:
         """
             handle normal chat generate response
         """
@@ -549,7 +596,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
         prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools)
         completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=tools)
 
-        usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)
+        usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens,
+                                          completion_tokens=completion_tokens)
 
         response = LLMResult(
             model=model,
@@ -560,10 +608,10 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
         )
 
         return response
-    
+
     def _handle_chat_stream_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
-                                        tools: list[PromptMessageTool],
-                                        resp: Iterator[ChatCompletionChunk]) -> Generator:
+                                     tools: list[PromptMessageTool],
+                                     resp: Iterator[ChatCompletionChunk]) -> Generator:
         """
             handle stream chat generate response
         """
@@ -634,8 +682,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
                 full_response += delta.delta.content
 
     def _handle_completion_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
-                                        tools: list[PromptMessageTool],
-                                        resp: Completion) -> LLMResult:
+                                             tools: list[PromptMessageTool],
+                                             resp: Completion) -> LLMResult:
         """
             handle normal completion generate response
         """
@@ -671,8 +719,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
         return response
 
     def _handle_completion_stream_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
-                                        tools: list[PromptMessageTool],
-                                        resp: Iterator[Completion]) -> Generator:
+                                           tools: list[PromptMessageTool],
+                                           resp: Iterator[Completion]) -> Generator:
         """
             handle stream completion generate response
         """
@@ -764,4 +812,4 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
             InvokeBadRequestError: [
                 ValueError
             ]
-        }
+        }