spark_llm.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. import base64
  2. import hashlib
  3. import hmac
  4. import json
  5. import queue
  6. import ssl
  7. from datetime import datetime
  8. from time import mktime
  9. from typing import Optional
  10. from urllib.parse import urlencode, urlparse
  11. from wsgiref.handlers import format_date_time
  12. import websocket
  13. class SparkLLMClient:
  14. def __init__(self, model_name: str, app_id: str, api_key: str, api_secret: str, api_domain: Optional[str] = None):
  15. domain = 'spark-api.xf-yun.com'
  16. endpoint = 'chat'
  17. if api_domain:
  18. domain = api_domain
  19. if model_name == 'spark-v3':
  20. endpoint = 'multimodal'
  21. model_api_configs = {
  22. 'spark': {
  23. 'version': 'v1.1',
  24. 'chat_domain': 'general'
  25. },
  26. 'spark-v2': {
  27. 'version': 'v2.1',
  28. 'chat_domain': 'generalv2'
  29. },
  30. 'spark-v3': {
  31. 'version': 'v3.1',
  32. 'chat_domain': 'generalv3'
  33. },
  34. 'spark-v3.5': {
  35. 'version': 'v3.5',
  36. 'chat_domain': 'generalv3.5'
  37. }
  38. }
  39. api_version = model_api_configs[model_name]['version']
  40. self.chat_domain = model_api_configs[model_name]['chat_domain']
  41. self.api_base = f"wss://{domain}/{api_version}/{endpoint}"
  42. self.app_id = app_id
  43. self.ws_url = self.create_url(
  44. urlparse(self.api_base).netloc,
  45. urlparse(self.api_base).path,
  46. self.api_base,
  47. api_key,
  48. api_secret
  49. )
  50. self.queue = queue.Queue()
  51. self.blocking_message = ''
  52. def create_url(self, host: str, path: str, api_base: str, api_key: str, api_secret: str) -> str:
  53. # generate timestamp by RFC1123
  54. now = datetime.now()
  55. date = format_date_time(mktime(now.timetuple()))
  56. signature_origin = "host: " + host + "\n"
  57. signature_origin += "date: " + date + "\n"
  58. signature_origin += "GET " + path + " HTTP/1.1"
  59. # encrypt using hmac-sha256
  60. signature_sha = hmac.new(api_secret.encode('utf-8'), signature_origin.encode('utf-8'),
  61. digestmod=hashlib.sha256).digest()
  62. signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
  63. authorization_origin = f'api_key="{api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
  64. authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
  65. v = {
  66. "authorization": authorization,
  67. "date": date,
  68. "host": host
  69. }
  70. # generate url
  71. url = api_base + '?' + urlencode(v)
  72. return url
  73. def run(self, messages: list, user_id: str,
  74. model_kwargs: Optional[dict] = None, streaming: bool = False):
  75. websocket.enableTrace(False)
  76. ws = websocket.WebSocketApp(
  77. self.ws_url,
  78. on_message=self.on_message,
  79. on_error=self.on_error,
  80. on_close=self.on_close,
  81. on_open=self.on_open
  82. )
  83. ws.messages = messages
  84. ws.user_id = user_id
  85. ws.model_kwargs = model_kwargs
  86. ws.streaming = streaming
  87. ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
  88. def on_error(self, ws, error):
  89. self.queue.put({
  90. 'status_code': error.status_code,
  91. 'error': error.resp_body.decode('utf-8')
  92. })
  93. ws.close()
  94. def on_close(self, ws, close_status_code, close_reason):
  95. self.queue.put({'done': True})
  96. def on_open(self, ws):
  97. self.blocking_message = ''
  98. data = json.dumps(self.gen_params(
  99. messages=ws.messages,
  100. user_id=ws.user_id,
  101. model_kwargs=ws.model_kwargs
  102. ))
  103. ws.send(data)
  104. def on_message(self, ws, message):
  105. data = json.loads(message)
  106. code = data['header']['code']
  107. if code != 0:
  108. self.queue.put({
  109. 'status_code': 400,
  110. 'error': f"Code: {code}, Error: {data['header']['message']}"
  111. })
  112. ws.close()
  113. else:
  114. choices = data["payload"]["choices"]
  115. status = choices["status"]
  116. content = choices["text"][0]["content"]
  117. if ws.streaming:
  118. self.queue.put({'data': content})
  119. else:
  120. self.blocking_message += content
  121. if status == 2:
  122. if not ws.streaming:
  123. self.queue.put({'data': self.blocking_message})
  124. ws.close()
  125. def gen_params(self, messages: list, user_id: str,
  126. model_kwargs: Optional[dict] = None) -> dict:
  127. data = {
  128. "header": {
  129. "app_id": self.app_id,
  130. "uid": user_id
  131. },
  132. "parameter": {
  133. "chat": {
  134. "domain": self.chat_domain
  135. }
  136. },
  137. "payload": {
  138. "message": {
  139. "text": messages
  140. }
  141. }
  142. }
  143. if model_kwargs:
  144. data['parameter']['chat'].update(model_kwargs)
  145. return data
  146. def subscribe(self):
  147. while True:
  148. content = self.queue.get()
  149. if 'error' in content:
  150. if content['status_code'] == 401:
  151. raise SparkError('[Spark] The credentials you provided are incorrect. '
  152. 'Please double-check and fill them in again.')
  153. elif content['status_code'] == 403:
  154. raise SparkError("[Spark] Sorry, the credentials you provided are access denied. "
  155. "Please try again after obtaining the necessary permissions.")
  156. else:
  157. raise SparkError(f"[Spark] code: {content['status_code']}, error: {content['error']}")
  158. if 'data' not in content:
  159. break
  160. yield content
  161. class SparkError(Exception):
  162. pass