|  | @@ -0,0 +1,62 @@
 | 
	
		
			
				|  |  | +from typing import Dict, Optional, List, Any
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +from huggingface_hub import HfApi, InferenceApi
 | 
	
		
			
				|  |  | +from langchain import HuggingFaceHub
 | 
	
		
			
				|  |  | +from langchain.callbacks.manager import CallbackManagerForLLMRun
 | 
	
		
			
				|  |  | +from langchain.llms.huggingface_hub import VALID_TASKS
 | 
	
		
			
				|  |  | +from pydantic import root_validator
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +from langchain.utils import get_from_dict_or_env
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +class HuggingFaceHubLLM(HuggingFaceHub):
 | 
	
		
			
				|  |  | +    """HuggingFaceHub  models.
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    To use, you should have the ``huggingface_hub`` python package installed, and the
 | 
	
		
			
				|  |  | +    environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token, or pass
 | 
	
		
			
				|  |  | +    it as a named parameter to the constructor.
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    Only supports `text-generation`, `text2text-generation` and `summarization` for now.
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    Example:
 | 
	
		
			
				|  |  | +        .. code-block:: python
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            from langchain.llms import HuggingFaceHub
 | 
	
		
			
				|  |  | +            hf = HuggingFaceHub(repo_id="gpt2", huggingfacehub_api_token="my-api-key")
 | 
	
		
			
				|  |  | +    """
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    @root_validator()
 | 
	
		
			
				|  |  | +    def validate_environment(cls, values: Dict) -> Dict:
 | 
	
		
			
				|  |  | +        """Validate that api key and python package exists in environment."""
 | 
	
		
			
				|  |  | +        huggingfacehub_api_token = get_from_dict_or_env(
 | 
	
		
			
				|  |  | +            values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN"
 | 
	
		
			
				|  |  | +        )
 | 
	
		
			
				|  |  | +        client = InferenceApi(
 | 
	
		
			
				|  |  | +            repo_id=values["repo_id"],
 | 
	
		
			
				|  |  | +            token=huggingfacehub_api_token,
 | 
	
		
			
				|  |  | +            task=values.get("task"),
 | 
	
		
			
				|  |  | +        )
 | 
	
		
			
				|  |  | +        client.options = {"wait_for_model": False, "use_gpu": False}
 | 
	
		
			
				|  |  | +        values["client"] = client
 | 
	
		
			
				|  |  | +        return values
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def _call(
 | 
	
		
			
				|  |  | +            self,
 | 
	
		
			
				|  |  | +            prompt: str,
 | 
	
		
			
				|  |  | +            stop: Optional[List[str]] = None,
 | 
	
		
			
				|  |  | +            run_manager: Optional[CallbackManagerForLLMRun] = None,
 | 
	
		
			
				|  |  | +            **kwargs: Any,
 | 
	
		
			
				|  |  | +    ) -> str:
 | 
	
		
			
				|  |  | +        hfapi = HfApi(token=self.huggingfacehub_api_token)
 | 
	
		
			
				|  |  | +        model_info = hfapi.model_info(repo_id=self.repo_id)
 | 
	
		
			
				|  |  | +        if not model_info:
 | 
	
		
			
				|  |  | +            raise ValueError(f"Model {self.repo_id} not found.")
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        if 'inference' in model_info.cardData and not model_info.cardData['inference']:
 | 
	
		
			
				|  |  | +            raise ValueError(f"Inference API has been turned off for this model {self.repo_id}.")
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        if model_info.pipeline_tag not in VALID_TASKS:
 | 
	
		
			
				|  |  | +            raise ValueError(f"Model {self.repo_id} is not a valid task, "
 | 
	
		
			
				|  |  | +                             f"must be one of {VALID_TASKS}.")
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        return super()._call(prompt, stop, run_manager, **kwargs)
 |