Explorar o código

feat: api_key support for xinference (#6417)

Signed-off-by: themanforfree <themanforfree@gmail.com>
themanforfree hai 10 meses
pai
achega
ba181197c2

+ 3 - 1
api/core/model_runtime/model_providers/xinference/llm/llm.py

@@ -453,9 +453,11 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
         if credentials['server_url'].endswith('/'):
         if credentials['server_url'].endswith('/'):
             credentials['server_url'] = credentials['server_url'][:-1]
             credentials['server_url'] = credentials['server_url'][:-1]
 
 
+        api_key = credentials.get('api_key') or "abc"
+
         client = OpenAI(
         client = OpenAI(
             base_url=f'{credentials["server_url"]}/v1',
             base_url=f'{credentials["server_url"]}/v1',
-            api_key='abc',
+            api_key=api_key,
             max_retries=3,
             max_retries=3,
             timeout=60,
             timeout=60,
         )
         )

+ 18 - 10
api/core/model_runtime/model_providers/xinference/rerank/rerank.py

@@ -44,15 +44,23 @@ class XinferenceRerankModel(RerankModel):
                 docs=[]
                 docs=[]
             )
             )
 
 
-        if credentials['server_url'].endswith('/'):
-            credentials['server_url'] = credentials['server_url'][:-1]
-
-        handle = RESTfulRerankModelHandle(credentials['model_uid'], credentials['server_url'],auth_headers={})
-        response = handle.rerank(
-            documents=docs,
-            query=query,
-            top_n=top_n,
-        )
+        server_url = credentials['server_url']
+        model_uid = credentials['model_uid']
+        api_key = credentials.get('api_key')
+        if server_url.endswith('/'):
+            server_url = server_url[:-1]
+        auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {}
+
+        try:
+            handle = RESTfulRerankModelHandle(model_uid, server_url, auth_headers)
+            response = handle.rerank(
+                documents=docs,
+                query=query,
+                top_n=top_n,
+            )
+        except RuntimeError as e:
+            raise InvokeServerUnavailableError(str(e))
+
 
 
         rerank_documents = []
         rerank_documents = []
         for idx, result in enumerate(response['results']):
         for idx, result in enumerate(response['results']):
@@ -102,7 +110,7 @@ class XinferenceRerankModel(RerankModel):
             if not isinstance(xinference_client, RESTfulRerankModelHandle):
             if not isinstance(xinference_client, RESTfulRerankModelHandle):
                 raise InvokeBadRequestError(
                 raise InvokeBadRequestError(
                     'please check model type, the model you want to invoke is not a rerank model')
                     'please check model type, the model you want to invoke is not a rerank model')
-            
+
             self.invoke(
             self.invoke(
                 model=model,
                 model=model,
                 credentials=credentials,
                 credentials=credentials,

+ 21 - 14
api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py

@@ -99,9 +99,9 @@ class XinferenceSpeech2TextModel(Speech2TextModel):
         }
         }
 
 
     def _speech2text_invoke(
     def _speech2text_invoke(
-        self, 
-        model: str, 
-        credentials: dict, 
+        self,
+        model: str,
+        credentials: dict,
         file: IO[bytes],
         file: IO[bytes],
         language: Optional[str] = None,
         language: Optional[str] = None,
         prompt: Optional[str] = None,
         prompt: Optional[str] = None,
@@ -121,17 +121,24 @@ class XinferenceSpeech2TextModel(Speech2TextModel):
         :param temperature: The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output mor            e random,while lower values like 0.2 will make it more focused and deterministic.If set to 0, the model wi            ll use log probability to automatically increase the temperature until certain thresholds are hit.
         :param temperature: The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output mor            e random,while lower values like 0.2 will make it more focused and deterministic.If set to 0, the model wi            ll use log probability to automatically increase the temperature until certain thresholds are hit.
         :return: text for given audio file
         :return: text for given audio file
         """
         """
-        if credentials['server_url'].endswith('/'):
-            credentials['server_url'] = credentials['server_url'][:-1]
-
-        handle = RESTfulAudioModelHandle(credentials['model_uid'],credentials['server_url'],auth_headers={})
-        response = handle.transcriptions(
-            audio=file,
-            language = language,
-            prompt = prompt,
-            response_format = response_format,
-            temperature = temperature
-        )
+        server_url = credentials['server_url']
+        model_uid = credentials['model_uid']
+        api_key = credentials.get('api_key')
+        if server_url.endswith('/'):
+            server_url = server_url[:-1]
+        auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {}
+
+        try:
+            handle = RESTfulAudioModelHandle(model_uid, server_url, auth_headers)
+            response = handle.transcriptions(
+                audio=file,
+                language=language,
+                prompt=prompt,
+                response_format=response_format,
+                temperature=temperature
+            )
+        except RuntimeError as e:
+            raise InvokeServerUnavailableError(str(e))
 
 
         return response["text"]
         return response["text"]
 
 

+ 9 - 8
api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py

@@ -43,16 +43,17 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
         """
         """
         server_url = credentials['server_url']
         server_url = credentials['server_url']
         model_uid = credentials['model_uid']
         model_uid = credentials['model_uid']
-
+        api_key = credentials.get('api_key')
         if server_url.endswith('/'):
         if server_url.endswith('/'):
             server_url = server_url[:-1]
             server_url = server_url[:-1]
+        auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {}
 
 
         try:
         try:
-            handle = RESTfulEmbeddingModelHandle(model_uid, server_url, auth_headers={})
+            handle = RESTfulEmbeddingModelHandle(model_uid, server_url, auth_headers)
             embeddings = handle.create_embedding(input=texts)
             embeddings = handle.create_embedding(input=texts)
         except RuntimeError as e:
         except RuntimeError as e:
-            raise InvokeServerUnavailableError(e)
-        
+            raise InvokeServerUnavailableError(str(e))
+
         """
         """
         for convenience, the response json is like:
         for convenience, the response json is like:
         class Embedding(TypedDict):
         class Embedding(TypedDict):
@@ -106,7 +107,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
         try:
         try:
             if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']:
             if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']:
                 raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
                 raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
-            
+
             server_url = credentials['server_url']
             server_url = credentials['server_url']
             model_uid = credentials['model_uid']
             model_uid = credentials['model_uid']
             extra_args = XinferenceHelper.get_xinference_extra_parameter(server_url=server_url, model_uid=model_uid)
             extra_args = XinferenceHelper.get_xinference_extra_parameter(server_url=server_url, model_uid=model_uid)
@@ -117,7 +118,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
                 server_url = server_url[:-1]
                 server_url = server_url[:-1]
 
 
             client = Client(base_url=server_url)
             client = Client(base_url=server_url)
-        
+
             try:
             try:
                 handle = client.get_model(model_uid=model_uid)
                 handle = client.get_model(model_uid=model_uid)
             except RuntimeError as e:
             except RuntimeError as e:
@@ -151,7 +152,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
                 KeyError
                 KeyError
             ]
             ]
         }
         }
-    
+
     def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
     def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
         """
         """
         Calculate response usage
         Calculate response usage
@@ -186,7 +187,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
         """
         """
             used to define customizable model schema
             used to define customizable model schema
         """
         """
-        
+
         entity = AIModelEntity(
         entity = AIModelEntity(
             model=model,
             model=model,
             label=I18nObject(
             label=I18nObject(

+ 9 - 0
api/core/model_runtime/model_providers/xinference/xinference.yaml

@@ -46,3 +46,12 @@ model_credential_schema:
       placeholder:
       placeholder:
         zh_Hans: 在此输入您的Model UID
         zh_Hans: 在此输入您的Model UID
         en_US: Enter the model uid
         en_US: Enter the model uid
+    - variable: api_key
+      label:
+        zh_Hans: API密钥
+        en_US: API key
+      type: text-input
+      required: false
+      placeholder:
+        zh_Hans: 在此输入您的API密钥
+        en_US: Enter the api key