Browse Source

fix: hf hosted inference check (#1128)

takatost 1 year ago
parent
commit
c4d8bdc3db

+ 5 - 3
api/core/model_providers/models/llm/huggingface_hub_model.py

@@ -1,6 +1,5 @@
 from typing import List, Optional, Any
 
-from langchain import HuggingFaceHub
 from langchain.callbacks.manager import Callbacks
 from langchain.schema import LLMResult
 
@@ -9,6 +8,7 @@ from core.model_providers.models.llm.base import BaseLLM
 from core.model_providers.models.entity.message import PromptMessage
 from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
 from core.third_party.langchain.llms.huggingface_endpoint_llm import HuggingFaceEndpointLLM
+from core.third_party.langchain.llms.huggingface_hub_llm import HuggingFaceHubLLM
 
 
 class HuggingfaceHubModel(BaseLLM):
@@ -31,7 +31,7 @@ class HuggingfaceHubModel(BaseLLM):
                 streaming=streaming
             )
         else:
-            client = HuggingFaceHub(
+            client = HuggingFaceHubLLM(
                 repo_id=self.name,
                 task=self.credentials['task_type'],
                 model_kwargs=provider_model_kwargs,
@@ -88,4 +88,6 @@ class HuggingfaceHubModel(BaseLLM):
             if 'baichuan' in self.name.lower():
                 return False
 
-        return True
+            return True
+        else:
+            return False

+ 2 - 1
api/core/model_providers/providers/huggingface_hub_provider.py

@@ -89,7 +89,8 @@ class HuggingfaceHubProvider(BaseModelProvider):
                 raise CredentialsValidateFailedError('Task Type must be provided.')
 
             if credentials['task_type'] not in ("text2text-generation", "text-generation", "summarization"):
-                raise CredentialsValidateFailedError('Task Type must be one of text2text-generation, text-generation, summarization.')
+                raise CredentialsValidateFailedError('Task Type must be one of text2text-generation, '
+                                                     'text-generation, summarization.')
 
             try:
                 llm = HuggingFaceEndpointLLM(

+ 62 - 0
api/core/third_party/langchain/llms/huggingface_hub_llm.py

@@ -0,0 +1,62 @@
+from typing import Dict, Optional, List, Any
+
+from huggingface_hub import HfApi, InferenceApi
+from langchain import HuggingFaceHub
+from langchain.callbacks.manager import CallbackManagerForLLMRun
+from langchain.llms.huggingface_hub import VALID_TASKS
+from pydantic import root_validator
+
+from langchain.utils import get_from_dict_or_env
+
+
+class HuggingFaceHubLLM(HuggingFaceHub):
+    """HuggingFaceHub  models.
+
+    To use, you should have the ``huggingface_hub`` python package installed, and the
+    environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token, or pass
+    it as a named parameter to the constructor.
+
+    Only supports `text-generation`, `text2text-generation` and `summarization` for now.
+
+    Example:
+        .. code-block:: python
+
+            from langchain.llms import HuggingFaceHub
+            hf = HuggingFaceHub(repo_id="gpt2", huggingfacehub_api_token="my-api-key")
+    """
+
+    @root_validator()
+    def validate_environment(cls, values: Dict) -> Dict:
+        """Validate that api key and python package exists in environment."""
+        huggingfacehub_api_token = get_from_dict_or_env(
+            values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN"
+        )
+        client = InferenceApi(
+            repo_id=values["repo_id"],
+            token=huggingfacehub_api_token,
+            task=values.get("task"),
+        )
+        client.options = {"wait_for_model": False, "use_gpu": False}
+        values["client"] = client
+        return values
+
+    def _call(
+            self,
+            prompt: str,
+            stop: Optional[List[str]] = None,
+            run_manager: Optional[CallbackManagerForLLMRun] = None,
+            **kwargs: Any,
+    ) -> str:
+        hfapi = HfApi(token=self.huggingfacehub_api_token)
+        model_info = hfapi.model_info(repo_id=self.repo_id)
+        if not model_info:
+            raise ValueError(f"Model {self.repo_id} not found.")
+
+        if 'inference' in model_info.cardData and not model_info.cardData['inference']:
+            raise ValueError(f"Inference API has been turned off for this model {self.repo_id}.")
+
+        if model_info.pipeline_tag not in VALID_TASKS:
+            raise ValueError(f"Model {self.repo_id} is not a valid task, "
+                             f"must be one of {VALID_TASKS}.")
+
+        return super()._call(prompt, stop, run_manager, **kwargs)