model_factory.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293
  1. from typing import Optional
  2. from langchain.callbacks.base import Callbacks
  3. from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError
  4. from core.model_providers.model_provider_factory import ModelProviderFactory, DEFAULT_MODELS
  5. from core.model_providers.models.base import BaseProviderModel
  6. from core.model_providers.models.embedding.base import BaseEmbedding
  7. from core.model_providers.models.entity.model_params import ModelKwargs, ModelType
  8. from core.model_providers.models.llm.base import BaseLLM
  9. from core.model_providers.models.speech2text.base import BaseSpeech2Text
  10. from extensions.ext_database import db
  11. from models.provider import TenantDefaultModel
  12. class ModelFactory:
  13. @classmethod
  14. def get_text_generation_model_from_model_config(cls, tenant_id: str,
  15. model_config: dict,
  16. streaming: bool = False,
  17. callbacks: Callbacks = None) -> Optional[BaseLLM]:
  18. provider_name = model_config.get("provider")
  19. model_name = model_config.get("name")
  20. completion_params = model_config.get("completion_params", {})
  21. return cls.get_text_generation_model(
  22. tenant_id=tenant_id,
  23. model_provider_name=provider_name,
  24. model_name=model_name,
  25. model_kwargs=ModelKwargs(
  26. temperature=completion_params.get('temperature', 0),
  27. max_tokens=completion_params.get('max_tokens', 256),
  28. top_p=completion_params.get('top_p', 0),
  29. frequency_penalty=completion_params.get('frequency_penalty', 0.1),
  30. presence_penalty=completion_params.get('presence_penalty', 0.1)
  31. ),
  32. streaming=streaming,
  33. callbacks=callbacks
  34. )
  35. @classmethod
  36. def get_text_generation_model(cls,
  37. tenant_id: str,
  38. model_provider_name: Optional[str] = None,
  39. model_name: Optional[str] = None,
  40. model_kwargs: Optional[ModelKwargs] = None,
  41. streaming: bool = False,
  42. callbacks: Callbacks = None) -> Optional[BaseLLM]:
  43. """
  44. get text generation model.
  45. :param tenant_id: a string representing the ID of the tenant.
  46. :param model_provider_name:
  47. :param model_name:
  48. :param model_kwargs:
  49. :param streaming:
  50. :param callbacks:
  51. :return:
  52. """
  53. is_default_model = False
  54. if model_provider_name is None and model_name is None:
  55. default_model = cls.get_default_model(tenant_id, ModelType.TEXT_GENERATION)
  56. if not default_model:
  57. raise LLMBadRequestError(f"Default model is not available. "
  58. f"Please configure a Default System Reasoning Model "
  59. f"in the Settings -> Model Provider.")
  60. model_provider_name = default_model.provider_name
  61. model_name = default_model.model_name
  62. is_default_model = True
  63. # get model provider
  64. model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
  65. if not model_provider:
  66. raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
  67. # init text generation model
  68. model_class = model_provider.get_model_class(model_type=ModelType.TEXT_GENERATION)
  69. try:
  70. model_instance = model_class(
  71. model_provider=model_provider,
  72. name=model_name,
  73. model_kwargs=model_kwargs,
  74. streaming=streaming,
  75. callbacks=callbacks
  76. )
  77. except LLMBadRequestError as e:
  78. if is_default_model:
  79. raise LLMBadRequestError(f"Default model {model_name} is not available. "
  80. f"Please check your model provider credentials.")
  81. else:
  82. raise e
  83. if is_default_model:
  84. model_instance.deduct_quota = False
  85. return model_instance
  86. @classmethod
  87. def get_embedding_model(cls,
  88. tenant_id: str,
  89. model_provider_name: Optional[str] = None,
  90. model_name: Optional[str] = None) -> Optional[BaseEmbedding]:
  91. """
  92. get embedding model.
  93. :param tenant_id: a string representing the ID of the tenant.
  94. :param model_provider_name:
  95. :param model_name:
  96. :return:
  97. """
  98. if model_provider_name is None and model_name is None:
  99. default_model = cls.get_default_model(tenant_id, ModelType.EMBEDDINGS)
  100. if not default_model:
  101. raise LLMBadRequestError(f"Default model is not available. "
  102. f"Please configure a Default Embedding Model "
  103. f"in the Settings -> Model Provider.")
  104. model_provider_name = default_model.provider_name
  105. model_name = default_model.model_name
  106. # get model provider
  107. model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
  108. if not model_provider:
  109. raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
  110. # init embedding model
  111. model_class = model_provider.get_model_class(model_type=ModelType.EMBEDDINGS)
  112. return model_class(
  113. model_provider=model_provider,
  114. name=model_name
  115. )
  116. @classmethod
  117. def get_speech2text_model(cls,
  118. tenant_id: str,
  119. model_provider_name: Optional[str] = None,
  120. model_name: Optional[str] = None) -> Optional[BaseSpeech2Text]:
  121. """
  122. get speech to text model.
  123. :param tenant_id: a string representing the ID of the tenant.
  124. :param model_provider_name:
  125. :param model_name:
  126. :return:
  127. """
  128. if model_provider_name is None and model_name is None:
  129. default_model = cls.get_default_model(tenant_id, ModelType.SPEECH_TO_TEXT)
  130. if not default_model:
  131. raise LLMBadRequestError(f"Default model is not available. "
  132. f"Please configure a Default Speech-to-Text Model "
  133. f"in the Settings -> Model Provider.")
  134. model_provider_name = default_model.provider_name
  135. model_name = default_model.model_name
  136. # get model provider
  137. model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
  138. if not model_provider:
  139. raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
  140. # init speech to text model
  141. model_class = model_provider.get_model_class(model_type=ModelType.SPEECH_TO_TEXT)
  142. return model_class(
  143. model_provider=model_provider,
  144. name=model_name
  145. )
  146. @classmethod
  147. def get_moderation_model(cls,
  148. tenant_id: str,
  149. model_provider_name: str,
  150. model_name: str) -> Optional[BaseProviderModel]:
  151. """
  152. get moderation model.
  153. :param tenant_id: a string representing the ID of the tenant.
  154. :param model_provider_name:
  155. :param model_name:
  156. :return:
  157. """
  158. # get model provider
  159. model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
  160. if not model_provider:
  161. raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
  162. # init moderation model
  163. model_class = model_provider.get_model_class(model_type=ModelType.MODERATION)
  164. return model_class(
  165. model_provider=model_provider,
  166. name=model_name
  167. )
  168. @classmethod
  169. def get_default_model(cls, tenant_id: str, model_type: ModelType) -> TenantDefaultModel:
  170. """
  171. get default model of model type.
  172. :param tenant_id:
  173. :param model_type:
  174. :return:
  175. """
  176. # get default model
  177. default_model = db.session.query(TenantDefaultModel) \
  178. .filter(
  179. TenantDefaultModel.tenant_id == tenant_id,
  180. TenantDefaultModel.model_type == model_type.value
  181. ).first()
  182. if not default_model:
  183. model_provider_rules = ModelProviderFactory.get_provider_rules()
  184. for model_provider_name, model_provider_rule in model_provider_rules.items():
  185. model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
  186. if not model_provider:
  187. continue
  188. model_list = model_provider.get_supported_model_list(model_type)
  189. if model_list:
  190. model_info = model_list[0]
  191. default_model = TenantDefaultModel(
  192. tenant_id=tenant_id,
  193. model_type=model_type.value,
  194. provider_name=model_provider_name,
  195. model_name=model_info['id']
  196. )
  197. db.session.add(default_model)
  198. db.session.commit()
  199. break
  200. return default_model
  201. @classmethod
  202. def update_default_model(cls,
  203. tenant_id: str,
  204. model_type: ModelType,
  205. provider_name: str,
  206. model_name: str) -> TenantDefaultModel:
  207. """
  208. update default model of model type.
  209. :param tenant_id:
  210. :param model_type:
  211. :param provider_name:
  212. :param model_name:
  213. :return:
  214. """
  215. model_provider_name = ModelProviderFactory.get_provider_names()
  216. if provider_name not in model_provider_name:
  217. raise ValueError(f'Invalid provider name: {provider_name}')
  218. model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, provider_name)
  219. if not model_provider:
  220. raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
  221. model_list = model_provider.get_supported_model_list(model_type)
  222. model_ids = [model['id'] for model in model_list]
  223. if model_name not in model_ids:
  224. raise ValueError(f'Invalid model name: {model_name}')
  225. # get default model
  226. default_model = db.session.query(TenantDefaultModel) \
  227. .filter(
  228. TenantDefaultModel.tenant_id == tenant_id,
  229. TenantDefaultModel.model_type == model_type.value
  230. ).first()
  231. if default_model:
  232. # update default model
  233. default_model.provider_name = provider_name
  234. default_model.model_name = model_name
  235. db.session.commit()
  236. else:
  237. # create default model
  238. default_model = TenantDefaultModel(
  239. tenant_id=tenant_id,
  240. model_type=model_type.value,
  241. provider_name=provider_name,
  242. model_name=model_name,
  243. )
  244. db.session.add(default_model)
  245. db.session.commit()
  246. return default_model