spark_llm.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. import base64
  2. import datetime
  3. import hashlib
  4. import hmac
  5. import json
  6. import queue
  7. import ssl
  8. from datetime import datetime
  9. from time import mktime
  10. from typing import Optional
  11. from urllib.parse import urlencode, urlparse
  12. from wsgiref.handlers import format_date_time
  13. import websocket
  14. class SparkLLMClient:
  15. def __init__(self, model_name: str, app_id: str, api_key: str, api_secret: str, api_domain: Optional[str] = None):
  16. domain = 'spark-api.xf-yun.com'
  17. endpoint = 'chat'
  18. if api_domain:
  19. domain = api_domain
  20. if model_name == 'spark-v3':
  21. endpoint = 'multimodal'
  22. model_api_configs = {
  23. 'spark': {
  24. 'version': 'v1.1',
  25. 'chat_domain': 'general'
  26. },
  27. 'spark-v2': {
  28. 'version': 'v2.1',
  29. 'chat_domain': 'generalv2'
  30. },
  31. 'spark-v3': {
  32. 'version': 'v3.1',
  33. 'chat_domain': 'generalv3'
  34. }
  35. }
  36. api_version = model_api_configs[model_name]['version']
  37. self.chat_domain = model_api_configs[model_name]['chat_domain']
  38. self.api_base = f"wss://{domain}/{api_version}/{endpoint}"
  39. self.app_id = app_id
  40. self.ws_url = self.create_url(
  41. urlparse(self.api_base).netloc,
  42. urlparse(self.api_base).path,
  43. self.api_base,
  44. api_key,
  45. api_secret
  46. )
  47. self.queue = queue.Queue()
  48. self.blocking_message = ''
  49. def create_url(self, host: str, path: str, api_base: str, api_key: str, api_secret: str) -> str:
  50. # generate timestamp by RFC1123
  51. now = datetime.now()
  52. date = format_date_time(mktime(now.timetuple()))
  53. signature_origin = "host: " + host + "\n"
  54. signature_origin += "date: " + date + "\n"
  55. signature_origin += "GET " + path + " HTTP/1.1"
  56. # encrypt using hmac-sha256
  57. signature_sha = hmac.new(api_secret.encode('utf-8'), signature_origin.encode('utf-8'),
  58. digestmod=hashlib.sha256).digest()
  59. signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
  60. authorization_origin = f'api_key="{api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
  61. authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
  62. v = {
  63. "authorization": authorization,
  64. "date": date,
  65. "host": host
  66. }
  67. # generate url
  68. url = api_base + '?' + urlencode(v)
  69. return url
  70. def run(self, messages: list, user_id: str,
  71. model_kwargs: Optional[dict] = None, streaming: bool = False):
  72. websocket.enableTrace(False)
  73. ws = websocket.WebSocketApp(
  74. self.ws_url,
  75. on_message=self.on_message,
  76. on_error=self.on_error,
  77. on_close=self.on_close,
  78. on_open=self.on_open
  79. )
  80. ws.messages = messages
  81. ws.user_id = user_id
  82. ws.model_kwargs = model_kwargs
  83. ws.streaming = streaming
  84. ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
  85. def on_error(self, ws, error):
  86. self.queue.put({
  87. 'status_code': error.status_code,
  88. 'error': error.resp_body.decode('utf-8')
  89. })
  90. ws.close()
  91. def on_close(self, ws, close_status_code, close_reason):
  92. self.queue.put({'done': True})
  93. def on_open(self, ws):
  94. self.blocking_message = ''
  95. data = json.dumps(self.gen_params(
  96. messages=ws.messages,
  97. user_id=ws.user_id,
  98. model_kwargs=ws.model_kwargs
  99. ))
  100. ws.send(data)
  101. def on_message(self, ws, message):
  102. data = json.loads(message)
  103. code = data['header']['code']
  104. if code != 0:
  105. self.queue.put({
  106. 'status_code': 400,
  107. 'error': f"Code: {code}, Error: {data['header']['message']}"
  108. })
  109. ws.close()
  110. else:
  111. choices = data["payload"]["choices"]
  112. status = choices["status"]
  113. content = choices["text"][0]["content"]
  114. if ws.streaming:
  115. self.queue.put({'data': content})
  116. else:
  117. self.blocking_message += content
  118. if status == 2:
  119. if not ws.streaming:
  120. self.queue.put({'data': self.blocking_message})
  121. ws.close()
  122. def gen_params(self, messages: list, user_id: str,
  123. model_kwargs: Optional[dict] = None) -> dict:
  124. data = {
  125. "header": {
  126. "app_id": self.app_id,
  127. "uid": user_id
  128. },
  129. "parameter": {
  130. "chat": {
  131. "domain": self.chat_domain
  132. }
  133. },
  134. "payload": {
  135. "message": {
  136. "text": messages
  137. }
  138. }
  139. }
  140. if model_kwargs:
  141. data['parameter']['chat'].update(model_kwargs)
  142. return data
  143. def subscribe(self):
  144. while True:
  145. content = self.queue.get()
  146. if 'error' in content:
  147. if content['status_code'] == 401:
  148. raise SparkError('[Spark] The credentials you provided are incorrect. '
  149. 'Please double-check and fill them in again.')
  150. elif content['status_code'] == 403:
  151. raise SparkError("[Spark] Sorry, the credentials you provided are access denied. "
  152. "Please try again after obtaining the necessary permissions.")
  153. else:
  154. raise SparkError(f"[Spark] code: {content['status_code']}, error: {content['error']}")
  155. if 'data' not in content:
  156. break
  157. yield content
  158. class SparkError(Exception):
  159. pass