| 
					
				 | 
			
			
				@@ -1,7 +1,8 @@ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import json 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from typing import Type 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-from langchain.llms import Xinference 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import requests 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from xinference.client import RESTfulGenerateModelHandle, RESTfulChatModelHandle, RESTfulChatglmCppChatModelHandle 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from core.helper import encrypter 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -10,6 +11,7 @@ from core.model_providers.models.llm.xinference_model import XinferenceModel 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from core.model_providers.models.base import BaseProviderModel 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from core.third_party.langchain.llms.xinference_llm import XinferenceLLM 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from models.provider import ProviderType 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -48,13 +50,32 @@ class XinferenceProvider(BaseModelProvider): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         :param model_type: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         :return: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        return ModelKwargsRules( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            temperature=KwargRule[float](min=0, max=2, default=1), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            top_p=KwargRule[float](min=0, max=1, default=0.7), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            presence_penalty=KwargRule[float](min=-2, max=2, default=0), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            frequency_penalty=KwargRule[float](min=-2, max=2, default=0), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            max_tokens=KwargRule[int](min=10, max=4000, default=256), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        credentials = self.get_model_credentials(model_name, model_type) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if credentials['model_format'] == "ggmlv3" and credentials["model_handle_type"] == "chatglm": 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            return ModelKwargsRules( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                temperature=KwargRule[float](min=0.01, max=2, default=1), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                top_p=KwargRule[float](min=0, max=1, default=0.7), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                presence_penalty=KwargRule[float](enabled=False), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                frequency_penalty=KwargRule[float](enabled=False), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                max_tokens=KwargRule[int](min=10, max=4000, default=256), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        elif credentials['model_format'] == "ggmlv3": 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            return ModelKwargsRules( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                temperature=KwargRule[float](min=0.01, max=2, default=1), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                top_p=KwargRule[float](min=0, max=1, default=0.7), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                presence_penalty=KwargRule[float](min=-2, max=2, default=0), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                frequency_penalty=KwargRule[float](min=-2, max=2, default=0), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                max_tokens=KwargRule[int](min=10, max=4000, default=256), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            return ModelKwargsRules( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                temperature=KwargRule[float](min=0.01, max=2, default=1), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                top_p=KwargRule[float](min=0, max=1, default=0.7), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                presence_penalty=KwargRule[float](enabled=False), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                frequency_penalty=KwargRule[float](enabled=False), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=256), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     @classmethod 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict): 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -77,11 +98,11 @@ class XinferenceProvider(BaseModelProvider): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 'model_uid': credentials['model_uid'], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            llm = Xinference( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            llm = XinferenceLLM( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 **credential_kwargs 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            llm("ping", generate_config={'max_tokens': 10}) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            llm("ping") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         except Exception as ex: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             raise CredentialsValidateFailedError(str(ex)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -97,7 +118,11 @@ class XinferenceProvider(BaseModelProvider): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         :param credentials: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         :return: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        extra_credentials = cls._get_extra_credentials(credentials) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        credentials.update(extra_credentials) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         credentials['server_url'] = encrypter.encrypt_token(tenant_id, credentials['server_url']) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         return credentials 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict: 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -132,6 +157,30 @@ class XinferenceProvider(BaseModelProvider): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         return credentials 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    @classmethod 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def _get_extra_credentials(self, credentials: dict) -> dict: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        url = f"{credentials['server_url']}/v1/models/{credentials['model_uid']}" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        response = requests.get(url) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if response.status_code != 200: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            raise RuntimeError( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                f"Failed to get the model description, detail: {response.json()['detail']}" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        desc = response.json() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        extra_credentials = { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            'model_format': desc['model_format'], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if desc["model_format"] == "ggmlv3" and "chatglm" in desc["model_name"]: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            extra_credentials['model_handle_type'] = 'chatglm' 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        elif "generate" in desc["model_ability"]: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            extra_credentials['model_handle_type'] = 'generate' 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        elif "chat" in desc["model_ability"]: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            extra_credentials['model_handle_type'] = 'chat' 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            raise NotImplementedError(f"Model handle type not supported.") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        return extra_credentials 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     @classmethod 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def is_provider_credentials_valid_or_raise(cls, credentials: dict): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         return 
			 |