spark_llm.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. import base64
  2. import datetime
  3. import hashlib
  4. import hmac
  5. import json
  6. import queue
  7. import ssl
  8. from datetime import datetime
  9. from time import mktime
  10. from typing import Optional
  11. from urllib.parse import urlencode, urlparse
  12. from wsgiref.handlers import format_date_time
  13. import websocket
  14. class SparkLLMClient:
  15. def __init__(self, model_name: str, app_id: str, api_key: str, api_secret: str, api_domain: Optional[str] = None):
  16. domain = 'spark-api.xf-yun.com'
  17. endpoint = 'chat'
  18. if api_domain:
  19. domain = api_domain
  20. if model_name == 'spark-v3':
  21. endpoint = 'multimodal'
  22. model_api_configs = {
  23. 'spark': {
  24. 'version': 'v1.1',
  25. 'chat_domain': 'general'
  26. },
  27. 'spark-v2': {
  28. 'version': 'v2.1',
  29. 'chat_domain': 'generalv2'
  30. },
  31. 'spark-v3': {
  32. 'version': 'v3.1',
  33. 'chat_domain': 'generalv3'
  34. },
  35. 'spark-v3.5': {
  36. 'version': 'v3.5',
  37. 'chat_domain': 'generalv3.5'
  38. }
  39. }
  40. api_version = model_api_configs[model_name]['version']
  41. self.chat_domain = model_api_configs[model_name]['chat_domain']
  42. self.api_base = f"wss://{domain}/{api_version}/{endpoint}"
  43. self.app_id = app_id
  44. self.ws_url = self.create_url(
  45. urlparse(self.api_base).netloc,
  46. urlparse(self.api_base).path,
  47. self.api_base,
  48. api_key,
  49. api_secret
  50. )
  51. self.queue = queue.Queue()
  52. self.blocking_message = ''
  53. def create_url(self, host: str, path: str, api_base: str, api_key: str, api_secret: str) -> str:
  54. # generate timestamp by RFC1123
  55. now = datetime.now()
  56. date = format_date_time(mktime(now.timetuple()))
  57. signature_origin = "host: " + host + "\n"
  58. signature_origin += "date: " + date + "\n"
  59. signature_origin += "GET " + path + " HTTP/1.1"
  60. # encrypt using hmac-sha256
  61. signature_sha = hmac.new(api_secret.encode('utf-8'), signature_origin.encode('utf-8'),
  62. digestmod=hashlib.sha256).digest()
  63. signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
  64. authorization_origin = f'api_key="{api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
  65. authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
  66. v = {
  67. "authorization": authorization,
  68. "date": date,
  69. "host": host
  70. }
  71. # generate url
  72. url = api_base + '?' + urlencode(v)
  73. return url
  74. def run(self, messages: list, user_id: str,
  75. model_kwargs: Optional[dict] = None, streaming: bool = False):
  76. websocket.enableTrace(False)
  77. ws = websocket.WebSocketApp(
  78. self.ws_url,
  79. on_message=self.on_message,
  80. on_error=self.on_error,
  81. on_close=self.on_close,
  82. on_open=self.on_open
  83. )
  84. ws.messages = messages
  85. ws.user_id = user_id
  86. ws.model_kwargs = model_kwargs
  87. ws.streaming = streaming
  88. ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
  89. def on_error(self, ws, error):
  90. self.queue.put({
  91. 'status_code': error.status_code,
  92. 'error': error.resp_body.decode('utf-8')
  93. })
  94. ws.close()
  95. def on_close(self, ws, close_status_code, close_reason):
  96. self.queue.put({'done': True})
  97. def on_open(self, ws):
  98. self.blocking_message = ''
  99. data = json.dumps(self.gen_params(
  100. messages=ws.messages,
  101. user_id=ws.user_id,
  102. model_kwargs=ws.model_kwargs
  103. ))
  104. ws.send(data)
  105. def on_message(self, ws, message):
  106. data = json.loads(message)
  107. code = data['header']['code']
  108. if code != 0:
  109. self.queue.put({
  110. 'status_code': 400,
  111. 'error': f"Code: {code}, Error: {data['header']['message']}"
  112. })
  113. ws.close()
  114. else:
  115. choices = data["payload"]["choices"]
  116. status = choices["status"]
  117. content = choices["text"][0]["content"]
  118. if ws.streaming:
  119. self.queue.put({'data': content})
  120. else:
  121. self.blocking_message += content
  122. if status == 2:
  123. if not ws.streaming:
  124. self.queue.put({'data': self.blocking_message})
  125. ws.close()
  126. def gen_params(self, messages: list, user_id: str,
  127. model_kwargs: Optional[dict] = None) -> dict:
  128. data = {
  129. "header": {
  130. "app_id": self.app_id,
  131. "uid": user_id
  132. },
  133. "parameter": {
  134. "chat": {
  135. "domain": self.chat_domain
  136. }
  137. },
  138. "payload": {
  139. "message": {
  140. "text": messages
  141. }
  142. }
  143. }
  144. if model_kwargs:
  145. data['parameter']['chat'].update(model_kwargs)
  146. return data
  147. def subscribe(self):
  148. while True:
  149. content = self.queue.get()
  150. if 'error' in content:
  151. if content['status_code'] == 401:
  152. raise SparkError('[Spark] The credentials you provided are incorrect. '
  153. 'Please double-check and fill them in again.')
  154. elif content['status_code'] == 403:
  155. raise SparkError("[Spark] Sorry, the credentials you provided are access denied. "
  156. "Please try again after obtaining the necessary permissions.")
  157. else:
  158. raise SparkError(f"[Spark] code: {content['status_code']}, error: {content['error']}")
  159. if 'data' not in content:
  160. break
  161. yield content
  162. class SparkError(Exception):
  163. pass