|  | @@ -39,6 +39,7 @@ from core.model_runtime.entities.message_entities import (
 | 
	
		
			
				|  |  |  )
 | 
	
		
			
				|  |  |  from core.model_runtime.entities.model_entities import (
 | 
	
		
			
				|  |  |      AIModelEntity,
 | 
	
		
			
				|  |  | +    DefaultParameterName,
 | 
	
		
			
				|  |  |      FetchFrom,
 | 
	
		
			
				|  |  |      ModelFeature,
 | 
	
		
			
				|  |  |      ModelPropertyKey,
 | 
	
	
		
			
				|  | @@ -67,7 +68,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
 | 
	
		
			
				|  |  |      def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
 | 
	
		
			
				|  |  |                  model_parameters: dict, tools: list[PromptMessageTool] | None = None,
 | 
	
		
			
				|  |  |                  stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
 | 
	
		
			
				|  |  | -        -> LLMResult | Generator:
 | 
	
		
			
				|  |  | +            -> LLMResult | Generator:
 | 
	
		
			
				|  |  |          """
 | 
	
		
			
				|  |  |              invoke LLM
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -113,7 +114,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
 | 
	
		
			
				|  |  |                  elif 'generate' in extra_param.model_ability:
 | 
	
		
			
				|  |  |                      credentials['completion_type'] = 'completion'
 | 
	
		
			
				|  |  |                  else:
 | 
	
		
			
				|  |  | -                    raise ValueError(f'xinference model ability {extra_param.model_ability} is not supported, check if you have the right model type')
 | 
	
		
			
				|  |  | +                    raise ValueError(
 | 
	
		
			
				|  |  | +                        f'xinference model ability {extra_param.model_ability} is not supported, check if you have the right model type')
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |              if extra_param.support_function_call:
 | 
	
		
			
				|  |  |                  credentials['support_function_call'] = True
 | 
	
	
		
			
				|  | @@ -206,6 +208,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
 | 
	
		
			
				|  |  |          :param tools: tools for tool calling
 | 
	
		
			
				|  |  |          :return: number of tokens
 | 
	
		
			
				|  |  |          """
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |          def tokens(text: str):
 | 
	
		
			
				|  |  |              return self._get_num_tokens_by_gpt2(text)
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -339,6 +342,45 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
 | 
	
		
			
				|  |  |                      zh_Hans='最大生成长度',
 | 
	
		
			
				|  |  |                      en_US='Max Tokens'
 | 
	
		
			
				|  |  |                  )
 | 
	
		
			
				|  |  | +            ),
 | 
	
		
			
				|  |  | +            ParameterRule(
 | 
	
		
			
				|  |  | +                name=DefaultParameterName.PRESENCE_PENALTY,
 | 
	
		
			
				|  |  | +                use_template=DefaultParameterName.PRESENCE_PENALTY,
 | 
	
		
			
				|  |  | +                type=ParameterType.FLOAT,
 | 
	
		
			
				|  |  | +                label=I18nObject(
 | 
	
		
			
				|  |  | +                    en_US='Presence Penalty',
 | 
	
		
			
				|  |  | +                    zh_Hans='存在惩罚',
 | 
	
		
			
				|  |  | +                ),
 | 
	
		
			
				|  |  | +                required=False,
 | 
	
		
			
				|  |  | +                help=I18nObject(
 | 
	
		
			
				|  |  | +                    en_US='Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they '
 | 
	
		
			
				|  |  | +                          'appear in the text so far, increasing the model\'s likelihood to talk about new topics.',
 | 
	
		
			
				|  |  | +                    zh_Hans='介于 -2.0 和 2.0 之间的数字。正值会根据新词是否已出现在文本中对其进行惩罚,从而增加模型谈论新话题的可能性。'
 | 
	
		
			
				|  |  | +                ),
 | 
	
		
			
				|  |  | +                default=0.0,
 | 
	
		
			
				|  |  | +                min=-2.0,
 | 
	
		
			
				|  |  | +                max=2.0,
 | 
	
		
			
				|  |  | +                precision=2
 | 
	
		
			
				|  |  | +            ),
 | 
	
		
			
				|  |  | +            ParameterRule(
 | 
	
		
			
				|  |  | +                name=DefaultParameterName.FREQUENCY_PENALTY,
 | 
	
		
			
				|  |  | +                use_template=DefaultParameterName.FREQUENCY_PENALTY,
 | 
	
		
			
				|  |  | +                type=ParameterType.FLOAT,
 | 
	
		
			
				|  |  | +                label=I18nObject(
 | 
	
		
			
				|  |  | +                    en_US='Frequency Penalty',
 | 
	
		
			
				|  |  | +                    zh_Hans='频率惩罚',
 | 
	
		
			
				|  |  | +                ),
 | 
	
		
			
				|  |  | +                required=False,
 | 
	
		
			
				|  |  | +                help=I18nObject(
 | 
	
		
			
				|  |  | +                    en_US='Number between -2.0 and 2.0. Positive values penalize new tokens based on their '
 | 
	
		
			
				|  |  | +                          'existing frequency in the text so far, decreasing the model\'s likelihood to repeat the '
 | 
	
		
			
				|  |  | +                          'same line verbatim.',
 | 
	
		
			
				|  |  | +                    zh_Hans='介于 -2.0 和 2.0 之间的数字。正值会根据新词在文本中的现有频率对其进行惩罚,从而降低模型逐字重复相同内容的可能性。'
 | 
	
		
			
				|  |  | +                ),
 | 
	
		
			
				|  |  | +                default=0.0,
 | 
	
		
			
				|  |  | +                min=-2.0,
 | 
	
		
			
				|  |  | +                max=2.0,
 | 
	
		
			
				|  |  | +                precision=2
 | 
	
		
			
				|  |  |              )
 | 
	
		
			
				|  |  |          ]
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -364,7 +406,6 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
 | 
	
		
			
				|  |  |              else:
 | 
	
		
			
				|  |  |                  raise ValueError(f'xinference model ability {extra_args.model_ability} is not supported')
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  |          features = []
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          support_function_call = credentials.get('support_function_call', False)
 | 
	
	
		
			
				|  | @@ -395,9 +436,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
 | 
	
		
			
				|  |  |          return entity
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
 | 
	
		
			
				|  |  | -                 model_parameters: dict, extra_model_kwargs: XinferenceModelExtraParameter,
 | 
	
		
			
				|  |  | -                 tools: list[PromptMessageTool] | None = None,
 | 
	
		
			
				|  |  | -                 stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
 | 
	
		
			
				|  |  | +                  model_parameters: dict, extra_model_kwargs: XinferenceModelExtraParameter,
 | 
	
		
			
				|  |  | +                  tools: list[PromptMessageTool] | None = None,
 | 
	
		
			
				|  |  | +                  stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
 | 
	
		
			
				|  |  |              -> LLMResult | Generator:
 | 
	
		
			
				|  |  |          """
 | 
	
		
			
				|  |  |              generate text from LLM
 | 
	
	
		
			
				|  | @@ -429,6 +470,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
 | 
	
		
			
				|  |  |              'temperature': model_parameters.get('temperature', 1.0),
 | 
	
		
			
				|  |  |              'top_p': model_parameters.get('top_p', 0.7),
 | 
	
		
			
				|  |  |              'max_tokens': model_parameters.get('max_tokens', 512),
 | 
	
		
			
				|  |  | +            'presence_penalty': model_parameters.get('presence_penalty', 0.0),
 | 
	
		
			
				|  |  | +            'frequency_penalty': model_parameters.get('frequency_penalty', 0.0),
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          if stop:
 | 
	
	
		
			
				|  | @@ -453,10 +496,12 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
 | 
	
		
			
				|  |  |              if stream:
 | 
	
		
			
				|  |  |                  if tools and len(tools) > 0:
 | 
	
		
			
				|  |  |                      raise InvokeBadRequestError('xinference tool calls does not support stream mode')
 | 
	
		
			
				|  |  | -                return self._handle_chat_stream_response(model=model, credentials=credentials, prompt_messages=prompt_messages,
 | 
	
		
			
				|  |  | -                                                        tools=tools, resp=resp)
 | 
	
		
			
				|  |  | -            return self._handle_chat_generate_response(model=model, credentials=credentials, prompt_messages=prompt_messages,
 | 
	
		
			
				|  |  | -                                                        tools=tools, resp=resp)
 | 
	
		
			
				|  |  | +                return self._handle_chat_stream_response(model=model, credentials=credentials,
 | 
	
		
			
				|  |  | +                                                         prompt_messages=prompt_messages,
 | 
	
		
			
				|  |  | +                                                         tools=tools, resp=resp)
 | 
	
		
			
				|  |  | +            return self._handle_chat_generate_response(model=model, credentials=credentials,
 | 
	
		
			
				|  |  | +                                                       prompt_messages=prompt_messages,
 | 
	
		
			
				|  |  | +                                                       tools=tools, resp=resp)
 | 
	
		
			
				|  |  |          elif isinstance(xinference_model, RESTfulGenerateModelHandle):
 | 
	
		
			
				|  |  |              resp = client.completions.create(
 | 
	
		
			
				|  |  |                  model=credentials['model_uid'],
 | 
	
	
		
			
				|  | @@ -466,10 +511,12 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
 | 
	
		
			
				|  |  |                  **generate_config,
 | 
	
		
			
				|  |  |              )
 | 
	
		
			
				|  |  |              if stream:
 | 
	
		
			
				|  |  | -                return self._handle_completion_stream_response(model=model, credentials=credentials, prompt_messages=prompt_messages,
 | 
	
		
			
				|  |  | -                                                        tools=tools, resp=resp)
 | 
	
		
			
				|  |  | -            return self._handle_completion_generate_response(model=model, credentials=credentials, prompt_messages=prompt_messages,
 | 
	
		
			
				|  |  | -                                                        tools=tools, resp=resp)
 | 
	
		
			
				|  |  | +                return self._handle_completion_stream_response(model=model, credentials=credentials,
 | 
	
		
			
				|  |  | +                                                               prompt_messages=prompt_messages,
 | 
	
		
			
				|  |  | +                                                               tools=tools, resp=resp)
 | 
	
		
			
				|  |  | +            return self._handle_completion_generate_response(model=model, credentials=credentials,
 | 
	
		
			
				|  |  | +                                                             prompt_messages=prompt_messages,
 | 
	
		
			
				|  |  | +                                                             tools=tools, resp=resp)
 | 
	
		
			
				|  |  |          else:
 | 
	
		
			
				|  |  |              raise NotImplementedError(f'xinference model handle type {type(xinference_model)} is not supported')
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -523,8 +570,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
 | 
	
		
			
				|  |  |          return tool_call
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def _handle_chat_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
 | 
	
		
			
				|  |  | -                                        tools: list[PromptMessageTool],
 | 
	
		
			
				|  |  | -                                        resp: ChatCompletion) -> LLMResult:
 | 
	
		
			
				|  |  | +                                       tools: list[PromptMessageTool],
 | 
	
		
			
				|  |  | +                                       resp: ChatCompletion) -> LLMResult:
 | 
	
		
			
				|  |  |          """
 | 
	
		
			
				|  |  |              handle normal chat generate response
 | 
	
		
			
				|  |  |          """
 | 
	
	
		
			
				|  | @@ -549,7 +596,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
 | 
	
		
			
				|  |  |          prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools)
 | 
	
		
			
				|  |  |          completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=tools)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -        usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)
 | 
	
		
			
				|  |  | +        usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens,
 | 
	
		
			
				|  |  | +                                          completion_tokens=completion_tokens)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          response = LLMResult(
 | 
	
		
			
				|  |  |              model=model,
 | 
	
	
		
			
				|  | @@ -560,10 +608,10 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
 | 
	
		
			
				|  |  |          )
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          return response
 | 
	
		
			
				|  |  | -    
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |      def _handle_chat_stream_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
 | 
	
		
			
				|  |  | -                                        tools: list[PromptMessageTool],
 | 
	
		
			
				|  |  | -                                        resp: Iterator[ChatCompletionChunk]) -> Generator:
 | 
	
		
			
				|  |  | +                                     tools: list[PromptMessageTool],
 | 
	
		
			
				|  |  | +                                     resp: Iterator[ChatCompletionChunk]) -> Generator:
 | 
	
		
			
				|  |  |          """
 | 
	
		
			
				|  |  |              handle stream chat generate response
 | 
	
		
			
				|  |  |          """
 | 
	
	
		
			
				|  | @@ -634,8 +682,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
 | 
	
		
			
				|  |  |                  full_response += delta.delta.content
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def _handle_completion_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
 | 
	
		
			
				|  |  | -                                        tools: list[PromptMessageTool],
 | 
	
		
			
				|  |  | -                                        resp: Completion) -> LLMResult:
 | 
	
		
			
				|  |  | +                                             tools: list[PromptMessageTool],
 | 
	
		
			
				|  |  | +                                             resp: Completion) -> LLMResult:
 | 
	
		
			
				|  |  |          """
 | 
	
		
			
				|  |  |              handle normal completion generate response
 | 
	
		
			
				|  |  |          """
 | 
	
	
		
			
				|  | @@ -671,8 +719,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
 | 
	
		
			
				|  |  |          return response
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def _handle_completion_stream_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
 | 
	
		
			
				|  |  | -                                        tools: list[PromptMessageTool],
 | 
	
		
			
				|  |  | -                                        resp: Iterator[Completion]) -> Generator:
 | 
	
		
			
				|  |  | +                                           tools: list[PromptMessageTool],
 | 
	
		
			
				|  |  | +                                           resp: Iterator[Completion]) -> Generator:
 | 
	
		
			
				|  |  |          """
 | 
	
		
			
				|  |  |              handle stream completion generate response
 | 
	
		
			
				|  |  |          """
 | 
	
	
		
			
				|  | @@ -764,4 +812,4 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
 | 
	
		
			
				|  |  |              InvokeBadRequestError: [
 | 
	
		
			
				|  |  |                  ValueError
 | 
	
		
			
				|  |  |              ]
 | 
	
		
			
				|  |  | -        }
 | 
	
		
			
				|  |  | +        }
 |