| 
					
				 | 
			
			
				@@ -43,16 +43,17 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         server_url = credentials['server_url'] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         model_uid = credentials['model_uid'] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        api_key = credentials.get('api_key') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         if server_url.endswith('/'): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             server_url = server_url[:-1] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {} 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         try: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            handle = RESTfulEmbeddingModelHandle(model_uid, server_url, auth_headers={}) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            handle = RESTfulEmbeddingModelHandle(model_uid, server_url, auth_headers) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             embeddings = handle.create_embedding(input=texts) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         except RuntimeError as e: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            raise InvokeServerUnavailableError(e) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-         
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            raise InvokeServerUnavailableError(str(e)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         for convenience, the response json is like: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         class Embedding(TypedDict): 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -106,7 +107,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         try: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-             
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             server_url = credentials['server_url'] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             model_uid = credentials['model_uid'] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             extra_args = XinferenceHelper.get_xinference_extra_parameter(server_url=server_url, model_uid=model_uid) 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -117,7 +118,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 server_url = server_url[:-1] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             client = Client(base_url=server_url) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-         
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             try: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 handle = client.get_model(model_uid=model_uid) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             except RuntimeError as e: 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -151,7 +152,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 KeyError 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             ] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-     
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         Calculate response usage 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -186,7 +187,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             used to define customizable model schema 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-         
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         entity = AIModelEntity( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             model=model, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             label=I18nObject( 
			 |