llm_builder.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. from typing import Union, Optional
  2. from langchain.callbacks import CallbackManager
  3. from langchain.llms.fake import FakeListLLM
  4. from core.constant import llm_constant
  5. from core.llm.provider.llm_provider_service import LLMProviderService
  6. from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
  7. from core.llm.streamable_open_ai import StreamableOpenAI
  8. class LLMBuilder:
  9. """
  10. This class handles the following logic:
  11. 1. For providers with the name 'OpenAI', the OPENAI_API_KEY value is stored directly in encrypted_config.
  12. 2. For providers with the name 'Azure OpenAI', encrypted_config stores the serialized values of four fields, as shown below:
  13. OPENAI_API_TYPE=azure
  14. OPENAI_API_VERSION=2022-12-01
  15. OPENAI_API_BASE=https://your-resource-name.openai.azure.com
  16. OPENAI_API_KEY=<your Azure OpenAI API key>
  17. 3. For providers with the name 'Anthropic', the ANTHROPIC_API_KEY value is stored directly in encrypted_config.
  18. 4. For providers with the name 'Cohere', the COHERE_API_KEY value is stored directly in encrypted_config.
  19. 5. For providers with the name 'HUGGINGFACEHUB', the HUGGINGFACEHUB_API_KEY value is stored directly in encrypted_config.
  20. 6. Providers with the provider_type 'CUSTOM' can be created through the admin interface, while 'System' providers cannot be created through the admin interface.
  21. 7. If both CUSTOM and System providers exist in the records, the CUSTOM provider is preferred by default, but this preference can be changed via an input parameter.
  22. 8. For providers with the provider_type 'System', the quota_used must not exceed quota_limit. If the quota is exceeded, the provider cannot be used. Currently, only the TRIAL quota_type is supported, which is permanently non-resetting.
  23. """
  24. @classmethod
  25. def to_llm(cls, tenant_id: str, model_name: str, **kwargs) -> Union[StreamableOpenAI, StreamableChatOpenAI, FakeListLLM]:
  26. if model_name == 'fake':
  27. return FakeListLLM(responses=[])
  28. mode = cls.get_mode_by_model(model_name)
  29. if mode == 'chat':
  30. # llm_cls = StreamableAzureChatOpenAI
  31. llm_cls = StreamableChatOpenAI
  32. elif mode == 'completion':
  33. llm_cls = StreamableOpenAI
  34. else:
  35. raise ValueError(f"model name {model_name} is not supported.")
  36. model_credentials = cls.get_model_credentials(tenant_id, model_name)
  37. return llm_cls(
  38. model_name=model_name,
  39. temperature=kwargs.get('temperature', 0),
  40. max_tokens=kwargs.get('max_tokens', 256),
  41. top_p=kwargs.get('top_p', 1),
  42. frequency_penalty=kwargs.get('frequency_penalty', 0),
  43. presence_penalty=kwargs.get('presence_penalty', 0),
  44. callback_manager=kwargs.get('callback_manager', None),
  45. streaming=kwargs.get('streaming', False),
  46. # request_timeout=None
  47. **model_credentials
  48. )
  49. @classmethod
  50. def to_llm_from_model(cls, tenant_id: str, model: dict, streaming: bool = False,
  51. callback_manager: Optional[CallbackManager] = None) -> Union[StreamableOpenAI, StreamableChatOpenAI]:
  52. model_name = model.get("name")
  53. completion_params = model.get("completion_params", {})
  54. return cls.to_llm(
  55. tenant_id=tenant_id,
  56. model_name=model_name,
  57. temperature=completion_params.get('temperature', 0),
  58. max_tokens=completion_params.get('max_tokens', 256),
  59. top_p=completion_params.get('top_p', 0),
  60. frequency_penalty=completion_params.get('frequency_penalty', 0.1),
  61. presence_penalty=completion_params.get('presence_penalty', 0.1),
  62. streaming=streaming,
  63. callback_manager=callback_manager
  64. )
  65. @classmethod
  66. def get_mode_by_model(cls, model_name: str) -> str:
  67. if not model_name:
  68. raise ValueError(f"empty model name is not supported.")
  69. if model_name in llm_constant.models_by_mode['chat']:
  70. return "chat"
  71. elif model_name in llm_constant.models_by_mode['completion']:
  72. return "completion"
  73. else:
  74. raise ValueError(f"model name {model_name} is not supported.")
  75. @classmethod
  76. def get_model_credentials(cls, tenant_id: str, model_name: str) -> dict:
  77. """
  78. Returns the API credentials for the given tenant_id and model_name, based on the model's provider.
  79. Raises an exception if the model_name is not found or if the provider is not found.
  80. """
  81. if not model_name:
  82. raise Exception('model name not found')
  83. if model_name not in llm_constant.models:
  84. raise Exception('model {} not found'.format(model_name))
  85. model_provider = llm_constant.models[model_name]
  86. provider_service = LLMProviderService(tenant_id=tenant_id, provider_name=model_provider)
  87. return provider_service.get_credentials(model_name)