Browse Source

feat: add spark v2 support (#885)

takatost 1 year ago
parent
commit
f42e7d1a61

+ 1 - 1
api/core/model_providers/models/llm/spark_model.py

@@ -1,5 +1,4 @@
 import decimal
-from functools import wraps
 from typing import List, Optional, Any
 
 from langchain.callbacks.manager import Callbacks
@@ -19,6 +18,7 @@ class SparkModel(BaseLLM):
     def _init_client(self) -> Any:
         provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
         return ChatSpark(
+            model_name=self.name,
             streaming=self.streaming,
             callbacks=self.callbacks,
             **self.credentials,

+ 5 - 1
api/core/model_providers/providers/spark_provider.py

@@ -29,7 +29,11 @@ class SparkProvider(BaseModelProvider):
             return [
                 {
                     'id': 'spark',
-                    'name': '星火认知大模型',
+                    'name': 'Spark V1.5',
+                },
+                {
+                    'id': 'spark-v2',
+                    'name': 'Spark V2.0',
                 }
             ]
         else:

+ 5 - 0
api/core/third_party/langchain/llms/spark.py

@@ -25,6 +25,7 @@ class ChatSpark(BaseChatModel):
         .. code-block:: python
 
         client = SparkLLMClient(
+            model_name="<model_name>",
             app_id="<app_id>",
             api_key="<api_key>",
             api_secret="<api_secret>"
@@ -32,6 +33,9 @@ class ChatSpark(BaseChatModel):
     """
     client: Any = None  #: :meta private:
 
+    model_name: str = "spark"
+    """The Spark model name."""
+
     max_tokens: int = 256
     """Denotes the number of tokens to predict per generation."""
 
@@ -66,6 +70,7 @@ class ChatSpark(BaseChatModel):
         )
 
         values["client"] = SparkLLMClient(
+            model_name=values["model_name"],
             app_id=values["app_id"],
             api_key=values["api_key"],
             api_secret=values["api_secret"],

+ 19 - 5
api/core/third_party/spark/spark_llm.py

@@ -16,9 +16,13 @@ import websocket
 
 
 class SparkLLMClient:
-    def __init__(self, app_id: str, api_key: str, api_secret: str, api_domain: Optional[str] = None):
+    def __init__(self, model_name: str, app_id: str, api_key: str, api_secret: str, api_domain: Optional[str] = None):
 
-        self.api_base = "wss://spark-api.xf-yun.com/v1.1/chat" if not api_domain else ('wss://' + api_domain + '/v1.1/chat')
+        domain = 'spark-api.xf-yun.com' if not api_domain else api_domain
+        api_version = 'v2.1' if model_name == 'spark-v2' else 'v1.1'
+
+        self.chat_domain = 'generalv2' if model_name == 'spark-v2' else 'general'
+        self.api_base = f"wss://{domain}/{api_version}/chat"
         self.app_id = app_id
         self.ws_url = self.create_url(
             urlparse(self.api_base).netloc,
@@ -76,7 +80,10 @@ class SparkLLMClient:
         ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
 
     def on_error(self, ws, error):
-        self.queue.put({'error': error})
+        self.queue.put({
+            'status_code': error.status_code,
+            'error': error.resp_body.decode('utf-8')
+        })
         ws.close()
 
     def on_close(self, ws, close_status_code, close_reason):
@@ -120,7 +127,7 @@ class SparkLLMClient:
             },
             "parameter": {
                 "chat": {
-                    "domain": "general"
+                    "domain": self.chat_domain
                 }
             },
             "payload": {
@@ -139,7 +146,14 @@ class SparkLLMClient:
         while True:
             content = self.queue.get()
             if 'error' in content:
-                raise SparkError(content['error'])
+                if content['status_code'] == 401:
+                    raise SparkError('[Spark] The credentials you provided are incorrect. '
+                                     'Please double-check and fill them in again.')
+                elif content['status_code'] == 403:
+                    raise SparkError("[Spark] Sorry, the credentials you provided are access denied. "
+                                     "Please try again after obtaining the necessary permissions.")
+                else:
+                    raise SparkError(f"[Spark] code: {content['status_code']}, error: {content['error']}")
 
             if 'data' not in content:
                 break

+ 1 - 0
api/services/provider_service.py

@@ -471,6 +471,7 @@ class ProviderService:
             for model in model_list:
                 valid_model_dict = {
                     "model_name": model['id'],
+                    "model_display_name": model['name'],
                     "model_type": model_type,
                     "model_provider": {
                         "provider_name": provider.provider_name,