chatglm_provider.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. import json
  2. from json import JSONDecodeError
  3. from typing import Type
  4. import requests
  5. from langchain.llms import ChatGLM
  6. from core.helper import encrypter
  7. from core.model_providers.models.base import BaseProviderModel
  8. from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
  9. from core.model_providers.models.llm.chatglm_model import ChatGLMModel
  10. from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
  11. from models.provider import ProviderType
  12. class ChatGLMProvider(BaseModelProvider):
  13. @property
  14. def provider_name(self):
  15. """
  16. Returns the name of a provider.
  17. """
  18. return 'chatglm'
  19. def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
  20. if model_type == ModelType.TEXT_GENERATION:
  21. return [
  22. {
  23. 'id': 'chatglm3-6b',
  24. 'name': 'ChatGLM3-6B',
  25. 'mode': ModelMode.CHAT.value,
  26. },
  27. {
  28. 'id': 'chatglm3-6b-32k',
  29. 'name': 'ChatGLM3-6B-32K',
  30. 'mode': ModelMode.CHAT.value,
  31. },
  32. {
  33. 'id': 'chatglm2-6b',
  34. 'name': 'ChatGLM2-6B',
  35. 'mode': ModelMode.CHAT.value,
  36. }
  37. ]
  38. else:
  39. return []
  40. def _get_text_generation_model_mode(self, model_name) -> str:
  41. return ModelMode.CHAT.value
  42. def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
  43. """
  44. Returns the model class.
  45. :param model_type:
  46. :return:
  47. """
  48. if model_type == ModelType.TEXT_GENERATION:
  49. model_class = ChatGLMModel
  50. else:
  51. raise NotImplementedError
  52. return model_class
  53. def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
  54. """
  55. get model parameter rules.
  56. :param model_name:
  57. :param model_type:
  58. :return:
  59. """
  60. model_max_tokens = {
  61. 'chatglm3-6b-32k': 32000,
  62. 'chatglm3-6b': 8000,
  63. 'chatglm2-6b': 8000,
  64. }
  65. max_tokens_alias = 'max_length' if model_name == 'chatglm2-6b' else 'max_tokens'
  66. return ModelKwargsRules(
  67. temperature=KwargRule[float](min=0, max=2, default=1, precision=2),
  68. top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2),
  69. presence_penalty=KwargRule[float](enabled=False),
  70. frequency_penalty=KwargRule[float](enabled=False),
  71. max_tokens=KwargRule[int](alias=max_tokens_alias, min=10, max=model_max_tokens.get(model_name), default=2048, precision=0),
  72. )
  73. @classmethod
  74. def is_provider_credentials_valid_or_raise(cls, credentials: dict):
  75. """
  76. Validates the given credentials.
  77. """
  78. if 'api_base' not in credentials:
  79. raise CredentialsValidateFailedError('ChatGLM Endpoint URL must be provided.')
  80. try:
  81. response = requests.get(f"{credentials['api_base']}/v1/models", timeout=5)
  82. if response.status_code != 200:
  83. raise Exception('ChatGLM Endpoint URL is invalid.')
  84. except Exception as ex:
  85. raise CredentialsValidateFailedError(str(ex))
  86. @classmethod
  87. def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
  88. credentials['api_base'] = encrypter.encrypt_token(tenant_id, credentials['api_base'])
  89. return credentials
  90. def get_provider_credentials(self, obfuscated: bool = False) -> dict:
  91. if self.provider.provider_type == ProviderType.CUSTOM.value:
  92. try:
  93. credentials = json.loads(self.provider.encrypted_config)
  94. except JSONDecodeError:
  95. credentials = {
  96. 'api_base': None
  97. }
  98. if credentials['api_base']:
  99. credentials['api_base'] = encrypter.decrypt_token(
  100. self.provider.tenant_id,
  101. credentials['api_base']
  102. )
  103. if obfuscated:
  104. credentials['api_base'] = encrypter.obfuscated_token(credentials['api_base'])
  105. return credentials
  106. return {}
  107. @classmethod
  108. def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
  109. """
  110. check model credentials valid.
  111. :param model_name:
  112. :param model_type:
  113. :param credentials:
  114. """
  115. return
  116. @classmethod
  117. def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
  118. credentials: dict) -> dict:
  119. """
  120. encrypt model credentials for save.
  121. :param tenant_id:
  122. :param model_name:
  123. :param model_type:
  124. :param credentials:
  125. :return:
  126. """
  127. return {}
  128. def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
  129. """
  130. get credentials for llm use.
  131. :param model_name:
  132. :param model_type:
  133. :param obfuscated:
  134. :return:
  135. """
  136. return self.get_provider_credentials(obfuscated)