| 
					
				 | 
			
			
				@@ -16,9 +16,13 @@ import websocket 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 class SparkLLMClient: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    def __init__(self, app_id: str, api_key: str, api_secret: str, api_domain: Optional[str] = None): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def __init__(self, model_name: str, app_id: str, api_key: str, api_secret: str, api_domain: Optional[str] = None): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self.api_base = "wss://spark-api.xf-yun.com/v1.1/chat" if not api_domain else ('wss://' + api_domain + '/v1.1/chat') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        domain = 'spark-api.xf-yun.com' if not api_domain else api_domain 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        api_version = 'v2.1' if model_name == 'spark-v2' else 'v1.1' 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.chat_domain = 'generalv2' if model_name == 'spark-v2' else 'general' 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.api_base = f"wss://{domain}/{api_version}/chat" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         self.app_id = app_id 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         self.ws_url = self.create_url( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             urlparse(self.api_base).netloc, 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -76,7 +80,10 @@ class SparkLLMClient: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE}) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def on_error(self, ws, error): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self.queue.put({'error': error}) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.queue.put({ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            'status_code': error.status_code, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            'error': error.resp_body.decode('utf-8') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        }) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         ws.close() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def on_close(self, ws, close_status_code, close_reason): 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -120,7 +127,7 @@ class SparkLLMClient: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             }, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             "parameter": { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 "chat": { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    "domain": "general" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    "domain": self.chat_domain 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             }, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             "payload": { 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -139,7 +146,14 @@ class SparkLLMClient: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         while True: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             content = self.queue.get() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             if 'error' in content: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                raise SparkError(content['error']) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                if content['status_code'] == 401: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    raise SparkError('[Spark] The credentials you provided are incorrect. ' 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                     'Please double-check and fill them in again.') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                elif content['status_code'] == 403: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    raise SparkError("[Spark] Sorry, the credentials you provided are access denied. " 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                     "Please try again after obtaining the necessary permissions.") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    raise SparkError(f"[Spark] code: {content['status_code']}, error: {content['error']}") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             if 'data' not in content: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 break 
			 |