models.py 13 KB


  1. import logging
  2. from flask_login import current_user
  3. from flask_restful import Resource, reqparse
  4. from werkzeug.exceptions import Forbidden
  5. from controllers.console import api
  6. from controllers.console.wraps import account_initialization_required, setup_required
  7. from core.model_runtime.entities.model_entities import ModelType
  8. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  9. from core.model_runtime.utils.encoders import jsonable_encoder
  10. from libs.login import login_required
  11. from services.model_load_balancing_service import ModelLoadBalancingService
  12. from services.model_provider_service import ModelProviderService
  13. class DefaultModelApi(Resource):
  14. @setup_required
  15. @login_required
  16. @account_initialization_required
  17. def get(self):
  18. parser = reqparse.RequestParser()
  19. parser.add_argument(
  20. "model_type",
  21. type=str,
  22. required=True,
  23. nullable=False,
  24. choices=[mt.value for mt in ModelType],
  25. location="args",
  26. )
  27. args = parser.parse_args()
  28. tenant_id = current_user.current_tenant_id
  29. model_provider_service = ModelProviderService()
  30. default_model_entity = model_provider_service.get_default_model_of_model_type(
  31. tenant_id=tenant_id, model_type=args["model_type"]
  32. )
  33. return jsonable_encoder({"data": default_model_entity})
  34. @setup_required
  35. @login_required
  36. @account_initialization_required
  37. def post(self):
  38. if not current_user.is_admin_or_owner:
  39. raise Forbidden()
  40. parser = reqparse.RequestParser()
  41. parser.add_argument("model_settings", type=list, required=True, nullable=False, location="json")
  42. args = parser.parse_args()
  43. tenant_id = current_user.current_tenant_id
  44. model_provider_service = ModelProviderService()
  45. model_settings = args["model_settings"]
  46. for model_setting in model_settings:
  47. if "model_type" not in model_setting or model_setting["model_type"] not in [mt.value for mt in ModelType]:
  48. raise ValueError("invalid model type")
  49. if "provider" not in model_setting:
  50. continue
  51. if "model" not in model_setting:
  52. raise ValueError("invalid model")
  53. try:
  54. model_provider_service.update_default_model_of_model_type(
  55. tenant_id=tenant_id,
  56. model_type=model_setting["model_type"],
  57. provider=model_setting["provider"],
  58. model=model_setting["model"],
  59. )
  60. except Exception as ex:
  61. logging.exception(f"{model_setting['model_type']} save error: {ex}")
  62. raise ex
  63. return {"result": "success"}
  64. class ModelProviderModelApi(Resource):
  65. @setup_required
  66. @login_required
  67. @account_initialization_required
  68. def get(self, provider):
  69. tenant_id = current_user.current_tenant_id
  70. model_provider_service = ModelProviderService()
  71. models = model_provider_service.get_models_by_provider(tenant_id=tenant_id, provider=provider)
  72. return jsonable_encoder({"data": models})
  73. @setup_required
  74. @login_required
  75. @account_initialization_required
  76. def post(self, provider: str):
  77. if not current_user.is_admin_or_owner:
  78. raise Forbidden()
  79. tenant_id = current_user.current_tenant_id
  80. parser = reqparse.RequestParser()
  81. parser.add_argument("model", type=str, required=True, nullable=False, location="json")
  82. parser.add_argument(
  83. "model_type",
  84. type=str,
  85. required=True,
  86. nullable=False,
  87. choices=[mt.value for mt in ModelType],
  88. location="json",
  89. )
  90. parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
  91. parser.add_argument("load_balancing", type=dict, required=False, nullable=True, location="json")
  92. parser.add_argument("config_from", type=str, required=False, nullable=True, location="json")
  93. args = parser.parse_args()
  94. model_load_balancing_service = ModelLoadBalancingService()
  95. if (
  96. "load_balancing" in args
  97. and args["load_balancing"]
  98. and "enabled" in args["load_balancing"]
  99. and args["load_balancing"]["enabled"]
  100. ):
  101. if "configs" not in args["load_balancing"]:
  102. raise ValueError("invalid load balancing configs")
  103. # save load balancing configs
  104. model_load_balancing_service.update_load_balancing_configs(
  105. tenant_id=tenant_id,
  106. provider=provider,
  107. model=args["model"],
  108. model_type=args["model_type"],
  109. configs=args["load_balancing"]["configs"],
  110. )
  111. # enable load balancing
  112. model_load_balancing_service.enable_model_load_balancing(
  113. tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
  114. )
  115. else:
  116. # disable load balancing
  117. model_load_balancing_service.disable_model_load_balancing(
  118. tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
  119. )
  120. if args.get("config_from", "") != "predefined-model":
  121. model_provider_service = ModelProviderService()
  122. try:
  123. model_provider_service.save_model_credentials(
  124. tenant_id=tenant_id,
  125. provider=provider,
  126. model=args["model"],
  127. model_type=args["model_type"],
  128. credentials=args["credentials"],
  129. )
  130. except CredentialsValidateFailedError as ex:
  131. logging.exception(f"save model credentials error: {ex}")
  132. raise ValueError(str(ex))
  133. return {"result": "success"}, 200
  134. @setup_required
  135. @login_required
  136. @account_initialization_required
  137. def delete(self, provider: str):
  138. if not current_user.is_admin_or_owner:
  139. raise Forbidden()
  140. tenant_id = current_user.current_tenant_id
  141. parser = reqparse.RequestParser()
  142. parser.add_argument("model", type=str, required=True, nullable=False, location="json")
  143. parser.add_argument(
  144. "model_type",
  145. type=str,
  146. required=True,
  147. nullable=False,
  148. choices=[mt.value for mt in ModelType],
  149. location="json",
  150. )
  151. args = parser.parse_args()
  152. model_provider_service = ModelProviderService()
  153. model_provider_service.remove_model_credentials(
  154. tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
  155. )
  156. return {"result": "success"}, 204
  157. class ModelProviderModelCredentialApi(Resource):
  158. @setup_required
  159. @login_required
  160. @account_initialization_required
  161. def get(self, provider: str):
  162. tenant_id = current_user.current_tenant_id
  163. parser = reqparse.RequestParser()
  164. parser.add_argument("model", type=str, required=True, nullable=False, location="args")
  165. parser.add_argument(
  166. "model_type",
  167. type=str,
  168. required=True,
  169. nullable=False,
  170. choices=[mt.value for mt in ModelType],
  171. location="args",
  172. )
  173. args = parser.parse_args()
  174. model_provider_service = ModelProviderService()
  175. credentials = model_provider_service.get_model_credentials(
  176. tenant_id=tenant_id, provider=provider, model_type=args["model_type"], model=args["model"]
  177. )
  178. model_load_balancing_service = ModelLoadBalancingService()
  179. is_load_balancing_enabled, load_balancing_configs = model_load_balancing_service.get_load_balancing_configs(
  180. tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
  181. )
  182. return {
  183. "credentials": credentials,
  184. "load_balancing": {"enabled": is_load_balancing_enabled, "configs": load_balancing_configs},
  185. }
  186. class ModelProviderModelEnableApi(Resource):
  187. @setup_required
  188. @login_required
  189. @account_initialization_required
  190. def patch(self, provider: str):
  191. tenant_id = current_user.current_tenant_id
  192. parser = reqparse.RequestParser()
  193. parser.add_argument("model", type=str, required=True, nullable=False, location="json")
  194. parser.add_argument(
  195. "model_type",
  196. type=str,
  197. required=True,
  198. nullable=False,
  199. choices=[mt.value for mt in ModelType],
  200. location="json",
  201. )
  202. args = parser.parse_args()
  203. model_provider_service = ModelProviderService()
  204. model_provider_service.enable_model(
  205. tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
  206. )
  207. return {"result": "success"}
  208. class ModelProviderModelDisableApi(Resource):
  209. @setup_required
  210. @login_required
  211. @account_initialization_required
  212. def patch(self, provider: str):
  213. tenant_id = current_user.current_tenant_id
  214. parser = reqparse.RequestParser()
  215. parser.add_argument("model", type=str, required=True, nullable=False, location="json")
  216. parser.add_argument(
  217. "model_type",
  218. type=str,
  219. required=True,
  220. nullable=False,
  221. choices=[mt.value for mt in ModelType],
  222. location="json",
  223. )
  224. args = parser.parse_args()
  225. model_provider_service = ModelProviderService()
  226. model_provider_service.disable_model(
  227. tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
  228. )
  229. return {"result": "success"}
  230. class ModelProviderModelValidateApi(Resource):
  231. @setup_required
  232. @login_required
  233. @account_initialization_required
  234. def post(self, provider: str):
  235. tenant_id = current_user.current_tenant_id
  236. parser = reqparse.RequestParser()
  237. parser.add_argument("model", type=str, required=True, nullable=False, location="json")
  238. parser.add_argument(
  239. "model_type",
  240. type=str,
  241. required=True,
  242. nullable=False,
  243. choices=[mt.value for mt in ModelType],
  244. location="json",
  245. )
  246. parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
  247. args = parser.parse_args()
  248. model_provider_service = ModelProviderService()
  249. result = True
  250. error = None
  251. try:
  252. model_provider_service.model_credentials_validate(
  253. tenant_id=tenant_id,
  254. provider=provider,
  255. model=args["model"],
  256. model_type=args["model_type"],
  257. credentials=args["credentials"],
  258. )
  259. except CredentialsValidateFailedError as ex:
  260. result = False
  261. error = str(ex)
  262. response = {"result": "success" if result else "error"}
  263. if not result:
  264. response["error"] = error
  265. return response
  266. class ModelProviderModelParameterRuleApi(Resource):
  267. @setup_required
  268. @login_required
  269. @account_initialization_required
  270. def get(self, provider: str):
  271. parser = reqparse.RequestParser()
  272. parser.add_argument("model", type=str, required=True, nullable=False, location="args")
  273. args = parser.parse_args()
  274. tenant_id = current_user.current_tenant_id
  275. model_provider_service = ModelProviderService()
  276. parameter_rules = model_provider_service.get_model_parameter_rules(
  277. tenant_id=tenant_id, provider=provider, model=args["model"]
  278. )
  279. return jsonable_encoder({"data": parameter_rules})
  280. class ModelProviderAvailableModelApi(Resource):
  281. @setup_required
  282. @login_required
  283. @account_initialization_required
  284. def get(self, model_type):
  285. tenant_id = current_user.current_tenant_id
  286. model_provider_service = ModelProviderService()
  287. models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type)
  288. return jsonable_encoder({"data": models})
  289. api.add_resource(ModelProviderModelApi, "/workspaces/current/model-providers/<string:provider>/models")
  290. api.add_resource(
  291. ModelProviderModelEnableApi,
  292. "/workspaces/current/model-providers/<string:provider>/models/enable",
  293. endpoint="model-provider-model-enable",
  294. )
  295. api.add_resource(
  296. ModelProviderModelDisableApi,
  297. "/workspaces/current/model-providers/<string:provider>/models/disable",
  298. endpoint="model-provider-model-disable",
  299. )
  300. api.add_resource(
  301. ModelProviderModelCredentialApi, "/workspaces/current/model-providers/<string:provider>/models/credentials"
  302. )
  303. api.add_resource(
  304. ModelProviderModelValidateApi, "/workspaces/current/model-providers/<string:provider>/models/credentials/validate"
  305. )
  306. api.add_resource(
  307. ModelProviderModelParameterRuleApi, "/workspaces/current/model-providers/<string:provider>/models/parameter-rules"
  308. )
  309. api.add_resource(ModelProviderAvailableModelApi, "/workspaces/current/models/model-types/<string:model_type>")
  310. api.add_resource(DefaultModelApi, "/workspaces/current/default-model")