localai_provider.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. import json
  2. from typing import Type
  3. from langchain.embeddings import LocalAIEmbeddings
  4. from langchain.schema import HumanMessage
  5. from core.helper import encrypter
  6. from core.model_providers.models.embedding.localai_embedding import LocalAIEmbedding
  7. from core.model_providers.models.entity.model_params import ModelKwargsRules, ModelType, KwargRule
  8. from core.model_providers.models.llm.localai_model import LocalAIModel
  9. from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
  10. from core.model_providers.models.base import BaseProviderModel
  11. from core.third_party.langchain.llms.chat_open_ai import EnhanceChatOpenAI
  12. from core.third_party.langchain.llms.open_ai import EnhanceOpenAI
  13. from models.provider import ProviderType
  14. class LocalAIProvider(BaseModelProvider):
  15. @property
  16. def provider_name(self):
  17. """
  18. Returns the name of a provider.
  19. """
  20. return 'localai'
  21. def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
  22. return []
  23. def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
  24. """
  25. Returns the model class.
  26. :param model_type:
  27. :return:
  28. """
  29. if model_type == ModelType.TEXT_GENERATION:
  30. model_class = LocalAIModel
  31. elif model_type == ModelType.EMBEDDINGS:
  32. model_class = LocalAIEmbedding
  33. else:
  34. raise NotImplementedError
  35. return model_class
  36. def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
  37. """
  38. get model parameter rules.
  39. :param model_name:
  40. :param model_type:
  41. :return:
  42. """
  43. return ModelKwargsRules(
  44. temperature=KwargRule[float](min=0, max=2, default=0.7, precision=2),
  45. top_p=KwargRule[float](min=0, max=1, default=1, precision=2),
  46. max_tokens=KwargRule[int](min=10, max=4097, default=16, precision=0),
  47. )
  48. @classmethod
  49. def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
  50. """
  51. check model credentials valid.
  52. :param model_name:
  53. :param model_type:
  54. :param credentials:
  55. """
  56. if 'server_url' not in credentials:
  57. raise CredentialsValidateFailedError('LocalAI Server URL must be provided.')
  58. try:
  59. if model_type == ModelType.EMBEDDINGS:
  60. model = LocalAIEmbeddings(
  61. model=model_name,
  62. openai_api_key='1',
  63. openai_api_base=credentials['server_url']
  64. )
  65. model.embed_query("ping")
  66. else:
  67. if ('completion_type' not in credentials
  68. or credentials['completion_type'] not in ['completion', 'chat_completion']):
  69. raise CredentialsValidateFailedError('LocalAI Completion Type must be provided.')
  70. if credentials['completion_type'] == 'chat_completion':
  71. model = EnhanceChatOpenAI(
  72. model_name=model_name,
  73. openai_api_key='1',
  74. openai_api_base=credentials['server_url'] + '/v1',
  75. max_tokens=10,
  76. request_timeout=60,
  77. )
  78. model([HumanMessage(content='ping')])
  79. else:
  80. model = EnhanceOpenAI(
  81. model_name=model_name,
  82. openai_api_key='1',
  83. openai_api_base=credentials['server_url'] + '/v1',
  84. max_tokens=10,
  85. request_timeout=60,
  86. )
  87. model('ping')
  88. except Exception as ex:
  89. raise CredentialsValidateFailedError(str(ex))
  90. @classmethod
  91. def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
  92. credentials: dict) -> dict:
  93. """
  94. encrypt model credentials for save.
  95. :param tenant_id:
  96. :param model_name:
  97. :param model_type:
  98. :param credentials:
  99. :return:
  100. """
  101. credentials['server_url'] = encrypter.encrypt_token(tenant_id, credentials['server_url'])
  102. return credentials
  103. def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
  104. """
  105. get credentials for llm use.
  106. :param model_name:
  107. :param model_type:
  108. :param obfuscated:
  109. :return:
  110. """
  111. if self.provider.provider_type != ProviderType.CUSTOM.value:
  112. raise NotImplementedError
  113. provider_model = self._get_provider_model(model_name, model_type)
  114. if not provider_model.encrypted_config:
  115. return {
  116. 'server_url': None,
  117. }
  118. credentials = json.loads(provider_model.encrypted_config)
  119. if credentials['server_url']:
  120. credentials['server_url'] = encrypter.decrypt_token(
  121. self.provider.tenant_id,
  122. credentials['server_url']
  123. )
  124. if obfuscated:
  125. credentials['server_url'] = encrypter.obfuscated_token(credentials['server_url'])
  126. return credentials
  127. @classmethod
  128. def is_provider_credentials_valid_or_raise(cls, credentials: dict):
  129. return
  130. @classmethod
  131. def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
  132. return {}
  133. def get_provider_credentials(self, obfuscated: bool = False) -> dict:
  134. return {}