|  | @@ -1,7 +1,8 @@
 | 
	
		
			
				|  |  |  import json
 | 
	
		
			
				|  |  |  from typing import Type
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -from langchain.llms import Xinference
 | 
	
		
			
				|  |  | +import requests
 | 
	
		
			
				|  |  | +from xinference.client import RESTfulGenerateModelHandle, RESTfulChatModelHandle, RESTfulChatglmCppChatModelHandle
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  from core.helper import encrypter
 | 
	
		
			
				|  |  |  from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding
 | 
	
	
		
			
				|  | @@ -10,6 +11,7 @@ from core.model_providers.models.llm.xinference_model import XinferenceModel
 | 
	
		
			
				|  |  |  from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  from core.model_providers.models.base import BaseProviderModel
 | 
	
		
			
				|  |  | +from core.third_party.langchain.llms.xinference_llm import XinferenceLLM
 | 
	
		
			
				|  |  |  from models.provider import ProviderType
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -48,13 +50,32 @@ class XinferenceProvider(BaseModelProvider):
 | 
	
		
			
				|  |  |          :param model_type:
 | 
	
		
			
				|  |  |          :return:
 | 
	
		
			
				|  |  |          """
 | 
	
		
			
				|  |  | -        return ModelKwargsRules(
 | 
	
		
			
				|  |  | -            temperature=KwargRule[float](min=0, max=2, default=1),
 | 
	
		
			
				|  |  | -            top_p=KwargRule[float](min=0, max=1, default=0.7),
 | 
	
		
			
				|  |  | -            presence_penalty=KwargRule[float](min=-2, max=2, default=0),
 | 
	
		
			
				|  |  | -            frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
 | 
	
		
			
				|  |  | -            max_tokens=KwargRule[int](min=10, max=4000, default=256),
 | 
	
		
			
				|  |  | -        )
 | 
	
		
			
				|  |  | +        credentials = self.get_model_credentials(model_name, model_type)
 | 
	
		
			
				|  |  | +        if credentials['model_format'] == "ggmlv3" and credentials["model_handle_type"] == "chatglm":
 | 
	
		
			
				|  |  | +            return ModelKwargsRules(
 | 
	
		
			
				|  |  | +                temperature=KwargRule[float](min=0.01, max=2, default=1),
 | 
	
		
			
				|  |  | +                top_p=KwargRule[float](min=0, max=1, default=0.7),
 | 
	
		
			
				|  |  | +                presence_penalty=KwargRule[float](enabled=False),
 | 
	
		
			
				|  |  | +                frequency_penalty=KwargRule[float](enabled=False),
 | 
	
		
			
				|  |  | +                max_tokens=KwargRule[int](min=10, max=4000, default=256),
 | 
	
		
			
				|  |  | +            )
 | 
	
		
			
				|  |  | +        elif credentials['model_format'] == "ggmlv3":
 | 
	
		
			
				|  |  | +            return ModelKwargsRules(
 | 
	
		
			
				|  |  | +                temperature=KwargRule[float](min=0.01, max=2, default=1),
 | 
	
		
			
				|  |  | +                top_p=KwargRule[float](min=0, max=1, default=0.7),
 | 
	
		
			
				|  |  | +                presence_penalty=KwargRule[float](min=-2, max=2, default=0),
 | 
	
		
			
				|  |  | +                frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
 | 
	
		
			
				|  |  | +                max_tokens=KwargRule[int](min=10, max=4000, default=256),
 | 
	
		
			
				|  |  | +            )
 | 
	
		
			
				|  |  | +        else:
 | 
	
		
			
				|  |  | +            return ModelKwargsRules(
 | 
	
		
			
				|  |  | +                temperature=KwargRule[float](min=0.01, max=2, default=1),
 | 
	
		
			
				|  |  | +                top_p=KwargRule[float](min=0, max=1, default=0.7),
 | 
	
		
			
				|  |  | +                presence_penalty=KwargRule[float](enabled=False),
 | 
	
		
			
				|  |  | +                frequency_penalty=KwargRule[float](enabled=False),
 | 
	
		
			
				|  |  | +                max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=256),
 | 
	
		
			
				|  |  | +            )
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      @classmethod
 | 
	
		
			
				|  |  |      def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
 | 
	
	
		
			
				|  | @@ -77,11 +98,11 @@ class XinferenceProvider(BaseModelProvider):
 | 
	
		
			
				|  |  |                  'model_uid': credentials['model_uid'],
 | 
	
		
			
				|  |  |              }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -            llm = Xinference(
 | 
	
		
			
				|  |  | +            llm = XinferenceLLM(
 | 
	
		
			
				|  |  |                  **credential_kwargs
 | 
	
		
			
				|  |  |              )
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -            llm("ping", generate_config={'max_tokens': 10})
 | 
	
		
			
				|  |  | +            llm("ping")
 | 
	
		
			
				|  |  |          except Exception as ex:
 | 
	
		
			
				|  |  |              raise CredentialsValidateFailedError(str(ex))
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -97,7 +118,11 @@ class XinferenceProvider(BaseModelProvider):
 | 
	
		
			
				|  |  |          :param credentials:
 | 
	
		
			
				|  |  |          :return:
 | 
	
		
			
				|  |  |          """
 | 
	
		
			
				|  |  | +        extra_credentials = cls._get_extra_credentials(credentials)
 | 
	
		
			
				|  |  | +        credentials.update(extra_credentials)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |          credentials['server_url'] = encrypter.encrypt_token(tenant_id, credentials['server_url'])
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |          return credentials
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
 | 
	
	
		
			
				|  | @@ -132,6 +157,30 @@ class XinferenceProvider(BaseModelProvider):
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          return credentials
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +    @classmethod
 | 
	
		
			
				|  |  | +    def _get_extra_credentials(self, credentials: dict) -> dict:
 | 
	
		
			
				|  |  | +        url = f"{credentials['server_url']}/v1/models/{credentials['model_uid']}"
 | 
	
		
			
				|  |  | +        response = requests.get(url)
 | 
	
		
			
				|  |  | +        if response.status_code != 200:
 | 
	
		
			
				|  |  | +            raise RuntimeError(
 | 
	
		
			
				|  |  | +                f"Failed to get the model description, detail: {response.json()['detail']}"
 | 
	
		
			
				|  |  | +            )
 | 
	
		
			
				|  |  | +        desc = response.json()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        extra_credentials = {
 | 
	
		
			
				|  |  | +            'model_format': desc['model_format'],
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +        if desc["model_format"] == "ggmlv3" and "chatglm" in desc["model_name"]:
 | 
	
		
			
				|  |  | +            extra_credentials['model_handle_type'] = 'chatglm'
 | 
	
		
			
				|  |  | +        elif "generate" in desc["model_ability"]:
 | 
	
		
			
				|  |  | +            extra_credentials['model_handle_type'] = 'generate'
 | 
	
		
			
				|  |  | +        elif "chat" in desc["model_ability"]:
 | 
	
		
			
				|  |  | +            extra_credentials['model_handle_type'] = 'chat'
 | 
	
		
			
				|  |  | +        else:
 | 
	
		
			
				|  |  | +            raise NotImplementedError(f"Model handle type not supported.")
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        return extra_credentials
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |      @classmethod
 | 
	
		
			
				|  |  |      def is_provider_credentials_valid_or_raise(cls, credentials: dict):
 | 
	
		
			
				|  |  |          return
 |