provider_configuration.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783
  1. import datetime
  2. import json
  3. import logging
  4. from json import JSONDecodeError
  5. from typing import Optional, List, Dict, Tuple, Iterator
  6. from pydantic import BaseModel
  7. from core.entities.model_entities import ModelWithProviderEntity, ModelStatus, SimpleModelProviderEntity
  8. from core.entities.provider_entities import SystemConfiguration, CustomConfiguration, SystemConfigurationStatus
  9. from core.helper import encrypter
  10. from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
  11. from core.model_runtime.entities.model_entities import ModelType, FetchFrom
  12. from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType, \
  13. ConfigurateMethod
  14. from core.model_runtime.model_providers import model_provider_factory
  15. from core.model_runtime.model_providers.__base.ai_model import AIModel
  16. from core.model_runtime.model_providers.__base.model_provider import ModelProvider
  17. from core.model_runtime.utils import encoders
  18. from extensions.ext_database import db
  19. from models.provider import ProviderType, Provider, ProviderModel, TenantPreferredModelProvider
  20. logger = logging.getLogger(__name__)
  21. original_provider_configurate_methods = {}
  22. class ProviderConfiguration(BaseModel):
  23. """
  24. Model class for provider configuration.
  25. """
  26. tenant_id: str
  27. provider: ProviderEntity
  28. preferred_provider_type: ProviderType
  29. using_provider_type: ProviderType
  30. system_configuration: SystemConfiguration
  31. custom_configuration: CustomConfiguration
  32. def __init__(self, **data):
  33. super().__init__(**data)
  34. if self.provider.provider not in original_provider_configurate_methods:
  35. original_provider_configurate_methods[self.provider.provider] = []
  36. for configurate_method in self.provider.configurate_methods:
  37. original_provider_configurate_methods[self.provider.provider].append(configurate_method)
  38. if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
  39. if (any([len(quota_configuration.restrict_models) > 0
  40. for quota_configuration in self.system_configuration.quota_configurations])
  41. and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods):
  42. self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL)
  43. def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]:
  44. """
  45. Get current credentials.
  46. :param model_type: model type
  47. :param model: model name
  48. :return:
  49. """
  50. if self.using_provider_type == ProviderType.SYSTEM:
  51. return self.system_configuration.credentials
  52. else:
  53. if self.custom_configuration.models:
  54. for model_configuration in self.custom_configuration.models:
  55. if model_configuration.model_type == model_type and model_configuration.model == model:
  56. return model_configuration.credentials
  57. if self.custom_configuration.provider:
  58. return self.custom_configuration.provider.credentials
  59. else:
  60. return None
  61. def get_system_configuration_status(self) -> SystemConfigurationStatus:
  62. """
  63. Get system configuration status.
  64. :return:
  65. """
  66. if self.system_configuration.enabled is False:
  67. return SystemConfigurationStatus.UNSUPPORTED
  68. current_quota_type = self.system_configuration.current_quota_type
  69. current_quota_configuration = next(
  70. (q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type),
  71. None
  72. )
  73. return SystemConfigurationStatus.ACTIVE if current_quota_configuration.is_valid else \
  74. SystemConfigurationStatus.QUOTA_EXCEEDED
  75. def is_custom_configuration_available(self) -> bool:
  76. """
  77. Check custom configuration available.
  78. :return:
  79. """
  80. return (self.custom_configuration.provider is not None
  81. or len(self.custom_configuration.models) > 0)
  82. def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]:
  83. """
  84. Get custom credentials.
  85. :param obfuscated: obfuscated secret data in credentials
  86. :return:
  87. """
  88. if self.custom_configuration.provider is None:
  89. return None
  90. credentials = self.custom_configuration.provider.credentials
  91. if not obfuscated:
  92. return credentials
  93. # Obfuscate credentials
  94. return self._obfuscated_credentials(
  95. credentials=credentials,
  96. credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas
  97. if self.provider.provider_credential_schema else []
  98. )
  99. def custom_credentials_validate(self, credentials: dict) -> Tuple[Provider, dict]:
  100. """
  101. Validate custom credentials.
  102. :param credentials: provider credentials
  103. :return:
  104. """
  105. # get provider
  106. provider_record = db.session.query(Provider) \
  107. .filter(
  108. Provider.tenant_id == self.tenant_id,
  109. Provider.provider_name == self.provider.provider,
  110. Provider.provider_type == ProviderType.CUSTOM.value
  111. ).first()
  112. # Get provider credential secret variables
  113. provider_credential_secret_variables = self._extract_secret_variables(
  114. self.provider.provider_credential_schema.credential_form_schemas
  115. if self.provider.provider_credential_schema else []
  116. )
  117. if provider_record:
  118. try:
  119. original_credentials = json.loads(
  120. provider_record.encrypted_config) if provider_record.encrypted_config else {}
  121. except JSONDecodeError:
  122. original_credentials = {}
  123. # encrypt credentials
  124. for key, value in credentials.items():
  125. if key in provider_credential_secret_variables:
  126. # if send [__HIDDEN__] in secret input, it will be same as original value
  127. if value == '[__HIDDEN__]' and key in original_credentials:
  128. credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
  129. model_provider_factory.provider_credentials_validate(
  130. self.provider.provider,
  131. credentials
  132. )
  133. for key, value in credentials.items():
  134. if key in provider_credential_secret_variables:
  135. credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
  136. return provider_record, credentials
  137. def add_or_update_custom_credentials(self, credentials: dict) -> None:
  138. """
  139. Add or update custom provider credentials.
  140. :param credentials:
  141. :return:
  142. """
  143. # validate custom provider config
  144. provider_record, credentials = self.custom_credentials_validate(credentials)
  145. # save provider
  146. # Note: Do not switch the preferred provider, which allows users to use quotas first
  147. if provider_record:
  148. provider_record.encrypted_config = json.dumps(credentials)
  149. provider_record.is_valid = True
  150. provider_record.updated_at = datetime.datetime.utcnow()
  151. db.session.commit()
  152. else:
  153. provider_record = Provider(
  154. tenant_id=self.tenant_id,
  155. provider_name=self.provider.provider,
  156. provider_type=ProviderType.CUSTOM.value,
  157. encrypted_config=json.dumps(credentials),
  158. is_valid=True
  159. )
  160. db.session.add(provider_record)
  161. db.session.commit()
  162. provider_model_credentials_cache = ProviderCredentialsCache(
  163. tenant_id=self.tenant_id,
  164. identity_id=provider_record.id,
  165. cache_type=ProviderCredentialsCacheType.PROVIDER
  166. )
  167. provider_model_credentials_cache.delete()
  168. self.switch_preferred_provider_type(ProviderType.CUSTOM)
  169. def delete_custom_credentials(self) -> None:
  170. """
  171. Delete custom provider credentials.
  172. :return:
  173. """
  174. # get provider
  175. provider_record = db.session.query(Provider) \
  176. .filter(
  177. Provider.tenant_id == self.tenant_id,
  178. Provider.provider_name == self.provider.provider,
  179. Provider.provider_type == ProviderType.CUSTOM.value
  180. ).first()
  181. # delete provider
  182. if provider_record:
  183. self.switch_preferred_provider_type(ProviderType.SYSTEM)
  184. db.session.delete(provider_record)
  185. db.session.commit()
  186. provider_model_credentials_cache = ProviderCredentialsCache(
  187. tenant_id=self.tenant_id,
  188. identity_id=provider_record.id,
  189. cache_type=ProviderCredentialsCacheType.PROVIDER
  190. )
  191. provider_model_credentials_cache.delete()
  192. def get_custom_model_credentials(self, model_type: ModelType, model: str, obfuscated: bool = False) \
  193. -> Optional[dict]:
  194. """
  195. Get custom model credentials.
  196. :param model_type: model type
  197. :param model: model name
  198. :param obfuscated: obfuscated secret data in credentials
  199. :return:
  200. """
  201. if not self.custom_configuration.models:
  202. return None
  203. for model_configuration in self.custom_configuration.models:
  204. if model_configuration.model_type == model_type and model_configuration.model == model:
  205. credentials = model_configuration.credentials
  206. if not obfuscated:
  207. return credentials
  208. # Obfuscate credentials
  209. return self._obfuscated_credentials(
  210. credentials=credentials,
  211. credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas
  212. if self.provider.model_credential_schema else []
  213. )
  214. return None
  215. def custom_model_credentials_validate(self, model_type: ModelType, model: str, credentials: dict) \
  216. -> Tuple[ProviderModel, dict]:
  217. """
  218. Validate custom model credentials.
  219. :param model_type: model type
  220. :param model: model name
  221. :param credentials: model credentials
  222. :return:
  223. """
  224. # get provider model
  225. provider_model_record = db.session.query(ProviderModel) \
  226. .filter(
  227. ProviderModel.tenant_id == self.tenant_id,
  228. ProviderModel.provider_name == self.provider.provider,
  229. ProviderModel.model_name == model,
  230. ProviderModel.model_type == model_type.to_origin_model_type()
  231. ).first()
  232. # Get provider credential secret variables
  233. provider_credential_secret_variables = self._extract_secret_variables(
  234. self.provider.model_credential_schema.credential_form_schemas
  235. if self.provider.model_credential_schema else []
  236. )
  237. if provider_model_record:
  238. try:
  239. original_credentials = json.loads(
  240. provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {}
  241. except JSONDecodeError:
  242. original_credentials = {}
  243. # decrypt credentials
  244. for key, value in credentials.items():
  245. if key in provider_credential_secret_variables:
  246. # if send [__HIDDEN__] in secret input, it will be same as original value
  247. if value == '[__HIDDEN__]' and key in original_credentials:
  248. credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
  249. model_provider_factory.model_credentials_validate(
  250. provider=self.provider.provider,
  251. model_type=model_type,
  252. model=model,
  253. credentials=credentials
  254. )
  255. model_schema = (
  256. model_provider_factory.get_provider_instance(self.provider.provider)
  257. .get_model_instance(model_type)._get_customizable_model_schema(
  258. model=model,
  259. credentials=credentials
  260. )
  261. )
  262. if model_schema:
  263. credentials['schema'] = json.dumps(encoders.jsonable_encoder(model_schema))
  264. for key, value in credentials.items():
  265. if key in provider_credential_secret_variables:
  266. credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
  267. return provider_model_record, credentials
  268. def add_or_update_custom_model_credentials(self, model_type: ModelType, model: str, credentials: dict) -> None:
  269. """
  270. Add or update custom model credentials.
  271. :param model_type: model type
  272. :param model: model name
  273. :param credentials: model credentials
  274. :return:
  275. """
  276. # validate custom model config
  277. provider_model_record, credentials = self.custom_model_credentials_validate(model_type, model, credentials)
  278. # save provider model
  279. # Note: Do not switch the preferred provider, which allows users to use quotas first
  280. if provider_model_record:
  281. provider_model_record.encrypted_config = json.dumps(credentials)
  282. provider_model_record.is_valid = True
  283. provider_model_record.updated_at = datetime.datetime.utcnow()
  284. db.session.commit()
  285. else:
  286. provider_model_record = ProviderModel(
  287. tenant_id=self.tenant_id,
  288. provider_name=self.provider.provider,
  289. model_name=model,
  290. model_type=model_type.to_origin_model_type(),
  291. encrypted_config=json.dumps(credentials),
  292. is_valid=True
  293. )
  294. db.session.add(provider_model_record)
  295. db.session.commit()
  296. provider_model_credentials_cache = ProviderCredentialsCache(
  297. tenant_id=self.tenant_id,
  298. identity_id=provider_model_record.id,
  299. cache_type=ProviderCredentialsCacheType.MODEL
  300. )
  301. provider_model_credentials_cache.delete()
  302. def delete_custom_model_credentials(self, model_type: ModelType, model: str) -> None:
  303. """
  304. Delete custom model credentials.
  305. :param model_type: model type
  306. :param model: model name
  307. :return:
  308. """
  309. # get provider model
  310. provider_model_record = db.session.query(ProviderModel) \
  311. .filter(
  312. ProviderModel.tenant_id == self.tenant_id,
  313. ProviderModel.provider_name == self.provider.provider,
  314. ProviderModel.model_name == model,
  315. ProviderModel.model_type == model_type.to_origin_model_type()
  316. ).first()
  317. # delete provider model
  318. if provider_model_record:
  319. db.session.delete(provider_model_record)
  320. db.session.commit()
  321. provider_model_credentials_cache = ProviderCredentialsCache(
  322. tenant_id=self.tenant_id,
  323. identity_id=provider_model_record.id,
  324. cache_type=ProviderCredentialsCacheType.MODEL
  325. )
  326. provider_model_credentials_cache.delete()
  327. def get_provider_instance(self) -> ModelProvider:
  328. """
  329. Get provider instance.
  330. :return:
  331. """
  332. return model_provider_factory.get_provider_instance(self.provider.provider)
  333. def get_model_type_instance(self, model_type: ModelType) -> AIModel:
  334. """
  335. Get current model type instance.
  336. :param model_type: model type
  337. :return:
  338. """
  339. # Get provider instance
  340. provider_instance = self.get_provider_instance()
  341. # Get model instance of LLM
  342. return provider_instance.get_model_instance(model_type)
  343. def switch_preferred_provider_type(self, provider_type: ProviderType) -> None:
  344. """
  345. Switch preferred provider type.
  346. :param provider_type:
  347. :return:
  348. """
  349. if provider_type == self.preferred_provider_type:
  350. return
  351. if provider_type == ProviderType.SYSTEM and not self.system_configuration.enabled:
  352. return
  353. # get preferred provider
  354. preferred_model_provider = db.session.query(TenantPreferredModelProvider) \
  355. .filter(
  356. TenantPreferredModelProvider.tenant_id == self.tenant_id,
  357. TenantPreferredModelProvider.provider_name == self.provider.provider
  358. ).first()
  359. if preferred_model_provider:
  360. preferred_model_provider.preferred_provider_type = provider_type.value
  361. else:
  362. preferred_model_provider = TenantPreferredModelProvider(
  363. tenant_id=self.tenant_id,
  364. provider_name=self.provider.provider,
  365. preferred_provider_type=provider_type.value
  366. )
  367. db.session.add(preferred_model_provider)
  368. db.session.commit()
  369. def _extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]:
  370. """
  371. Extract secret input form variables.
  372. :param credential_form_schemas:
  373. :return:
  374. """
  375. secret_input_form_variables = []
  376. for credential_form_schema in credential_form_schemas:
  377. if credential_form_schema.type == FormType.SECRET_INPUT:
  378. secret_input_form_variables.append(credential_form_schema.variable)
  379. return secret_input_form_variables
  380. def _obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]) -> dict:
  381. """
  382. Obfuscated credentials.
  383. :param credentials: credentials
  384. :param credential_form_schemas: credential form schemas
  385. :return:
  386. """
  387. # Get provider credential secret variables
  388. credential_secret_variables = self._extract_secret_variables(
  389. credential_form_schemas
  390. )
  391. # Obfuscate provider credentials
  392. copy_credentials = credentials.copy()
  393. for key, value in copy_credentials.items():
  394. if key in credential_secret_variables:
  395. copy_credentials[key] = encrypter.obfuscated_token(value)
  396. return copy_credentials
  397. def get_provider_model(self, model_type: ModelType,
  398. model: str,
  399. only_active: bool = False) -> Optional[ModelWithProviderEntity]:
  400. """
  401. Get provider model.
  402. :param model_type: model type
  403. :param model: model name
  404. :param only_active: return active model only
  405. :return:
  406. """
  407. provider_models = self.get_provider_models(model_type, only_active)
  408. for provider_model in provider_models:
  409. if provider_model.model == model:
  410. return provider_model
  411. return None
  412. def get_provider_models(self, model_type: Optional[ModelType] = None,
  413. only_active: bool = False) -> list[ModelWithProviderEntity]:
  414. """
  415. Get provider models.
  416. :param model_type: model type
  417. :param only_active: only active models
  418. :return:
  419. """
  420. provider_instance = self.get_provider_instance()
  421. model_types = []
  422. if model_type:
  423. model_types.append(model_type)
  424. else:
  425. model_types = provider_instance.get_provider_schema().supported_model_types
  426. if self.using_provider_type == ProviderType.SYSTEM:
  427. provider_models = self._get_system_provider_models(
  428. model_types=model_types,
  429. provider_instance=provider_instance
  430. )
  431. else:
  432. provider_models = self._get_custom_provider_models(
  433. model_types=model_types,
  434. provider_instance=provider_instance
  435. )
  436. if only_active:
  437. provider_models = [m for m in provider_models if m.status == ModelStatus.ACTIVE]
  438. # resort provider_models
  439. return sorted(provider_models, key=lambda x: x.model_type.value)
  440. def _get_system_provider_models(self,
  441. model_types: list[ModelType],
  442. provider_instance: ModelProvider) -> list[ModelWithProviderEntity]:
  443. """
  444. Get system provider models.
  445. :param model_types: model types
  446. :param provider_instance: provider instance
  447. :return:
  448. """
  449. provider_models = []
  450. for model_type in model_types:
  451. provider_models.extend(
  452. [
  453. ModelWithProviderEntity(
  454. model=m.model,
  455. label=m.label,
  456. model_type=m.model_type,
  457. features=m.features,
  458. fetch_from=m.fetch_from,
  459. model_properties=m.model_properties,
  460. deprecated=m.deprecated,
  461. provider=SimpleModelProviderEntity(self.provider),
  462. status=ModelStatus.ACTIVE
  463. )
  464. for m in provider_instance.models(model_type)
  465. ]
  466. )
  467. if self.provider.provider not in original_provider_configurate_methods:
  468. original_provider_configurate_methods[self.provider.provider] = []
  469. for configurate_method in provider_instance.get_provider_schema().configurate_methods:
  470. original_provider_configurate_methods[self.provider.provider].append(configurate_method)
  471. should_use_custom_model = False
  472. if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
  473. should_use_custom_model = True
  474. for quota_configuration in self.system_configuration.quota_configurations:
  475. if self.system_configuration.current_quota_type != quota_configuration.quota_type:
  476. continue
  477. restrict_models = quota_configuration.restrict_models
  478. if len(restrict_models) == 0:
  479. break
  480. if should_use_custom_model:
  481. if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
  482. # only customizable model
  483. for restrict_model in restrict_models:
  484. copy_credentials = self.system_configuration.credentials.copy()
  485. if restrict_model.base_model_name:
  486. copy_credentials['base_model_name'] = restrict_model.base_model_name
  487. try:
  488. custom_model_schema = (
  489. provider_instance.get_model_instance(restrict_model.model_type)
  490. .get_customizable_model_schema_from_credentials(
  491. restrict_model.model,
  492. copy_credentials
  493. )
  494. )
  495. except Exception as ex:
  496. logger.warning(f'get custom model schema failed, {ex}')
  497. continue
  498. if not custom_model_schema:
  499. continue
  500. if custom_model_schema.model_type not in model_types:
  501. continue
  502. provider_models.append(
  503. ModelWithProviderEntity(
  504. model=custom_model_schema.model,
  505. label=custom_model_schema.label,
  506. model_type=custom_model_schema.model_type,
  507. features=custom_model_schema.features,
  508. fetch_from=FetchFrom.PREDEFINED_MODEL,
  509. model_properties=custom_model_schema.model_properties,
  510. deprecated=custom_model_schema.deprecated,
  511. provider=SimpleModelProviderEntity(self.provider),
  512. status=ModelStatus.ACTIVE
  513. )
  514. )
  515. # if llm name not in restricted llm list, remove it
  516. restrict_model_names = [rm.model for rm in restrict_models]
  517. for m in provider_models:
  518. if m.model_type == ModelType.LLM and m.model not in restrict_model_names:
  519. m.status = ModelStatus.NO_PERMISSION
  520. elif not quota_configuration.is_valid:
  521. m.status = ModelStatus.QUOTA_EXCEEDED
  522. return provider_models
  523. def _get_custom_provider_models(self,
  524. model_types: list[ModelType],
  525. provider_instance: ModelProvider) -> list[ModelWithProviderEntity]:
  526. """
  527. Get custom provider models.
  528. :param model_types: model types
  529. :param provider_instance: provider instance
  530. :return:
  531. """
  532. provider_models = []
  533. credentials = None
  534. if self.custom_configuration.provider:
  535. credentials = self.custom_configuration.provider.credentials
  536. for model_type in model_types:
  537. if model_type not in self.provider.supported_model_types:
  538. continue
  539. models = provider_instance.models(model_type)
  540. for m in models:
  541. provider_models.append(
  542. ModelWithProviderEntity(
  543. model=m.model,
  544. label=m.label,
  545. model_type=m.model_type,
  546. features=m.features,
  547. fetch_from=m.fetch_from,
  548. model_properties=m.model_properties,
  549. deprecated=m.deprecated,
  550. provider=SimpleModelProviderEntity(self.provider),
  551. status=ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE
  552. )
  553. )
  554. # custom models
  555. for model_configuration in self.custom_configuration.models:
  556. if model_configuration.model_type not in model_types:
  557. continue
  558. try:
  559. custom_model_schema = (
  560. provider_instance.get_model_instance(model_configuration.model_type)
  561. .get_customizable_model_schema_from_credentials(
  562. model_configuration.model,
  563. model_configuration.credentials
  564. )
  565. )
  566. except Exception as ex:
  567. logger.warning(f'get custom model schema failed, {ex}')
  568. continue
  569. if not custom_model_schema:
  570. continue
  571. provider_models.append(
  572. ModelWithProviderEntity(
  573. model=custom_model_schema.model,
  574. label=custom_model_schema.label,
  575. model_type=custom_model_schema.model_type,
  576. features=custom_model_schema.features,
  577. fetch_from=custom_model_schema.fetch_from,
  578. model_properties=custom_model_schema.model_properties,
  579. deprecated=custom_model_schema.deprecated,
  580. provider=SimpleModelProviderEntity(self.provider),
  581. status=ModelStatus.ACTIVE
  582. )
  583. )
  584. return provider_models
  585. class ProviderConfigurations(BaseModel):
  586. """
  587. Model class for provider configuration dict.
  588. """
  589. tenant_id: str
  590. configurations: Dict[str, ProviderConfiguration] = {}
  591. def __init__(self, tenant_id: str):
  592. super().__init__(tenant_id=tenant_id)
  593. def get_models(self,
  594. provider: Optional[str] = None,
  595. model_type: Optional[ModelType] = None,
  596. only_active: bool = False) \
  597. -> list[ModelWithProviderEntity]:
  598. """
  599. Get available models.
  600. If preferred provider type is `system`:
  601. Get the current **system mode** if provider supported,
  602. if all system modes are not available (no quota), it is considered to be the **custom credential mode**.
  603. If there is no model configured in custom mode, it is treated as no_configure.
  604. system > custom > no_configure
  605. If preferred provider type is `custom`:
  606. If custom credentials are configured, it is treated as custom mode.
  607. Otherwise, get the current **system mode** if supported,
  608. If all system modes are not available (no quota), it is treated as no_configure.
  609. custom > system > no_configure
  610. If real mode is `system`, use system credentials to get models,
  611. paid quotas > provider free quotas > system free quotas
  612. include pre-defined models (exclude GPT-4, status marked as `no_permission`).
  613. If real mode is `custom`, use workspace custom credentials to get models,
  614. include pre-defined models, custom models(manual append).
  615. If real mode is `no_configure`, only return pre-defined models from `model runtime`.
  616. (model status marked as `no_configure` if preferred provider type is `custom` otherwise `quota_exceeded`)
  617. model status marked as `active` is available.
  618. :param provider: provider name
  619. :param model_type: model type
  620. :param only_active: only active models
  621. :return:
  622. """
  623. all_models = []
  624. for provider_configuration in self.values():
  625. if provider and provider_configuration.provider.provider != provider:
  626. continue
  627. all_models.extend(provider_configuration.get_provider_models(model_type, only_active))
  628. return all_models
  629. def to_list(self) -> List[ProviderConfiguration]:
  630. """
  631. Convert to list.
  632. :return:
  633. """
  634. return list(self.values())
  635. def __getitem__(self, key):
  636. return self.configurations[key]
  637. def __setitem__(self, key, value):
  638. self.configurations[key] = value
  639. def __iter__(self):
  640. return iter(self.configurations)
  641. def values(self) -> Iterator[ProviderConfiguration]:
  642. return self.configurations.values()
  643. def get(self, key, default=None):
  644. return self.configurations.get(key, default)
  645. class ProviderModelBundle(BaseModel):
  646. """
  647. Provider model bundle.
  648. """
  649. configuration: ProviderConfiguration
  650. provider_instance: ModelProvider
  651. model_type_instance: AIModel
  652. class Config:
  653. """Configuration for this pydantic object."""
  654. arbitrary_types_allowed = True