model_factory.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. from typing import Optional
  2. from langchain.callbacks.base import Callbacks
  3. from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError
  4. from core.model_providers.model_provider_factory import ModelProviderFactory, DEFAULT_MODELS
  5. from core.model_providers.models.base import BaseProviderModel
  6. from core.model_providers.models.embedding.base import BaseEmbedding
  7. from core.model_providers.models.entity.model_params import ModelKwargs, ModelType
  8. from core.model_providers.models.llm.base import BaseLLM
  9. from core.model_providers.models.moderation.base import BaseModeration
  10. from core.model_providers.models.speech2text.base import BaseSpeech2Text
  11. from extensions.ext_database import db
  12. from models.provider import TenantDefaultModel
  13. class ModelFactory:
  14. @classmethod
  15. def get_text_generation_model_from_model_config(cls, tenant_id: str,
  16. model_config: dict,
  17. streaming: bool = False,
  18. callbacks: Callbacks = None) -> Optional[BaseLLM]:
  19. provider_name = model_config.get("provider")
  20. model_name = model_config.get("name")
  21. completion_params = model_config.get("completion_params", {})
  22. return cls.get_text_generation_model(
  23. tenant_id=tenant_id,
  24. model_provider_name=provider_name,
  25. model_name=model_name,
  26. model_kwargs=ModelKwargs(
  27. temperature=completion_params.get('temperature', 0),
  28. max_tokens=completion_params.get('max_tokens', 256),
  29. top_p=completion_params.get('top_p', 0),
  30. frequency_penalty=completion_params.get('frequency_penalty', 0.1),
  31. presence_penalty=completion_params.get('presence_penalty', 0.1)
  32. ),
  33. streaming=streaming,
  34. callbacks=callbacks
  35. )
  36. @classmethod
  37. def get_text_generation_model(cls,
  38. tenant_id: str,
  39. model_provider_name: Optional[str] = None,
  40. model_name: Optional[str] = None,
  41. model_kwargs: Optional[ModelKwargs] = None,
  42. streaming: bool = False,
  43. callbacks: Callbacks = None,
  44. deduct_quota: bool = True) -> Optional[BaseLLM]:
  45. """
  46. get text generation model.
  47. :param tenant_id: a string representing the ID of the tenant.
  48. :param model_provider_name:
  49. :param model_name:
  50. :param model_kwargs:
  51. :param streaming:
  52. :param callbacks:
  53. :param deduct_quota:
  54. :return:
  55. """
  56. is_default_model = False
  57. if model_provider_name is None and model_name is None:
  58. default_model = cls.get_default_model(tenant_id, ModelType.TEXT_GENERATION)
  59. if not default_model:
  60. raise LLMBadRequestError(f"Default model is not available. "
  61. f"Please configure a Default System Reasoning Model "
  62. f"in the Settings -> Model Provider.")
  63. model_provider_name = default_model.provider_name
  64. model_name = default_model.model_name
  65. is_default_model = True
  66. # get model provider
  67. model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
  68. if not model_provider:
  69. raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
  70. # init text generation model
  71. model_class = model_provider.get_model_class(model_type=ModelType.TEXT_GENERATION)
  72. try:
  73. model_instance = model_class(
  74. model_provider=model_provider,
  75. name=model_name,
  76. model_kwargs=model_kwargs,
  77. streaming=streaming,
  78. callbacks=callbacks
  79. )
  80. except LLMBadRequestError as e:
  81. if is_default_model:
  82. raise LLMBadRequestError(f"Default model {model_name} is not available. "
  83. f"Please check your model provider credentials.")
  84. else:
  85. raise e
  86. if is_default_model or not deduct_quota:
  87. model_instance.deduct_quota = False
  88. return model_instance
  89. @classmethod
  90. def get_embedding_model(cls,
  91. tenant_id: str,
  92. model_provider_name: Optional[str] = None,
  93. model_name: Optional[str] = None) -> Optional[BaseEmbedding]:
  94. """
  95. get embedding model.
  96. :param tenant_id: a string representing the ID of the tenant.
  97. :param model_provider_name:
  98. :param model_name:
  99. :return:
  100. """
  101. if model_provider_name is None and model_name is None:
  102. default_model = cls.get_default_model(tenant_id, ModelType.EMBEDDINGS)
  103. if not default_model:
  104. raise LLMBadRequestError(f"Default model is not available. "
  105. f"Please configure a Default Embedding Model "
  106. f"in the Settings -> Model Provider.")
  107. model_provider_name = default_model.provider_name
  108. model_name = default_model.model_name
  109. # get model provider
  110. model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
  111. if not model_provider:
  112. raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
  113. # init embedding model
  114. model_class = model_provider.get_model_class(model_type=ModelType.EMBEDDINGS)
  115. return model_class(
  116. model_provider=model_provider,
  117. name=model_name
  118. )
  119. @classmethod
  120. def get_speech2text_model(cls,
  121. tenant_id: str,
  122. model_provider_name: Optional[str] = None,
  123. model_name: Optional[str] = None) -> Optional[BaseSpeech2Text]:
  124. """
  125. get speech to text model.
  126. :param tenant_id: a string representing the ID of the tenant.
  127. :param model_provider_name:
  128. :param model_name:
  129. :return:
  130. """
  131. if model_provider_name is None and model_name is None:
  132. default_model = cls.get_default_model(tenant_id, ModelType.SPEECH_TO_TEXT)
  133. if not default_model:
  134. raise LLMBadRequestError(f"Default model is not available. "
  135. f"Please configure a Default Speech-to-Text Model "
  136. f"in the Settings -> Model Provider.")
  137. model_provider_name = default_model.provider_name
  138. model_name = default_model.model_name
  139. # get model provider
  140. model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
  141. if not model_provider:
  142. raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
  143. # init speech to text model
  144. model_class = model_provider.get_model_class(model_type=ModelType.SPEECH_TO_TEXT)
  145. return model_class(
  146. model_provider=model_provider,
  147. name=model_name
  148. )
  149. @classmethod
  150. def get_moderation_model(cls,
  151. tenant_id: str,
  152. model_provider_name: str,
  153. model_name: str) -> Optional[BaseModeration]:
  154. """
  155. get moderation model.
  156. :param tenant_id: a string representing the ID of the tenant.
  157. :param model_provider_name:
  158. :param model_name:
  159. :return:
  160. """
  161. # get model provider
  162. model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
  163. if not model_provider:
  164. raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
  165. # init moderation model
  166. model_class = model_provider.get_model_class(model_type=ModelType.MODERATION)
  167. return model_class(
  168. model_provider=model_provider,
  169. name=model_name
  170. )
  171. @classmethod
  172. def get_default_model(cls, tenant_id: str, model_type: ModelType) -> TenantDefaultModel:
  173. """
  174. get default model of model type.
  175. :param tenant_id:
  176. :param model_type:
  177. :return:
  178. """
  179. # get default model
  180. default_model = db.session.query(TenantDefaultModel) \
  181. .filter(
  182. TenantDefaultModel.tenant_id == tenant_id,
  183. TenantDefaultModel.model_type == model_type.value
  184. ).first()
  185. if not default_model:
  186. model_provider_rules = ModelProviderFactory.get_provider_rules()
  187. for model_provider_name, model_provider_rule in model_provider_rules.items():
  188. model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
  189. if not model_provider:
  190. continue
  191. model_list = model_provider.get_supported_model_list(model_type)
  192. if model_list:
  193. model_info = model_list[0]
  194. default_model = TenantDefaultModel(
  195. tenant_id=tenant_id,
  196. model_type=model_type.value,
  197. provider_name=model_provider_name,
  198. model_name=model_info['id']
  199. )
  200. db.session.add(default_model)
  201. db.session.commit()
  202. break
  203. return default_model
  204. @classmethod
  205. def update_default_model(cls,
  206. tenant_id: str,
  207. model_type: ModelType,
  208. provider_name: str,
  209. model_name: str) -> TenantDefaultModel:
  210. """
  211. update default model of model type.
  212. :param tenant_id:
  213. :param model_type:
  214. :param provider_name:
  215. :param model_name:
  216. :return:
  217. """
  218. model_provider_name = ModelProviderFactory.get_provider_names()
  219. if provider_name not in model_provider_name:
  220. raise ValueError(f'Invalid provider name: {provider_name}')
  221. model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, provider_name)
  222. if not model_provider:
  223. raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
  224. model_list = model_provider.get_supported_model_list(model_type)
  225. model_ids = [model['id'] for model in model_list]
  226. if model_name not in model_ids:
  227. raise ValueError(f'Invalid model name: {model_name}')
  228. # get default model
  229. default_model = db.session.query(TenantDefaultModel) \
  230. .filter(
  231. TenantDefaultModel.tenant_id == tenant_id,
  232. TenantDefaultModel.model_type == model_type.value
  233. ).first()
  234. if default_model:
  235. # update default model
  236. default_model.provider_name = provider_name
  237. default_model.model_name = model_name
  238. db.session.commit()
  239. else:
  240. # create default model
  241. default_model = TenantDefaultModel(
  242. tenant_id=tenant_id,
  243. model_type=model_type.value,
  244. provider_name=provider_name,
  245. model_name=model_name,
  246. )
  247. db.session.add(default_model)
  248. db.session.commit()
  249. return default_model