Bladeren bron

fix: inference embedding validate (#1187)

takatost 1 jaar geleden
bovenliggende
commit
c8bd76cd66
2 gewijzigde bestanden met toevoegingen van 17 en 8 verwijderingen
  1. 15 6
      api/core/model_providers/providers/xinference_provider.py
  2. 2 2
      api/requirements.txt

+ 15 - 6
api/core/model_providers/providers/xinference_provider.py

@@ -2,6 +2,7 @@ import json
 from typing import Type
 
 import requests
+from langchain.embeddings import XinferenceEmbeddings
 
 from core.helper import encrypter
 from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding
@@ -97,11 +98,18 @@ class XinferenceProvider(BaseModelProvider):
                 'model_uid': credentials['model_uid'],
             }
 
-            llm = XinferenceLLM(
-                **credential_kwargs
-            )
+            if model_type == ModelType.TEXT_GENERATION:
+                llm = XinferenceLLM(
+                    **credential_kwargs
+                )
+
+                llm("ping")
+            elif model_type == ModelType.EMBEDDINGS:
+                embedding = XinferenceEmbeddings(
+                    **credential_kwargs
+                )
 
-            llm("ping")
+                embedding.embed_query("ping")
         except Exception as ex:
             raise CredentialsValidateFailedError(str(ex))
 
@@ -117,8 +125,9 @@ class XinferenceProvider(BaseModelProvider):
         :param credentials:
         :return:
         """
-        extra_credentials = cls._get_extra_credentials(credentials)
-        credentials.update(extra_credentials)
+        if model_type == ModelType.TEXT_GENERATION:
+            extra_credentials = cls._get_extra_credentials(credentials)
+            credentials.update(extra_credentials)
 
         credentials['server_url'] = encrypter.encrypt_token(tenant_id, credentials['server_url'])
 

+ 2 - 2
api/requirements.txt

@@ -19,7 +19,7 @@ pytest~=7.3.1
 pytest-mock~=3.11.1
 tiktoken==0.3.3
 Authlib==1.2.0
-boto3~=1.26.123
+boto3==1.28.17
 tenacity==8.2.2
 cachetools~=5.3.0
 weaviate-client~=3.21.0
@@ -49,5 +49,5 @@ huggingface_hub~=0.16.4
 transformers~=4.31.0
 stripe~=5.5.0
 pandas==1.5.3
-xinference==0.2.1
+xinference==0.4.2
 safetensors==0.3.2