spark_llm.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. import base64
  2. import datetime
  3. import hashlib
  4. import hmac
  5. import json
  6. import queue
  7. from typing import Optional
  8. from urllib.parse import urlparse
  9. import ssl
  10. from datetime import datetime
  11. from time import mktime
  12. from urllib.parse import urlencode
  13. from wsgiref.handlers import format_date_time
  14. import websocket
  15. class SparkLLMClient:
  16. def __init__(self, app_id: str, api_key: str, api_secret: str, api_domain: Optional[str] = None):
  17. self.api_base = "wss://spark-api.xf-yun.com/v1.1/chat" if not api_domain else ('wss://' + api_domain + '/v1.1/chat')
  18. self.app_id = app_id
  19. self.ws_url = self.create_url(
  20. urlparse(self.api_base).netloc,
  21. urlparse(self.api_base).path,
  22. self.api_base,
  23. api_key,
  24. api_secret
  25. )
  26. self.queue = queue.Queue()
  27. self.blocking_message = ''
  28. def create_url(self, host: str, path: str, api_base: str, api_key: str, api_secret: str) -> str:
  29. # generate timestamp by RFC1123
  30. now = datetime.now()
  31. date = format_date_time(mktime(now.timetuple()))
  32. signature_origin = "host: " + host + "\n"
  33. signature_origin += "date: " + date + "\n"
  34. signature_origin += "GET " + path + " HTTP/1.1"
  35. # encrypt using hmac-sha256
  36. signature_sha = hmac.new(api_secret.encode('utf-8'), signature_origin.encode('utf-8'),
  37. digestmod=hashlib.sha256).digest()
  38. signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
  39. authorization_origin = f'api_key="{api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
  40. authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
  41. v = {
  42. "authorization": authorization,
  43. "date": date,
  44. "host": host
  45. }
  46. # generate url
  47. url = api_base + '?' + urlencode(v)
  48. return url
  49. def run(self, messages: list, user_id: str,
  50. model_kwargs: Optional[dict] = None, streaming: bool = False):
  51. websocket.enableTrace(False)
  52. ws = websocket.WebSocketApp(
  53. self.ws_url,
  54. on_message=self.on_message,
  55. on_error=self.on_error,
  56. on_close=self.on_close,
  57. on_open=self.on_open
  58. )
  59. ws.messages = messages
  60. ws.user_id = user_id
  61. ws.model_kwargs = model_kwargs
  62. ws.streaming = streaming
  63. ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
  64. def on_error(self, ws, error):
  65. self.queue.put({'error': error})
  66. ws.close()
  67. def on_close(self, ws, close_status_code, close_reason):
  68. self.queue.put({'done': True})
  69. def on_open(self, ws):
  70. self.blocking_message = ''
  71. data = json.dumps(self.gen_params(
  72. messages=ws.messages,
  73. user_id=ws.user_id,
  74. model_kwargs=ws.model_kwargs
  75. ))
  76. ws.send(data)
  77. def on_message(self, ws, message):
  78. data = json.loads(message)
  79. code = data['header']['code']
  80. if code != 0:
  81. self.queue.put({'error': f"Code: {code}, Error: {data['header']['message']}"})
  82. ws.close()
  83. else:
  84. choices = data["payload"]["choices"]
  85. status = choices["status"]
  86. content = choices["text"][0]["content"]
  87. if ws.streaming:
  88. self.queue.put({'data': content})
  89. else:
  90. self.blocking_message += content
  91. if status == 2:
  92. if not ws.streaming:
  93. self.queue.put({'data': self.blocking_message})
  94. ws.close()
  95. def gen_params(self, messages: list, user_id: str,
  96. model_kwargs: Optional[dict] = None) -> dict:
  97. data = {
  98. "header": {
  99. "app_id": self.app_id,
  100. "uid": user_id
  101. },
  102. "parameter": {
  103. "chat": {
  104. "domain": "general"
  105. }
  106. },
  107. "payload": {
  108. "message": {
  109. "text": messages
  110. }
  111. }
  112. }
  113. if model_kwargs:
  114. data['parameter']['chat'].update(model_kwargs)
  115. return data
  116. def subscribe(self):
  117. while True:
  118. content = self.queue.get()
  119. if 'error' in content:
  120. raise SparkError(content['error'])
  121. if 'data' not in content:
  122. break
  123. yield content
  124. class SparkError(Exception):
  125. pass