| 
					
				 | 
			
			
				@@ -2,6 +2,7 @@ import json 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from typing import Type 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import requests 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from langchain.embeddings import XinferenceEmbeddings 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from core.helper import encrypter 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -97,11 +98,18 @@ class XinferenceProvider(BaseModelProvider): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 'model_uid': credentials['model_uid'], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            llm = XinferenceLLM( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                **credential_kwargs 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if model_type == ModelType.TEXT_GENERATION: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                llm = XinferenceLLM( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    **credential_kwargs 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                llm("ping") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            elif model_type == ModelType.EMBEDDINGS: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                embedding = XinferenceEmbeddings( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    **credential_kwargs 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            llm("ping") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                embedding.embed_query("ping") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         except Exception as ex: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             raise CredentialsValidateFailedError(str(ex)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -117,8 +125,9 @@ class XinferenceProvider(BaseModelProvider): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         :param credentials: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         :return: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        extra_credentials = cls._get_extra_credentials(credentials) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        credentials.update(extra_credentials) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if model_type == ModelType.TEXT_GENERATION: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            extra_credentials = cls._get_extra_credentials(credentials) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            credentials.update(extra_credentials) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         credentials['server_url'] = encrypter.encrypt_token(tenant_id, credentials['server_url']) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 |