import base64
import datetime
import hashlib
import hmac
import json
import queue
import ssl
from datetime import datetime
from time import mktime
from typing import Optional
from urllib.parse import urlencode, urlparse
from wsgiref.handlers import format_date_time

import websocket


class SparkLLMClient:
    def __init__(self, model_name: str, app_id: str, api_key: str, api_secret: str, api_domain: Optional[str] = None):
        domain = 'spark-api.xf-yun.com'
        endpoint = 'chat'
        if api_domain:
            domain = api_domain
            if model_name == 'spark-v3':
                endpoint = 'multimodal'

        model_api_configs = {
            'spark': {
                'version': 'v1.1',
                'chat_domain': 'general'
            },
            'spark-v2': {
                'version': 'v2.1',
                'chat_domain': 'generalv2'
            },
            'spark-v3': {
                'version': 'v3.1',
                'chat_domain': 'generalv3'
            },
            'spark-v3.5': {
                'version': 'v3.5',
                'chat_domain': 'generalv3.5'
            }
        }

        api_version = model_api_configs[model_name]['version']

        self.chat_domain = model_api_configs[model_name]['chat_domain']
        self.api_base = f"wss://{domain}/{api_version}/{endpoint}"
        self.app_id = app_id
        self.ws_url = self.create_url(
            urlparse(self.api_base).netloc,
            urlparse(self.api_base).path,
            self.api_base,
            api_key,
            api_secret
        )

        self.queue = queue.Queue()
        self.blocking_message = ''

    def create_url(self, host: str, path: str, api_base: str, api_key: str, api_secret: str) -> str:
        # generate timestamp by RFC1123
        now = datetime.now()
        date = format_date_time(mktime(now.timetuple()))

        signature_origin = "host: " + host + "\n"
        signature_origin += "date: " + date + "\n"
        signature_origin += "GET " + path + " HTTP/1.1"

        # encrypt using hmac-sha256
        signature_sha = hmac.new(api_secret.encode('utf-8'), signature_origin.encode('utf-8'),
                                 digestmod=hashlib.sha256).digest()

        signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')

        authorization_origin = f'api_key="{api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'

        authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')

        v = {
            "authorization": authorization,
            "date": date,
            "host": host
        }
        # generate url
        url = api_base + '?' + urlencode(v)
        return url

    def run(self, messages: list, user_id: str,
            model_kwargs: Optional[dict] = None, streaming: bool = False):
        websocket.enableTrace(False)
        ws = websocket.WebSocketApp(
            self.ws_url,
            on_message=self.on_message,
            on_error=self.on_error,
            on_close=self.on_close,
            on_open=self.on_open
        )
        ws.messages = messages
        ws.user_id = user_id
        ws.model_kwargs = model_kwargs
        ws.streaming = streaming
        ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})

    def on_error(self, ws, error):
        self.queue.put({
            'status_code': error.status_code,
            'error': error.resp_body.decode('utf-8')
        })
        ws.close()

    def on_close(self, ws, close_status_code, close_reason):
        self.queue.put({'done': True})

    def on_open(self, ws):
        self.blocking_message = ''
        data = json.dumps(self.gen_params(
            messages=ws.messages,
            user_id=ws.user_id,
            model_kwargs=ws.model_kwargs
        ))
        ws.send(data)

    def on_message(self, ws, message):
        data = json.loads(message)
        code = data['header']['code']
        if code != 0:
            self.queue.put({
                'status_code': 400,
                'error': f"Code: {code}, Error: {data['header']['message']}"
            })
            ws.close()
        else:
            choices = data["payload"]["choices"]
            status = choices["status"]
            content = choices["text"][0]["content"]
            if ws.streaming:
                self.queue.put({'data': content})
            else:
                self.blocking_message += content

            if status == 2:
                if not ws.streaming:
                    self.queue.put({'data': self.blocking_message})
                ws.close()

    def gen_params(self, messages: list, user_id: str,
                   model_kwargs: Optional[dict] = None) -> dict:
        data = {
            "header": {
                "app_id": self.app_id,
                "uid": user_id
            },
            "parameter": {
                "chat": {
                    "domain": self.chat_domain
                }
            },
            "payload": {
                "message": {
                    "text": messages
                }
            }
        }

        if model_kwargs:
            data['parameter']['chat'].update(model_kwargs)

        return data

    def subscribe(self):
        while True:
            content = self.queue.get()
            if 'error' in content:
                if content['status_code'] == 401:
                    raise SparkError('[Spark] The credentials you provided are incorrect. '
                                     'Please double-check and fill them in again.')
                elif content['status_code'] == 403:
                    raise SparkError("[Spark] Sorry, the credentials you provided are access denied. "
                                     "Please try again after obtaining the necessary permissions.")
                else:
                    raise SparkError(f"[Spark] code: {content['status_code']}, error: {content['error']}")

            if 'data' not in content:
                break
            yield content


class SparkError(Exception):
    pass