provider_configuration.py 41 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062
  1. import datetime
  2. import json
  3. import logging
  4. from collections import defaultdict
  5. from collections.abc import Iterator
  6. from json import JSONDecodeError
  7. from typing import Optional
  8. from pydantic import BaseModel, ConfigDict
  9. from constants import HIDDEN_VALUE
  10. from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
  11. from core.entities.provider_entities import (
  12. CustomConfiguration,
  13. ModelSettings,
  14. SystemConfiguration,
  15. SystemConfigurationStatus,
  16. )
  17. from core.helper import encrypter
  18. from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
  19. from core.model_runtime.entities.model_entities import FetchFrom, ModelType
  20. from core.model_runtime.entities.provider_entities import (
  21. ConfigurateMethod,
  22. CredentialFormSchema,
  23. FormType,
  24. ProviderEntity,
  25. )
  26. from core.model_runtime.model_providers import model_provider_factory
  27. from core.model_runtime.model_providers.__base.ai_model import AIModel
  28. from core.model_runtime.model_providers.__base.model_provider import ModelProvider
  29. from extensions.ext_database import db
  30. from models.provider import (
  31. LoadBalancingModelConfig,
  32. Provider,
  33. ProviderModel,
  34. ProviderModelSetting,
  35. ProviderType,
  36. TenantPreferredModelProvider,
  37. )
  38. logger = logging.getLogger(__name__)
  39. original_provider_configurate_methods = {}
  40. class ProviderConfiguration(BaseModel):
  41. """
  42. Model class for provider configuration.
  43. """
  44. tenant_id: str
  45. provider: ProviderEntity
  46. preferred_provider_type: ProviderType
  47. using_provider_type: ProviderType
  48. system_configuration: SystemConfiguration
  49. custom_configuration: CustomConfiguration
  50. model_settings: list[ModelSettings]
  51. # pydantic configs
  52. model_config = ConfigDict(protected_namespaces=())
  53. def __init__(self, **data):
  54. super().__init__(**data)
  55. if self.provider.provider not in original_provider_configurate_methods:
  56. original_provider_configurate_methods[self.provider.provider] = []
  57. for configurate_method in self.provider.configurate_methods:
  58. original_provider_configurate_methods[self.provider.provider].append(configurate_method)
  59. if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
  60. if (
  61. any(
  62. len(quota_configuration.restrict_models) > 0
  63. for quota_configuration in self.system_configuration.quota_configurations
  64. )
  65. and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods
  66. ):
  67. self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL)
  68. def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]:
  69. """
  70. Get current credentials.
  71. :param model_type: model type
  72. :param model: model name
  73. :return:
  74. """
  75. if self.model_settings:
  76. # check if model is disabled by admin
  77. for model_setting in self.model_settings:
  78. if model_setting.model_type == model_type and model_setting.model == model:
  79. if not model_setting.enabled:
  80. raise ValueError(f"Model {model} is disabled.")
  81. if self.using_provider_type == ProviderType.SYSTEM:
  82. restrict_models = []
  83. for quota_configuration in self.system_configuration.quota_configurations:
  84. if self.system_configuration.current_quota_type != quota_configuration.quota_type:
  85. continue
  86. restrict_models = quota_configuration.restrict_models
  87. copy_credentials = self.system_configuration.credentials.copy()
  88. if restrict_models:
  89. for restrict_model in restrict_models:
  90. if (
  91. restrict_model.model_type == model_type
  92. and restrict_model.model == model
  93. and restrict_model.base_model_name
  94. ):
  95. copy_credentials["base_model_name"] = restrict_model.base_model_name
  96. return copy_credentials
  97. else:
  98. credentials = None
  99. if self.custom_configuration.models:
  100. for model_configuration in self.custom_configuration.models:
  101. if model_configuration.model_type == model_type and model_configuration.model == model:
  102. credentials = model_configuration.credentials
  103. break
  104. if not credentials and self.custom_configuration.provider:
  105. credentials = self.custom_configuration.provider.credentials
  106. return credentials
  107. def get_system_configuration_status(self) -> SystemConfigurationStatus:
  108. """
  109. Get system configuration status.
  110. :return:
  111. """
  112. if self.system_configuration.enabled is False:
  113. return SystemConfigurationStatus.UNSUPPORTED
  114. current_quota_type = self.system_configuration.current_quota_type
  115. current_quota_configuration = next(
  116. (q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type), None
  117. )
  118. return (
  119. SystemConfigurationStatus.ACTIVE
  120. if current_quota_configuration.is_valid
  121. else SystemConfigurationStatus.QUOTA_EXCEEDED
  122. )
  123. def is_custom_configuration_available(self) -> bool:
  124. """
  125. Check custom configuration available.
  126. :return:
  127. """
  128. return self.custom_configuration.provider is not None or len(self.custom_configuration.models) > 0
  129. def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]:
  130. """
  131. Get custom credentials.
  132. :param obfuscated: obfuscated secret data in credentials
  133. :return:
  134. """
  135. if self.custom_configuration.provider is None:
  136. return None
  137. credentials = self.custom_configuration.provider.credentials
  138. if not obfuscated:
  139. return credentials
  140. # Obfuscate credentials
  141. return self.obfuscated_credentials(
  142. credentials=credentials,
  143. credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas
  144. if self.provider.provider_credential_schema
  145. else [],
  146. )
  147. def custom_credentials_validate(self, credentials: dict) -> tuple[Provider, dict]:
  148. """
  149. Validate custom credentials.
  150. :param credentials: provider credentials
  151. :return:
  152. """
  153. # get provider
  154. provider_record = (
  155. db.session.query(Provider)
  156. .filter(
  157. Provider.tenant_id == self.tenant_id,
  158. Provider.provider_name == self.provider.provider,
  159. Provider.provider_type == ProviderType.CUSTOM.value,
  160. )
  161. .first()
  162. )
  163. # Get provider credential secret variables
  164. provider_credential_secret_variables = self.extract_secret_variables(
  165. self.provider.provider_credential_schema.credential_form_schemas
  166. if self.provider.provider_credential_schema
  167. else []
  168. )
  169. if provider_record:
  170. try:
  171. # fix origin data
  172. if provider_record.encrypted_config:
  173. if not provider_record.encrypted_config.startswith("{"):
  174. original_credentials = {"openai_api_key": provider_record.encrypted_config}
  175. else:
  176. original_credentials = json.loads(provider_record.encrypted_config)
  177. else:
  178. original_credentials = {}
  179. except JSONDecodeError:
  180. original_credentials = {}
  181. # encrypt credentials
  182. for key, value in credentials.items():
  183. if key in provider_credential_secret_variables:
  184. # if send [__HIDDEN__] in secret input, it will be same as original value
  185. if value == HIDDEN_VALUE and key in original_credentials:
  186. credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
  187. credentials = model_provider_factory.provider_credentials_validate(
  188. provider=self.provider.provider, credentials=credentials
  189. )
  190. for key, value in credentials.items():
  191. if key in provider_credential_secret_variables:
  192. credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
  193. return provider_record, credentials
  194. def add_or_update_custom_credentials(self, credentials: dict) -> None:
  195. """
  196. Add or update custom provider credentials.
  197. :param credentials:
  198. :return:
  199. """
  200. # validate custom provider config
  201. provider_record, credentials = self.custom_credentials_validate(credentials)
  202. # save provider
  203. # Note: Do not switch the preferred provider, which allows users to use quotas first
  204. if provider_record:
  205. provider_record.encrypted_config = json.dumps(credentials)
  206. provider_record.is_valid = True
  207. provider_record.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  208. db.session.commit()
  209. else:
  210. provider_record = Provider(
  211. tenant_id=self.tenant_id,
  212. provider_name=self.provider.provider,
  213. provider_type=ProviderType.CUSTOM.value,
  214. encrypted_config=json.dumps(credentials),
  215. is_valid=True,
  216. )
  217. db.session.add(provider_record)
  218. db.session.commit()
  219. provider_model_credentials_cache = ProviderCredentialsCache(
  220. tenant_id=self.tenant_id, identity_id=provider_record.id, cache_type=ProviderCredentialsCacheType.PROVIDER
  221. )
  222. provider_model_credentials_cache.delete()
  223. self.switch_preferred_provider_type(ProviderType.CUSTOM)
  224. def delete_custom_credentials(self) -> None:
  225. """
  226. Delete custom provider credentials.
  227. :return:
  228. """
  229. # get provider
  230. provider_record = (
  231. db.session.query(Provider)
  232. .filter(
  233. Provider.tenant_id == self.tenant_id,
  234. Provider.provider_name == self.provider.provider,
  235. Provider.provider_type == ProviderType.CUSTOM.value,
  236. )
  237. .first()
  238. )
  239. # delete provider
  240. if provider_record:
  241. self.switch_preferred_provider_type(ProviderType.SYSTEM)
  242. db.session.delete(provider_record)
  243. db.session.commit()
  244. provider_model_credentials_cache = ProviderCredentialsCache(
  245. tenant_id=self.tenant_id,
  246. identity_id=provider_record.id,
  247. cache_type=ProviderCredentialsCacheType.PROVIDER,
  248. )
  249. provider_model_credentials_cache.delete()
  250. def get_custom_model_credentials(
  251. self, model_type: ModelType, model: str, obfuscated: bool = False
  252. ) -> Optional[dict]:
  253. """
  254. Get custom model credentials.
  255. :param model_type: model type
  256. :param model: model name
  257. :param obfuscated: obfuscated secret data in credentials
  258. :return:
  259. """
  260. if not self.custom_configuration.models:
  261. return None
  262. for model_configuration in self.custom_configuration.models:
  263. if model_configuration.model_type == model_type and model_configuration.model == model:
  264. credentials = model_configuration.credentials
  265. if not obfuscated:
  266. return credentials
  267. # Obfuscate credentials
  268. return self.obfuscated_credentials(
  269. credentials=credentials,
  270. credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas
  271. if self.provider.model_credential_schema
  272. else [],
  273. )
  274. return None
  275. def custom_model_credentials_validate(
  276. self, model_type: ModelType, model: str, credentials: dict
  277. ) -> tuple[ProviderModel, dict]:
  278. """
  279. Validate custom model credentials.
  280. :param model_type: model type
  281. :param model: model name
  282. :param credentials: model credentials
  283. :return:
  284. """
  285. # get provider model
  286. provider_model_record = (
  287. db.session.query(ProviderModel)
  288. .filter(
  289. ProviderModel.tenant_id == self.tenant_id,
  290. ProviderModel.provider_name == self.provider.provider,
  291. ProviderModel.model_name == model,
  292. ProviderModel.model_type == model_type.to_origin_model_type(),
  293. )
  294. .first()
  295. )
  296. # Get provider credential secret variables
  297. provider_credential_secret_variables = self.extract_secret_variables(
  298. self.provider.model_credential_schema.credential_form_schemas
  299. if self.provider.model_credential_schema
  300. else []
  301. )
  302. if provider_model_record:
  303. try:
  304. original_credentials = (
  305. json.loads(provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {}
  306. )
  307. except JSONDecodeError:
  308. original_credentials = {}
  309. # decrypt credentials
  310. for key, value in credentials.items():
  311. if key in provider_credential_secret_variables:
  312. # if send [__HIDDEN__] in secret input, it will be same as original value
  313. if value == HIDDEN_VALUE and key in original_credentials:
  314. credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
  315. credentials = model_provider_factory.model_credentials_validate(
  316. provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
  317. )
  318. for key, value in credentials.items():
  319. if key in provider_credential_secret_variables:
  320. credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
  321. return provider_model_record, credentials
  322. def add_or_update_custom_model_credentials(self, model_type: ModelType, model: str, credentials: dict) -> None:
  323. """
  324. Add or update custom model credentials.
  325. :param model_type: model type
  326. :param model: model name
  327. :param credentials: model credentials
  328. :return:
  329. """
  330. # validate custom model config
  331. provider_model_record, credentials = self.custom_model_credentials_validate(model_type, model, credentials)
  332. # save provider model
  333. # Note: Do not switch the preferred provider, which allows users to use quotas first
  334. if provider_model_record:
  335. provider_model_record.encrypted_config = json.dumps(credentials)
  336. provider_model_record.is_valid = True
  337. provider_model_record.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  338. db.session.commit()
  339. else:
  340. provider_model_record = ProviderModel(
  341. tenant_id=self.tenant_id,
  342. provider_name=self.provider.provider,
  343. model_name=model,
  344. model_type=model_type.to_origin_model_type(),
  345. encrypted_config=json.dumps(credentials),
  346. is_valid=True,
  347. )
  348. db.session.add(provider_model_record)
  349. db.session.commit()
  350. provider_model_credentials_cache = ProviderCredentialsCache(
  351. tenant_id=self.tenant_id,
  352. identity_id=provider_model_record.id,
  353. cache_type=ProviderCredentialsCacheType.MODEL,
  354. )
  355. provider_model_credentials_cache.delete()
  356. def delete_custom_model_credentials(self, model_type: ModelType, model: str) -> None:
  357. """
  358. Delete custom model credentials.
  359. :param model_type: model type
  360. :param model: model name
  361. :return:
  362. """
  363. # get provider model
  364. provider_model_record = (
  365. db.session.query(ProviderModel)
  366. .filter(
  367. ProviderModel.tenant_id == self.tenant_id,
  368. ProviderModel.provider_name == self.provider.provider,
  369. ProviderModel.model_name == model,
  370. ProviderModel.model_type == model_type.to_origin_model_type(),
  371. )
  372. .first()
  373. )
  374. # delete provider model
  375. if provider_model_record:
  376. db.session.delete(provider_model_record)
  377. db.session.commit()
  378. provider_model_credentials_cache = ProviderCredentialsCache(
  379. tenant_id=self.tenant_id,
  380. identity_id=provider_model_record.id,
  381. cache_type=ProviderCredentialsCacheType.MODEL,
  382. )
  383. provider_model_credentials_cache.delete()
  384. def enable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting:
  385. """
  386. Enable model.
  387. :param model_type: model type
  388. :param model: model name
  389. :return:
  390. """
  391. model_setting = (
  392. db.session.query(ProviderModelSetting)
  393. .filter(
  394. ProviderModelSetting.tenant_id == self.tenant_id,
  395. ProviderModelSetting.provider_name == self.provider.provider,
  396. ProviderModelSetting.model_type == model_type.to_origin_model_type(),
  397. ProviderModelSetting.model_name == model,
  398. )
  399. .first()
  400. )
  401. if model_setting:
  402. model_setting.enabled = True
  403. model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  404. db.session.commit()
  405. else:
  406. model_setting = ProviderModelSetting(
  407. tenant_id=self.tenant_id,
  408. provider_name=self.provider.provider,
  409. model_type=model_type.to_origin_model_type(),
  410. model_name=model,
  411. enabled=True,
  412. )
  413. db.session.add(model_setting)
  414. db.session.commit()
  415. return model_setting
  416. def disable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting:
  417. """
  418. Disable model.
  419. :param model_type: model type
  420. :param model: model name
  421. :return:
  422. """
  423. model_setting = (
  424. db.session.query(ProviderModelSetting)
  425. .filter(
  426. ProviderModelSetting.tenant_id == self.tenant_id,
  427. ProviderModelSetting.provider_name == self.provider.provider,
  428. ProviderModelSetting.model_type == model_type.to_origin_model_type(),
  429. ProviderModelSetting.model_name == model,
  430. )
  431. .first()
  432. )
  433. if model_setting:
  434. model_setting.enabled = False
  435. model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  436. db.session.commit()
  437. else:
  438. model_setting = ProviderModelSetting(
  439. tenant_id=self.tenant_id,
  440. provider_name=self.provider.provider,
  441. model_type=model_type.to_origin_model_type(),
  442. model_name=model,
  443. enabled=False,
  444. )
  445. db.session.add(model_setting)
  446. db.session.commit()
  447. return model_setting
  448. def get_provider_model_setting(self, model_type: ModelType, model: str) -> Optional[ProviderModelSetting]:
  449. """
  450. Get provider model setting.
  451. :param model_type: model type
  452. :param model: model name
  453. :return:
  454. """
  455. return (
  456. db.session.query(ProviderModelSetting)
  457. .filter(
  458. ProviderModelSetting.tenant_id == self.tenant_id,
  459. ProviderModelSetting.provider_name == self.provider.provider,
  460. ProviderModelSetting.model_type == model_type.to_origin_model_type(),
  461. ProviderModelSetting.model_name == model,
  462. )
  463. .first()
  464. )
  465. def enable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting:
  466. """
  467. Enable model load balancing.
  468. :param model_type: model type
  469. :param model: model name
  470. :return:
  471. """
  472. load_balancing_config_count = (
  473. db.session.query(LoadBalancingModelConfig)
  474. .filter(
  475. LoadBalancingModelConfig.tenant_id == self.tenant_id,
  476. LoadBalancingModelConfig.provider_name == self.provider.provider,
  477. LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
  478. LoadBalancingModelConfig.model_name == model,
  479. )
  480. .count()
  481. )
  482. if load_balancing_config_count <= 1:
  483. raise ValueError("Model load balancing configuration must be more than 1.")
  484. model_setting = (
  485. db.session.query(ProviderModelSetting)
  486. .filter(
  487. ProviderModelSetting.tenant_id == self.tenant_id,
  488. ProviderModelSetting.provider_name == self.provider.provider,
  489. ProviderModelSetting.model_type == model_type.to_origin_model_type(),
  490. ProviderModelSetting.model_name == model,
  491. )
  492. .first()
  493. )
  494. if model_setting:
  495. model_setting.load_balancing_enabled = True
  496. model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  497. db.session.commit()
  498. else:
  499. model_setting = ProviderModelSetting(
  500. tenant_id=self.tenant_id,
  501. provider_name=self.provider.provider,
  502. model_type=model_type.to_origin_model_type(),
  503. model_name=model,
  504. load_balancing_enabled=True,
  505. )
  506. db.session.add(model_setting)
  507. db.session.commit()
  508. return model_setting
  509. def disable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting:
  510. """
  511. Disable model load balancing.
  512. :param model_type: model type
  513. :param model: model name
  514. :return:
  515. """
  516. model_setting = (
  517. db.session.query(ProviderModelSetting)
  518. .filter(
  519. ProviderModelSetting.tenant_id == self.tenant_id,
  520. ProviderModelSetting.provider_name == self.provider.provider,
  521. ProviderModelSetting.model_type == model_type.to_origin_model_type(),
  522. ProviderModelSetting.model_name == model,
  523. )
  524. .first()
  525. )
  526. if model_setting:
  527. model_setting.load_balancing_enabled = False
  528. model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  529. db.session.commit()
  530. else:
  531. model_setting = ProviderModelSetting(
  532. tenant_id=self.tenant_id,
  533. provider_name=self.provider.provider,
  534. model_type=model_type.to_origin_model_type(),
  535. model_name=model,
  536. load_balancing_enabled=False,
  537. )
  538. db.session.add(model_setting)
  539. db.session.commit()
  540. return model_setting
  541. def get_provider_instance(self) -> ModelProvider:
  542. """
  543. Get provider instance.
  544. :return:
  545. """
  546. return model_provider_factory.get_provider_instance(self.provider.provider)
  547. def get_model_type_instance(self, model_type: ModelType) -> AIModel:
  548. """
  549. Get current model type instance.
  550. :param model_type: model type
  551. :return:
  552. """
  553. # Get provider instance
  554. provider_instance = self.get_provider_instance()
  555. # Get model instance of LLM
  556. return provider_instance.get_model_instance(model_type)
  557. def switch_preferred_provider_type(self, provider_type: ProviderType) -> None:
  558. """
  559. Switch preferred provider type.
  560. :param provider_type:
  561. :return:
  562. """
  563. if provider_type == self.preferred_provider_type:
  564. return
  565. if provider_type == ProviderType.SYSTEM and not self.system_configuration.enabled:
  566. return
  567. # get preferred provider
  568. preferred_model_provider = (
  569. db.session.query(TenantPreferredModelProvider)
  570. .filter(
  571. TenantPreferredModelProvider.tenant_id == self.tenant_id,
  572. TenantPreferredModelProvider.provider_name == self.provider.provider,
  573. )
  574. .first()
  575. )
  576. if preferred_model_provider:
  577. preferred_model_provider.preferred_provider_type = provider_type.value
  578. else:
  579. preferred_model_provider = TenantPreferredModelProvider(
  580. tenant_id=self.tenant_id,
  581. provider_name=self.provider.provider,
  582. preferred_provider_type=provider_type.value,
  583. )
  584. db.session.add(preferred_model_provider)
  585. db.session.commit()
  586. def extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]:
  587. """
  588. Extract secret input form variables.
  589. :param credential_form_schemas:
  590. :return:
  591. """
  592. secret_input_form_variables = []
  593. for credential_form_schema in credential_form_schemas:
  594. if credential_form_schema.type == FormType.SECRET_INPUT:
  595. secret_input_form_variables.append(credential_form_schema.variable)
  596. return secret_input_form_variables
  597. def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]) -> dict:
  598. """
  599. Obfuscated credentials.
  600. :param credentials: credentials
  601. :param credential_form_schemas: credential form schemas
  602. :return:
  603. """
  604. # Get provider credential secret variables
  605. credential_secret_variables = self.extract_secret_variables(credential_form_schemas)
  606. # Obfuscate provider credentials
  607. copy_credentials = credentials.copy()
  608. for key, value in copy_credentials.items():
  609. if key in credential_secret_variables:
  610. copy_credentials[key] = encrypter.obfuscated_token(value)
  611. return copy_credentials
  612. def get_provider_model(
  613. self, model_type: ModelType, model: str, only_active: bool = False
  614. ) -> Optional[ModelWithProviderEntity]:
  615. """
  616. Get provider model.
  617. :param model_type: model type
  618. :param model: model name
  619. :param only_active: return active model only
  620. :return:
  621. """
  622. provider_models = self.get_provider_models(model_type, only_active)
  623. for provider_model in provider_models:
  624. if provider_model.model == model:
  625. return provider_model
  626. return None
  627. def get_provider_models(
  628. self, model_type: Optional[ModelType] = None, only_active: bool = False
  629. ) -> list[ModelWithProviderEntity]:
  630. """
  631. Get provider models.
  632. :param model_type: model type
  633. :param only_active: only active models
  634. :return:
  635. """
  636. provider_instance = self.get_provider_instance()
  637. model_types = []
  638. if model_type:
  639. model_types.append(model_type)
  640. else:
  641. model_types = provider_instance.get_provider_schema().supported_model_types
  642. # Group model settings by model type and model
  643. model_setting_map = defaultdict(dict)
  644. for model_setting in self.model_settings:
  645. model_setting_map[model_setting.model_type][model_setting.model] = model_setting
  646. if self.using_provider_type == ProviderType.SYSTEM:
  647. provider_models = self._get_system_provider_models(
  648. model_types=model_types, provider_instance=provider_instance, model_setting_map=model_setting_map
  649. )
  650. else:
  651. provider_models = self._get_custom_provider_models(
  652. model_types=model_types, provider_instance=provider_instance, model_setting_map=model_setting_map
  653. )
  654. if only_active:
  655. provider_models = [m for m in provider_models if m.status == ModelStatus.ACTIVE]
  656. # resort provider_models
  657. return sorted(provider_models, key=lambda x: x.model_type.value)
  658. def _get_system_provider_models(
  659. self,
  660. model_types: list[ModelType],
  661. provider_instance: ModelProvider,
  662. model_setting_map: dict[ModelType, dict[str, ModelSettings]],
  663. ) -> list[ModelWithProviderEntity]:
  664. """
  665. Get system provider models.
  666. :param model_types: model types
  667. :param provider_instance: provider instance
  668. :param model_setting_map: model setting map
  669. :return:
  670. """
  671. provider_models = []
  672. for model_type in model_types:
  673. for m in provider_instance.models(model_type):
  674. status = ModelStatus.ACTIVE
  675. if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]:
  676. model_setting = model_setting_map[m.model_type][m.model]
  677. if model_setting.enabled is False:
  678. status = ModelStatus.DISABLED
  679. provider_models.append(
  680. ModelWithProviderEntity(
  681. model=m.model,
  682. label=m.label,
  683. model_type=m.model_type,
  684. features=m.features,
  685. fetch_from=m.fetch_from,
  686. model_properties=m.model_properties,
  687. deprecated=m.deprecated,
  688. provider=SimpleModelProviderEntity(self.provider),
  689. status=status,
  690. )
  691. )
  692. if self.provider.provider not in original_provider_configurate_methods:
  693. original_provider_configurate_methods[self.provider.provider] = []
  694. for configurate_method in provider_instance.get_provider_schema().configurate_methods:
  695. original_provider_configurate_methods[self.provider.provider].append(configurate_method)
  696. should_use_custom_model = False
  697. if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
  698. should_use_custom_model = True
  699. for quota_configuration in self.system_configuration.quota_configurations:
  700. if self.system_configuration.current_quota_type != quota_configuration.quota_type:
  701. continue
  702. restrict_models = quota_configuration.restrict_models
  703. if len(restrict_models) == 0:
  704. break
  705. if should_use_custom_model:
  706. if original_provider_configurate_methods[self.provider.provider] == [
  707. ConfigurateMethod.CUSTOMIZABLE_MODEL
  708. ]:
  709. # only customizable model
  710. for restrict_model in restrict_models:
  711. copy_credentials = self.system_configuration.credentials.copy()
  712. if restrict_model.base_model_name:
  713. copy_credentials["base_model_name"] = restrict_model.base_model_name
  714. try:
  715. custom_model_schema = provider_instance.get_model_instance(
  716. restrict_model.model_type
  717. ).get_customizable_model_schema_from_credentials(restrict_model.model, copy_credentials)
  718. except Exception as ex:
  719. logger.warning(f"get custom model schema failed, {ex}")
  720. continue
  721. if not custom_model_schema:
  722. continue
  723. if custom_model_schema.model_type not in model_types:
  724. continue
  725. status = ModelStatus.ACTIVE
  726. if (
  727. custom_model_schema.model_type in model_setting_map
  728. and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]
  729. ):
  730. model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model]
  731. if model_setting.enabled is False:
  732. status = ModelStatus.DISABLED
  733. provider_models.append(
  734. ModelWithProviderEntity(
  735. model=custom_model_schema.model,
  736. label=custom_model_schema.label,
  737. model_type=custom_model_schema.model_type,
  738. features=custom_model_schema.features,
  739. fetch_from=FetchFrom.PREDEFINED_MODEL,
  740. model_properties=custom_model_schema.model_properties,
  741. deprecated=custom_model_schema.deprecated,
  742. provider=SimpleModelProviderEntity(self.provider),
  743. status=status,
  744. )
  745. )
  746. # if llm name not in restricted llm list, remove it
  747. restrict_model_names = [rm.model for rm in restrict_models]
  748. for m in provider_models:
  749. if m.model_type == ModelType.LLM and m.model not in restrict_model_names:
  750. m.status = ModelStatus.NO_PERMISSION
  751. elif not quota_configuration.is_valid:
  752. m.status = ModelStatus.QUOTA_EXCEEDED
  753. return provider_models
  754. def _get_custom_provider_models(
  755. self,
  756. model_types: list[ModelType],
  757. provider_instance: ModelProvider,
  758. model_setting_map: dict[ModelType, dict[str, ModelSettings]],
  759. ) -> list[ModelWithProviderEntity]:
  760. """
  761. Get custom provider models.
  762. :param model_types: model types
  763. :param provider_instance: provider instance
  764. :param model_setting_map: model setting map
  765. :return:
  766. """
  767. provider_models = []
  768. credentials = None
  769. if self.custom_configuration.provider:
  770. credentials = self.custom_configuration.provider.credentials
  771. for model_type in model_types:
  772. if model_type not in self.provider.supported_model_types:
  773. continue
  774. models = provider_instance.models(model_type)
  775. for m in models:
  776. status = ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE
  777. load_balancing_enabled = False
  778. if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]:
  779. model_setting = model_setting_map[m.model_type][m.model]
  780. if model_setting.enabled is False:
  781. status = ModelStatus.DISABLED
  782. if len(model_setting.load_balancing_configs) > 1:
  783. load_balancing_enabled = True
  784. provider_models.append(
  785. ModelWithProviderEntity(
  786. model=m.model,
  787. label=m.label,
  788. model_type=m.model_type,
  789. features=m.features,
  790. fetch_from=m.fetch_from,
  791. model_properties=m.model_properties,
  792. deprecated=m.deprecated,
  793. provider=SimpleModelProviderEntity(self.provider),
  794. status=status,
  795. load_balancing_enabled=load_balancing_enabled,
  796. )
  797. )
  798. # custom models
  799. for model_configuration in self.custom_configuration.models:
  800. if model_configuration.model_type not in model_types:
  801. continue
  802. try:
  803. custom_model_schema = provider_instance.get_model_instance(
  804. model_configuration.model_type
  805. ).get_customizable_model_schema_from_credentials(
  806. model_configuration.model, model_configuration.credentials
  807. )
  808. except Exception as ex:
  809. logger.warning(f"get custom model schema failed, {ex}")
  810. continue
  811. if not custom_model_schema:
  812. continue
  813. status = ModelStatus.ACTIVE
  814. load_balancing_enabled = False
  815. if (
  816. custom_model_schema.model_type in model_setting_map
  817. and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]
  818. ):
  819. model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model]
  820. if model_setting.enabled is False:
  821. status = ModelStatus.DISABLED
  822. if len(model_setting.load_balancing_configs) > 1:
  823. load_balancing_enabled = True
  824. provider_models.append(
  825. ModelWithProviderEntity(
  826. model=custom_model_schema.model,
  827. label=custom_model_schema.label,
  828. model_type=custom_model_schema.model_type,
  829. features=custom_model_schema.features,
  830. fetch_from=custom_model_schema.fetch_from,
  831. model_properties=custom_model_schema.model_properties,
  832. deprecated=custom_model_schema.deprecated,
  833. provider=SimpleModelProviderEntity(self.provider),
  834. status=status,
  835. load_balancing_enabled=load_balancing_enabled,
  836. )
  837. )
  838. return provider_models
  839. class ProviderConfigurations(BaseModel):
  840. """
  841. Model class for provider configuration dict.
  842. """
  843. tenant_id: str
  844. configurations: dict[str, ProviderConfiguration] = {}
  845. def __init__(self, tenant_id: str):
  846. super().__init__(tenant_id=tenant_id)
  847. def get_models(
  848. self, provider: Optional[str] = None, model_type: Optional[ModelType] = None, only_active: bool = False
  849. ) -> list[ModelWithProviderEntity]:
  850. """
  851. Get available models.
  852. If preferred provider type is `system`:
  853. Get the current **system mode** if provider supported,
  854. if all system modes are not available (no quota), it is considered to be the **custom credential mode**.
  855. If there is no model configured in custom mode, it is treated as no_configure.
  856. system > custom > no_configure
  857. If preferred provider type is `custom`:
  858. If custom credentials are configured, it is treated as custom mode.
  859. Otherwise, get the current **system mode** if supported,
  860. If all system modes are not available (no quota), it is treated as no_configure.
  861. custom > system > no_configure
  862. If real mode is `system`, use system credentials to get models,
  863. paid quotas > provider free quotas > system free quotas
  864. include pre-defined models (exclude GPT-4, status marked as `no_permission`).
  865. If real mode is `custom`, use workspace custom credentials to get models,
  866. include pre-defined models, custom models(manual append).
  867. If real mode is `no_configure`, only return pre-defined models from `model runtime`.
  868. (model status marked as `no_configure` if preferred provider type is `custom` otherwise `quota_exceeded`)
  869. model status marked as `active` is available.
  870. :param provider: provider name
  871. :param model_type: model type
  872. :param only_active: only active models
  873. :return:
  874. """
  875. all_models = []
  876. for provider_configuration in self.values():
  877. if provider and provider_configuration.provider.provider != provider:
  878. continue
  879. all_models.extend(provider_configuration.get_provider_models(model_type, only_active))
  880. return all_models
  881. def to_list(self) -> list[ProviderConfiguration]:
  882. """
  883. Convert to list.
  884. :return:
  885. """
  886. return list(self.values())
  887. def __getitem__(self, key):
  888. return self.configurations[key]
  889. def __setitem__(self, key, value):
  890. self.configurations[key] = value
  891. def __iter__(self):
  892. return iter(self.configurations)
  893. def values(self) -> Iterator[ProviderConfiguration]:
  894. return self.configurations.values()
  895. def get(self, key, default=None):
  896. return self.configurations.get(key, default)
  897. class ProviderModelBundle(BaseModel):
  898. """
  899. Provider model bundle.
  900. """
  901. configuration: ProviderConfiguration
  902. provider_instance: ModelProvider
  903. model_type_instance: AIModel
  904. # pydantic configs
  905. model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())