|  | @@ -43,16 +43,17 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
 | 
	
		
			
				|  |  |          """
 | 
	
		
			
				|  |  |          server_url = credentials['server_url']
 | 
	
		
			
				|  |  |          model_uid = credentials['model_uid']
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | +        api_key = credentials.get('api_key')
 | 
	
		
			
				|  |  |          if server_url.endswith('/'):
 | 
	
		
			
				|  |  |              server_url = server_url[:-1]
 | 
	
		
			
				|  |  | +        auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {}
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          try:
 | 
	
		
			
				|  |  | -            handle = RESTfulEmbeddingModelHandle(model_uid, server_url, auth_headers={})
 | 
	
		
			
				|  |  | +            handle = RESTfulEmbeddingModelHandle(model_uid, server_url, auth_headers)
 | 
	
		
			
				|  |  |              embeddings = handle.create_embedding(input=texts)
 | 
	
		
			
				|  |  |          except RuntimeError as e:
 | 
	
		
			
				|  |  | -            raise InvokeServerUnavailableError(e)
 | 
	
		
			
				|  |  | -        
 | 
	
		
			
				|  |  | +            raise InvokeServerUnavailableError(str(e))
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |          """
 | 
	
		
			
				|  |  |          for convenience, the response json is like:
 | 
	
		
			
				|  |  |          class Embedding(TypedDict):
 | 
	
	
		
			
				|  | @@ -106,7 +107,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
 | 
	
		
			
				|  |  |          try:
 | 
	
		
			
				|  |  |              if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']:
 | 
	
		
			
				|  |  |                  raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
 | 
	
		
			
				|  |  | -            
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |              server_url = credentials['server_url']
 | 
	
		
			
				|  |  |              model_uid = credentials['model_uid']
 | 
	
		
			
				|  |  |              extra_args = XinferenceHelper.get_xinference_extra_parameter(server_url=server_url, model_uid=model_uid)
 | 
	
	
		
			
				|  | @@ -117,7 +118,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
 | 
	
		
			
				|  |  |                  server_url = server_url[:-1]
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |              client = Client(base_url=server_url)
 | 
	
		
			
				|  |  | -        
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |              try:
 | 
	
		
			
				|  |  |                  handle = client.get_model(model_uid=model_uid)
 | 
	
		
			
				|  |  |              except RuntimeError as e:
 | 
	
	
		
			
				|  | @@ -151,7 +152,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
 | 
	
		
			
				|  |  |                  KeyError
 | 
	
		
			
				|  |  |              ]
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  | -    
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |      def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
 | 
	
		
			
				|  |  |          """
 | 
	
		
			
				|  |  |          Calculate response usage
 | 
	
	
		
			
				|  | @@ -186,7 +187,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
 | 
	
		
			
				|  |  |          """
 | 
	
		
			
				|  |  |              used to define customizable model schema
 | 
	
		
			
				|  |  |          """
 | 
	
		
			
				|  |  | -        
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |          entity = AIModelEntity(
 | 
	
		
			
				|  |  |              model=model,
 | 
	
		
			
				|  |  |              label=I18nObject(
 |