| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186 | import base64import datetimeimport hashlibimport hmacimport jsonimport queueimport sslfrom datetime import datetimefrom time import mktimefrom typing import Optionalfrom urllib.parse import urlencode, urlparsefrom wsgiref.handlers import format_date_timeimport websocketclass SparkLLMClient:    def __init__(self, model_name: str, app_id: str, api_key: str, api_secret: str, api_domain: Optional[str] = None):        domain = 'spark-api.xf-yun.com'        endpoint = 'chat'        if api_domain:            domain = api_domain            if model_name == 'spark-v3':                endpoint = 'multimodal'        model_api_configs = {            'spark': {                'version': 'v1.1',                'chat_domain': 'general'            },            'spark-v2': {                'version': 'v2.1',                'chat_domain': 'generalv2'            },            'spark-v3': {                'version': 'v3.1',                'chat_domain': 'generalv3'            }        }        api_version = model_api_configs[model_name]['version']        self.chat_domain = model_api_configs[model_name]['chat_domain']        self.api_base = f"wss://{domain}/{api_version}/{endpoint}"        self.app_id = app_id        self.ws_url = self.create_url(            urlparse(self.api_base).netloc,            urlparse(self.api_base).path,            self.api_base,            api_key,            api_secret        )        self.queue = queue.Queue()        self.blocking_message = ''    def create_url(self, host: str, path: str, api_base: str, api_key: str, api_secret: str) -> str:        # generate timestamp by RFC1123        now = datetime.now()        date = format_date_time(mktime(now.timetuple()))        signature_origin = "host: " + host + "\n"        signature_origin += "date: " + date + "\n"        signature_origin += "GET " + path + " HTTP/1.1"        # encrypt using hmac-sha256        signature_sha = hmac.new(api_secret.encode('utf-8'), signature_origin.encode('utf-8'),                                 digestmod=hashlib.sha256).digest()        signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')        authorization_origin = f'api_key="{api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'        authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')        v = {            "authorization": authorization,            "date": date,            "host": host        }        # generate url        url = api_base + '?' + urlencode(v)        return url    def run(self, messages: list, user_id: str,            model_kwargs: Optional[dict] = None, streaming: bool = False):        websocket.enableTrace(False)        ws = websocket.WebSocketApp(            self.ws_url,            on_message=self.on_message,            on_error=self.on_error,            on_close=self.on_close,            on_open=self.on_open        )        ws.messages = messages        ws.user_id = user_id        ws.model_kwargs = model_kwargs        ws.streaming = streaming        ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})    def on_error(self, ws, error):        self.queue.put({            'status_code': error.status_code,            'error': error.resp_body.decode('utf-8')        })        ws.close()    def on_close(self, ws, close_status_code, close_reason):        self.queue.put({'done': True})    def on_open(self, ws):        self.blocking_message = ''        data = json.dumps(self.gen_params(            messages=ws.messages,            user_id=ws.user_id,            model_kwargs=ws.model_kwargs        ))        ws.send(data)    def on_message(self, ws, message):        data = json.loads(message)        code = data['header']['code']        if code != 0:            self.queue.put({                'status_code': 400,                'error': f"Code: {code}, Error: {data['header']['message']}"            })            ws.close()        else:            choices = data["payload"]["choices"]            status = choices["status"]            content = choices["text"][0]["content"]            if ws.streaming:                self.queue.put({'data': content})            else:                self.blocking_message += content            if status == 2:                if not ws.streaming:                    self.queue.put({'data': self.blocking_message})                ws.close()    def gen_params(self, messages: list, user_id: str,                   model_kwargs: Optional[dict] = None) -> dict:        data = {            "header": {                "app_id": self.app_id,                "uid": user_id            },            "parameter": {                "chat": {                    "domain": self.chat_domain                }            },            "payload": {                "message": {                    "text": messages                }            }        }        if model_kwargs:            data['parameter']['chat'].update(model_kwargs)        return data    def subscribe(self):        while True:            content = self.queue.get()            if 'error' in content:                if content['status_code'] == 401:                    raise SparkError('[Spark] The credentials you provided are incorrect. '                                     'Please double-check and fill them in again.')                elif content['status_code'] == 403:                    raise SparkError("[Spark] Sorry, the credentials you provided are access denied. "                                     "Please try again after obtaining the necessary permissions.")                else:                    raise SparkError(f"[Spark] code: {content['status_code']}, error: {content['error']}")            if 'data' not in content:                break            yield contentclass SparkError(Exception):    pass
 |