|
@@ -2,6 +2,7 @@ import json
|
|
from typing import Type
|
|
from typing import Type
|
|
|
|
|
|
import requests
|
|
import requests
|
|
|
|
+from langchain.embeddings import XinferenceEmbeddings
|
|
|
|
|
|
from core.helper import encrypter
|
|
from core.helper import encrypter
|
|
from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding
|
|
from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding
|
|
@@ -97,11 +98,18 @@ class XinferenceProvider(BaseModelProvider):
|
|
'model_uid': credentials['model_uid'],
|
|
'model_uid': credentials['model_uid'],
|
|
}
|
|
}
|
|
|
|
|
|
- llm = XinferenceLLM(
|
|
|
|
- **credential_kwargs
|
|
|
|
- )
|
|
|
|
|
|
+ if model_type == ModelType.TEXT_GENERATION:
|
|
|
|
+ llm = XinferenceLLM(
|
|
|
|
+ **credential_kwargs
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ llm("ping")
|
|
|
|
+ elif model_type == ModelType.EMBEDDINGS:
|
|
|
|
+ embedding = XinferenceEmbeddings(
|
|
|
|
+ **credential_kwargs
|
|
|
|
+ )
|
|
|
|
|
|
- llm("ping")
|
|
|
|
|
|
+ embedding.embed_query("ping")
|
|
except Exception as ex:
|
|
except Exception as ex:
|
|
raise CredentialsValidateFailedError(str(ex))
|
|
raise CredentialsValidateFailedError(str(ex))
|
|
|
|
|
|
@@ -117,8 +125,9 @@ class XinferenceProvider(BaseModelProvider):
|
|
:param credentials:
|
|
:param credentials:
|
|
:return:
|
|
:return:
|
|
"""
|
|
"""
|
|
- extra_credentials = cls._get_extra_credentials(credentials)
|
|
|
|
- credentials.update(extra_credentials)
|
|
|
|
|
|
+ if model_type == ModelType.TEXT_GENERATION:
|
|
|
|
+ extra_credentials = cls._get_extra_credentials(credentials)
|
|
|
|
+ credentials.update(extra_credentials)
|
|
|
|
|
|
credentials['server_url'] = encrypter.encrypt_token(tenant_id, credentials['server_url'])
|
|
credentials['server_url'] = encrypter.encrypt_token(tenant_id, credentials['server_url'])
|
|
|
|
|