|  | @@ -2,7 +2,6 @@ import json
 | 
	
		
			
				|  |  |  from typing import Type
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  from huggingface_hub import HfApi
 | 
	
		
			
				|  |  | -from langchain.llms import HuggingFaceEndpoint
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  from core.helper import encrypter
 | 
	
		
			
				|  |  |  from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
 | 
	
	
		
			
				|  | @@ -10,6 +9,7 @@ from core.model_providers.models.llm.huggingface_hub_model import HuggingfaceHub
 | 
	
		
			
				|  |  |  from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  from core.model_providers.models.base import BaseProviderModel
 | 
	
		
			
				|  |  | +from core.third_party.langchain.llms.huggingface_endpoint_llm import HuggingFaceEndpointLLM
 | 
	
		
			
				|  |  |  from models.provider import ProviderType
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -85,10 +85,16 @@ class HuggingfaceHubProvider(BaseModelProvider):
 | 
	
		
			
				|  |  |              if 'huggingfacehub_endpoint_url' not in credentials:
 | 
	
		
			
				|  |  |                  raise CredentialsValidateFailedError('Hugging Face Hub Endpoint URL must be provided.')
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +            if 'task_type' not in credentials:
 | 
	
		
			
				|  |  | +                raise CredentialsValidateFailedError('Task Type must be provided.')
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            if credentials['task_type'] not in ("text2text-generation", "text-generation", "summarization"):
 | 
	
		
			
				|  |  | +                raise CredentialsValidateFailedError('Task Type must be one of text2text-generation, text-generation, summarization.')
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |              try:
 | 
	
		
			
				|  |  | -                llm = HuggingFaceEndpoint(
 | 
	
		
			
				|  |  | +                llm = HuggingFaceEndpointLLM(
 | 
	
		
			
				|  |  |                      endpoint_url=credentials['huggingfacehub_endpoint_url'],
 | 
	
		
			
				|  |  | -                    task="text2text-generation",
 | 
	
		
			
				|  |  | +                    task=credentials['task_type'],
 | 
	
		
			
				|  |  |                      model_kwargs={"temperature": 0.5, "max_new_tokens": 200},
 | 
	
		
			
				|  |  |                      huggingfacehub_api_token=credentials['huggingfacehub_api_token']
 | 
	
		
			
				|  |  |                  )
 | 
	
	
		
			
				|  | @@ -160,6 +166,10 @@ class HuggingfaceHubProvider(BaseModelProvider):
 | 
	
		
			
				|  |  |              }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          credentials = json.loads(provider_model.encrypted_config)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        if 'task_type' not in credentials:
 | 
	
		
			
				|  |  | +            credentials['task_type'] = 'text-generation'
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |          if credentials['huggingfacehub_api_token']:
 | 
	
		
			
				|  |  |              credentials['huggingfacehub_api_token'] = encrypter.decrypt_token(
 | 
	
		
			
				|  |  |                  self.provider.tenant_id,
 |