provider_configuration.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798
  1. import datetime
  2. import json
  3. import logging
  4. from collections.abc import Iterator
  5. from json import JSONDecodeError
  6. from typing import Optional
  7. from pydantic import BaseModel
  8. from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
  9. from core.entities.provider_entities import CustomConfiguration, SystemConfiguration, SystemConfigurationStatus
  10. from core.helper import encrypter
  11. from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
  12. from core.model_runtime.entities.model_entities import FetchFrom, ModelType
  13. from core.model_runtime.entities.provider_entities import (
  14. ConfigurateMethod,
  15. CredentialFormSchema,
  16. FormType,
  17. ProviderEntity,
  18. )
  19. from core.model_runtime.model_providers import model_provider_factory
  20. from core.model_runtime.model_providers.__base.ai_model import AIModel
  21. from core.model_runtime.model_providers.__base.model_provider import ModelProvider
  22. from extensions.ext_database import db
  23. from models.provider import Provider, ProviderModel, ProviderType, TenantPreferredModelProvider
  24. logger = logging.getLogger(__name__)
  25. original_provider_configurate_methods = {}
  26. class ProviderConfiguration(BaseModel):
  27. """
  28. Model class for provider configuration.
  29. """
  30. tenant_id: str
  31. provider: ProviderEntity
  32. preferred_provider_type: ProviderType
  33. using_provider_type: ProviderType
  34. system_configuration: SystemConfiguration
  35. custom_configuration: CustomConfiguration
  36. def __init__(self, **data):
  37. super().__init__(**data)
  38. if self.provider.provider not in original_provider_configurate_methods:
  39. original_provider_configurate_methods[self.provider.provider] = []
  40. for configurate_method in self.provider.configurate_methods:
  41. original_provider_configurate_methods[self.provider.provider].append(configurate_method)
  42. if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
  43. if (any([len(quota_configuration.restrict_models) > 0
  44. for quota_configuration in self.system_configuration.quota_configurations])
  45. and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods):
  46. self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL)
  47. def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]:
  48. """
  49. Get current credentials.
  50. :param model_type: model type
  51. :param model: model name
  52. :return:
  53. """
  54. if self.using_provider_type == ProviderType.SYSTEM:
  55. restrict_models = []
  56. for quota_configuration in self.system_configuration.quota_configurations:
  57. if self.system_configuration.current_quota_type != quota_configuration.quota_type:
  58. continue
  59. restrict_models = quota_configuration.restrict_models
  60. copy_credentials = self.system_configuration.credentials.copy()
  61. if restrict_models:
  62. for restrict_model in restrict_models:
  63. if (restrict_model.model_type == model_type
  64. and restrict_model.model == model
  65. and restrict_model.base_model_name):
  66. copy_credentials['base_model_name'] = restrict_model.base_model_name
  67. return copy_credentials
  68. else:
  69. if self.custom_configuration.models:
  70. for model_configuration in self.custom_configuration.models:
  71. if model_configuration.model_type == model_type and model_configuration.model == model:
  72. return model_configuration.credentials
  73. if self.custom_configuration.provider:
  74. return self.custom_configuration.provider.credentials
  75. else:
  76. return None
  77. def get_system_configuration_status(self) -> SystemConfigurationStatus:
  78. """
  79. Get system configuration status.
  80. :return:
  81. """
  82. if self.system_configuration.enabled is False:
  83. return SystemConfigurationStatus.UNSUPPORTED
  84. current_quota_type = self.system_configuration.current_quota_type
  85. current_quota_configuration = next(
  86. (q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type),
  87. None
  88. )
  89. return SystemConfigurationStatus.ACTIVE if current_quota_configuration.is_valid else \
  90. SystemConfigurationStatus.QUOTA_EXCEEDED
  91. def is_custom_configuration_available(self) -> bool:
  92. """
  93. Check custom configuration available.
  94. :return:
  95. """
  96. return (self.custom_configuration.provider is not None
  97. or len(self.custom_configuration.models) > 0)
  98. def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]:
  99. """
  100. Get custom credentials.
  101. :param obfuscated: obfuscated secret data in credentials
  102. :return:
  103. """
  104. if self.custom_configuration.provider is None:
  105. return None
  106. credentials = self.custom_configuration.provider.credentials
  107. if not obfuscated:
  108. return credentials
  109. # Obfuscate credentials
  110. return self._obfuscated_credentials(
  111. credentials=credentials,
  112. credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas
  113. if self.provider.provider_credential_schema else []
  114. )
  115. def custom_credentials_validate(self, credentials: dict) -> tuple[Provider, dict]:
  116. """
  117. Validate custom credentials.
  118. :param credentials: provider credentials
  119. :return:
  120. """
  121. # get provider
  122. provider_record = db.session.query(Provider) \
  123. .filter(
  124. Provider.tenant_id == self.tenant_id,
  125. Provider.provider_name == self.provider.provider,
  126. Provider.provider_type == ProviderType.CUSTOM.value
  127. ).first()
  128. # Get provider credential secret variables
  129. provider_credential_secret_variables = self._extract_secret_variables(
  130. self.provider.provider_credential_schema.credential_form_schemas
  131. if self.provider.provider_credential_schema else []
  132. )
  133. if provider_record:
  134. try:
  135. # fix origin data
  136. if provider_record.encrypted_config:
  137. if not provider_record.encrypted_config.startswith("{"):
  138. original_credentials = {
  139. "openai_api_key": provider_record.encrypted_config
  140. }
  141. else:
  142. original_credentials = json.loads(provider_record.encrypted_config)
  143. else:
  144. original_credentials = {}
  145. except JSONDecodeError:
  146. original_credentials = {}
  147. # encrypt credentials
  148. for key, value in credentials.items():
  149. if key in provider_credential_secret_variables:
  150. # if send [__HIDDEN__] in secret input, it will be same as original value
  151. if value == '[__HIDDEN__]' and key in original_credentials:
  152. credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
  153. credentials = model_provider_factory.provider_credentials_validate(
  154. self.provider.provider,
  155. credentials
  156. )
  157. for key, value in credentials.items():
  158. if key in provider_credential_secret_variables:
  159. credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
  160. return provider_record, credentials
  161. def add_or_update_custom_credentials(self, credentials: dict) -> None:
  162. """
  163. Add or update custom provider credentials.
  164. :param credentials:
  165. :return:
  166. """
  167. # validate custom provider config
  168. provider_record, credentials = self.custom_credentials_validate(credentials)
  169. # save provider
  170. # Note: Do not switch the preferred provider, which allows users to use quotas first
  171. if provider_record:
  172. provider_record.encrypted_config = json.dumps(credentials)
  173. provider_record.is_valid = True
  174. provider_record.updated_at = datetime.datetime.utcnow()
  175. db.session.commit()
  176. else:
  177. provider_record = Provider(
  178. tenant_id=self.tenant_id,
  179. provider_name=self.provider.provider,
  180. provider_type=ProviderType.CUSTOM.value,
  181. encrypted_config=json.dumps(credentials),
  182. is_valid=True
  183. )
  184. db.session.add(provider_record)
  185. db.session.commit()
  186. provider_model_credentials_cache = ProviderCredentialsCache(
  187. tenant_id=self.tenant_id,
  188. identity_id=provider_record.id,
  189. cache_type=ProviderCredentialsCacheType.PROVIDER
  190. )
  191. provider_model_credentials_cache.delete()
  192. self.switch_preferred_provider_type(ProviderType.CUSTOM)
  193. def delete_custom_credentials(self) -> None:
  194. """
  195. Delete custom provider credentials.
  196. :return:
  197. """
  198. # get provider
  199. provider_record = db.session.query(Provider) \
  200. .filter(
  201. Provider.tenant_id == self.tenant_id,
  202. Provider.provider_name == self.provider.provider,
  203. Provider.provider_type == ProviderType.CUSTOM.value
  204. ).first()
  205. # delete provider
  206. if provider_record:
  207. self.switch_preferred_provider_type(ProviderType.SYSTEM)
  208. db.session.delete(provider_record)
  209. db.session.commit()
  210. provider_model_credentials_cache = ProviderCredentialsCache(
  211. tenant_id=self.tenant_id,
  212. identity_id=provider_record.id,
  213. cache_type=ProviderCredentialsCacheType.PROVIDER
  214. )
  215. provider_model_credentials_cache.delete()
  216. def get_custom_model_credentials(self, model_type: ModelType, model: str, obfuscated: bool = False) \
  217. -> Optional[dict]:
  218. """
  219. Get custom model credentials.
  220. :param model_type: model type
  221. :param model: model name
  222. :param obfuscated: obfuscated secret data in credentials
  223. :return:
  224. """
  225. if not self.custom_configuration.models:
  226. return None
  227. for model_configuration in self.custom_configuration.models:
  228. if model_configuration.model_type == model_type and model_configuration.model == model:
  229. credentials = model_configuration.credentials
  230. if not obfuscated:
  231. return credentials
  232. # Obfuscate credentials
  233. return self._obfuscated_credentials(
  234. credentials=credentials,
  235. credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas
  236. if self.provider.model_credential_schema else []
  237. )
  238. return None
  239. def custom_model_credentials_validate(self, model_type: ModelType, model: str, credentials: dict) \
  240. -> tuple[ProviderModel, dict]:
  241. """
  242. Validate custom model credentials.
  243. :param model_type: model type
  244. :param model: model name
  245. :param credentials: model credentials
  246. :return:
  247. """
  248. # get provider model
  249. provider_model_record = db.session.query(ProviderModel) \
  250. .filter(
  251. ProviderModel.tenant_id == self.tenant_id,
  252. ProviderModel.provider_name == self.provider.provider,
  253. ProviderModel.model_name == model,
  254. ProviderModel.model_type == model_type.to_origin_model_type()
  255. ).first()
  256. # Get provider credential secret variables
  257. provider_credential_secret_variables = self._extract_secret_variables(
  258. self.provider.model_credential_schema.credential_form_schemas
  259. if self.provider.model_credential_schema else []
  260. )
  261. if provider_model_record:
  262. try:
  263. original_credentials = json.loads(
  264. provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {}
  265. except JSONDecodeError:
  266. original_credentials = {}
  267. # decrypt credentials
  268. for key, value in credentials.items():
  269. if key in provider_credential_secret_variables:
  270. # if send [__HIDDEN__] in secret input, it will be same as original value
  271. if value == '[__HIDDEN__]' and key in original_credentials:
  272. credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
  273. credentials = model_provider_factory.model_credentials_validate(
  274. provider=self.provider.provider,
  275. model_type=model_type,
  276. model=model,
  277. credentials=credentials
  278. )
  279. for key, value in credentials.items():
  280. if key in provider_credential_secret_variables:
  281. credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
  282. return provider_model_record, credentials
  283. def add_or_update_custom_model_credentials(self, model_type: ModelType, model: str, credentials: dict) -> None:
  284. """
  285. Add or update custom model credentials.
  286. :param model_type: model type
  287. :param model: model name
  288. :param credentials: model credentials
  289. :return:
  290. """
  291. # validate custom model config
  292. provider_model_record, credentials = self.custom_model_credentials_validate(model_type, model, credentials)
  293. # save provider model
  294. # Note: Do not switch the preferred provider, which allows users to use quotas first
  295. if provider_model_record:
  296. provider_model_record.encrypted_config = json.dumps(credentials)
  297. provider_model_record.is_valid = True
  298. provider_model_record.updated_at = datetime.datetime.utcnow()
  299. db.session.commit()
  300. else:
  301. provider_model_record = ProviderModel(
  302. tenant_id=self.tenant_id,
  303. provider_name=self.provider.provider,
  304. model_name=model,
  305. model_type=model_type.to_origin_model_type(),
  306. encrypted_config=json.dumps(credentials),
  307. is_valid=True
  308. )
  309. db.session.add(provider_model_record)
  310. db.session.commit()
  311. provider_model_credentials_cache = ProviderCredentialsCache(
  312. tenant_id=self.tenant_id,
  313. identity_id=provider_model_record.id,
  314. cache_type=ProviderCredentialsCacheType.MODEL
  315. )
  316. provider_model_credentials_cache.delete()
  317. def delete_custom_model_credentials(self, model_type: ModelType, model: str) -> None:
  318. """
  319. Delete custom model credentials.
  320. :param model_type: model type
  321. :param model: model name
  322. :return:
  323. """
  324. # get provider model
  325. provider_model_record = db.session.query(ProviderModel) \
  326. .filter(
  327. ProviderModel.tenant_id == self.tenant_id,
  328. ProviderModel.provider_name == self.provider.provider,
  329. ProviderModel.model_name == model,
  330. ProviderModel.model_type == model_type.to_origin_model_type()
  331. ).first()
  332. # delete provider model
  333. if provider_model_record:
  334. db.session.delete(provider_model_record)
  335. db.session.commit()
  336. provider_model_credentials_cache = ProviderCredentialsCache(
  337. tenant_id=self.tenant_id,
  338. identity_id=provider_model_record.id,
  339. cache_type=ProviderCredentialsCacheType.MODEL
  340. )
  341. provider_model_credentials_cache.delete()
  342. def get_provider_instance(self) -> ModelProvider:
  343. """
  344. Get provider instance.
  345. :return:
  346. """
  347. return model_provider_factory.get_provider_instance(self.provider.provider)
  348. def get_model_type_instance(self, model_type: ModelType) -> AIModel:
  349. """
  350. Get current model type instance.
  351. :param model_type: model type
  352. :return:
  353. """
  354. # Get provider instance
  355. provider_instance = self.get_provider_instance()
  356. # Get model instance of LLM
  357. return provider_instance.get_model_instance(model_type)
  358. def switch_preferred_provider_type(self, provider_type: ProviderType) -> None:
  359. """
  360. Switch preferred provider type.
  361. :param provider_type:
  362. :return:
  363. """
  364. if provider_type == self.preferred_provider_type:
  365. return
  366. if provider_type == ProviderType.SYSTEM and not self.system_configuration.enabled:
  367. return
  368. # get preferred provider
  369. preferred_model_provider = db.session.query(TenantPreferredModelProvider) \
  370. .filter(
  371. TenantPreferredModelProvider.tenant_id == self.tenant_id,
  372. TenantPreferredModelProvider.provider_name == self.provider.provider
  373. ).first()
  374. if preferred_model_provider:
  375. preferred_model_provider.preferred_provider_type = provider_type.value
  376. else:
  377. preferred_model_provider = TenantPreferredModelProvider(
  378. tenant_id=self.tenant_id,
  379. provider_name=self.provider.provider,
  380. preferred_provider_type=provider_type.value
  381. )
  382. db.session.add(preferred_model_provider)
  383. db.session.commit()
  384. def _extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]:
  385. """
  386. Extract secret input form variables.
  387. :param credential_form_schemas:
  388. :return:
  389. """
  390. secret_input_form_variables = []
  391. for credential_form_schema in credential_form_schemas:
  392. if credential_form_schema.type == FormType.SECRET_INPUT:
  393. secret_input_form_variables.append(credential_form_schema.variable)
  394. return secret_input_form_variables
  395. def _obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]) -> dict:
  396. """
  397. Obfuscated credentials.
  398. :param credentials: credentials
  399. :param credential_form_schemas: credential form schemas
  400. :return:
  401. """
  402. # Get provider credential secret variables
  403. credential_secret_variables = self._extract_secret_variables(
  404. credential_form_schemas
  405. )
  406. # Obfuscate provider credentials
  407. copy_credentials = credentials.copy()
  408. for key, value in copy_credentials.items():
  409. if key in credential_secret_variables:
  410. copy_credentials[key] = encrypter.obfuscated_token(value)
  411. return copy_credentials
  412. def get_provider_model(self, model_type: ModelType,
  413. model: str,
  414. only_active: bool = False) -> Optional[ModelWithProviderEntity]:
  415. """
  416. Get provider model.
  417. :param model_type: model type
  418. :param model: model name
  419. :param only_active: return active model only
  420. :return:
  421. """
  422. provider_models = self.get_provider_models(model_type, only_active)
  423. for provider_model in provider_models:
  424. if provider_model.model == model:
  425. return provider_model
  426. return None
  427. def get_provider_models(self, model_type: Optional[ModelType] = None,
  428. only_active: bool = False) -> list[ModelWithProviderEntity]:
  429. """
  430. Get provider models.
  431. :param model_type: model type
  432. :param only_active: only active models
  433. :return:
  434. """
  435. provider_instance = self.get_provider_instance()
  436. model_types = []
  437. if model_type:
  438. model_types.append(model_type)
  439. else:
  440. model_types = provider_instance.get_provider_schema().supported_model_types
  441. if self.using_provider_type == ProviderType.SYSTEM:
  442. provider_models = self._get_system_provider_models(
  443. model_types=model_types,
  444. provider_instance=provider_instance
  445. )
  446. else:
  447. provider_models = self._get_custom_provider_models(
  448. model_types=model_types,
  449. provider_instance=provider_instance
  450. )
  451. if only_active:
  452. provider_models = [m for m in provider_models if m.status == ModelStatus.ACTIVE]
  453. # resort provider_models
  454. return sorted(provider_models, key=lambda x: x.model_type.value)
  455. def _get_system_provider_models(self,
  456. model_types: list[ModelType],
  457. provider_instance: ModelProvider) -> list[ModelWithProviderEntity]:
  458. """
  459. Get system provider models.
  460. :param model_types: model types
  461. :param provider_instance: provider instance
  462. :return:
  463. """
  464. provider_models = []
  465. for model_type in model_types:
  466. provider_models.extend(
  467. [
  468. ModelWithProviderEntity(
  469. model=m.model,
  470. label=m.label,
  471. model_type=m.model_type,
  472. features=m.features,
  473. fetch_from=m.fetch_from,
  474. model_properties=m.model_properties,
  475. deprecated=m.deprecated,
  476. provider=SimpleModelProviderEntity(self.provider),
  477. status=ModelStatus.ACTIVE
  478. )
  479. for m in provider_instance.models(model_type)
  480. ]
  481. )
  482. if self.provider.provider not in original_provider_configurate_methods:
  483. original_provider_configurate_methods[self.provider.provider] = []
  484. for configurate_method in provider_instance.get_provider_schema().configurate_methods:
  485. original_provider_configurate_methods[self.provider.provider].append(configurate_method)
  486. should_use_custom_model = False
  487. if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
  488. should_use_custom_model = True
  489. for quota_configuration in self.system_configuration.quota_configurations:
  490. if self.system_configuration.current_quota_type != quota_configuration.quota_type:
  491. continue
  492. restrict_models = quota_configuration.restrict_models
  493. if len(restrict_models) == 0:
  494. break
  495. if should_use_custom_model:
  496. if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
  497. # only customizable model
  498. for restrict_model in restrict_models:
  499. copy_credentials = self.system_configuration.credentials.copy()
  500. if restrict_model.base_model_name:
  501. copy_credentials['base_model_name'] = restrict_model.base_model_name
  502. try:
  503. custom_model_schema = (
  504. provider_instance.get_model_instance(restrict_model.model_type)
  505. .get_customizable_model_schema_from_credentials(
  506. restrict_model.model,
  507. copy_credentials
  508. )
  509. )
  510. except Exception as ex:
  511. logger.warning(f'get custom model schema failed, {ex}')
  512. continue
  513. if not custom_model_schema:
  514. continue
  515. if custom_model_schema.model_type not in model_types:
  516. continue
  517. provider_models.append(
  518. ModelWithProviderEntity(
  519. model=custom_model_schema.model,
  520. label=custom_model_schema.label,
  521. model_type=custom_model_schema.model_type,
  522. features=custom_model_schema.features,
  523. fetch_from=FetchFrom.PREDEFINED_MODEL,
  524. model_properties=custom_model_schema.model_properties,
  525. deprecated=custom_model_schema.deprecated,
  526. provider=SimpleModelProviderEntity(self.provider),
  527. status=ModelStatus.ACTIVE
  528. )
  529. )
  530. # if llm name not in restricted llm list, remove it
  531. restrict_model_names = [rm.model for rm in restrict_models]
  532. for m in provider_models:
  533. if m.model_type == ModelType.LLM and m.model not in restrict_model_names:
  534. m.status = ModelStatus.NO_PERMISSION
  535. elif not quota_configuration.is_valid:
  536. m.status = ModelStatus.QUOTA_EXCEEDED
  537. return provider_models
  538. def _get_custom_provider_models(self,
  539. model_types: list[ModelType],
  540. provider_instance: ModelProvider) -> list[ModelWithProviderEntity]:
  541. """
  542. Get custom provider models.
  543. :param model_types: model types
  544. :param provider_instance: provider instance
  545. :return:
  546. """
  547. provider_models = []
  548. credentials = None
  549. if self.custom_configuration.provider:
  550. credentials = self.custom_configuration.provider.credentials
  551. for model_type in model_types:
  552. if model_type not in self.provider.supported_model_types:
  553. continue
  554. models = provider_instance.models(model_type)
  555. for m in models:
  556. provider_models.append(
  557. ModelWithProviderEntity(
  558. model=m.model,
  559. label=m.label,
  560. model_type=m.model_type,
  561. features=m.features,
  562. fetch_from=m.fetch_from,
  563. model_properties=m.model_properties,
  564. deprecated=m.deprecated,
  565. provider=SimpleModelProviderEntity(self.provider),
  566. status=ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE
  567. )
  568. )
  569. # custom models
  570. for model_configuration in self.custom_configuration.models:
  571. if model_configuration.model_type not in model_types:
  572. continue
  573. try:
  574. custom_model_schema = (
  575. provider_instance.get_model_instance(model_configuration.model_type)
  576. .get_customizable_model_schema_from_credentials(
  577. model_configuration.model,
  578. model_configuration.credentials
  579. )
  580. )
  581. except Exception as ex:
  582. logger.warning(f'get custom model schema failed, {ex}')
  583. continue
  584. if not custom_model_schema:
  585. continue
  586. provider_models.append(
  587. ModelWithProviderEntity(
  588. model=custom_model_schema.model,
  589. label=custom_model_schema.label,
  590. model_type=custom_model_schema.model_type,
  591. features=custom_model_schema.features,
  592. fetch_from=custom_model_schema.fetch_from,
  593. model_properties=custom_model_schema.model_properties,
  594. deprecated=custom_model_schema.deprecated,
  595. provider=SimpleModelProviderEntity(self.provider),
  596. status=ModelStatus.ACTIVE
  597. )
  598. )
  599. return provider_models
  600. class ProviderConfigurations(BaseModel):
  601. """
  602. Model class for provider configuration dict.
  603. """
  604. tenant_id: str
  605. configurations: dict[str, ProviderConfiguration] = {}
  606. def __init__(self, tenant_id: str):
  607. super().__init__(tenant_id=tenant_id)
  608. def get_models(self,
  609. provider: Optional[str] = None,
  610. model_type: Optional[ModelType] = None,
  611. only_active: bool = False) \
  612. -> list[ModelWithProviderEntity]:
  613. """
  614. Get available models.
  615. If preferred provider type is `system`:
  616. Get the current **system mode** if provider supported,
  617. if all system modes are not available (no quota), it is considered to be the **custom credential mode**.
  618. If there is no model configured in custom mode, it is treated as no_configure.
  619. system > custom > no_configure
  620. If preferred provider type is `custom`:
  621. If custom credentials are configured, it is treated as custom mode.
  622. Otherwise, get the current **system mode** if supported,
  623. If all system modes are not available (no quota), it is treated as no_configure.
  624. custom > system > no_configure
  625. If real mode is `system`, use system credentials to get models,
  626. paid quotas > provider free quotas > system free quotas
  627. include pre-defined models (exclude GPT-4, status marked as `no_permission`).
  628. If real mode is `custom`, use workspace custom credentials to get models,
  629. include pre-defined models, custom models(manual append).
  630. If real mode is `no_configure`, only return pre-defined models from `model runtime`.
  631. (model status marked as `no_configure` if preferred provider type is `custom` otherwise `quota_exceeded`)
  632. model status marked as `active` is available.
  633. :param provider: provider name
  634. :param model_type: model type
  635. :param only_active: only active models
  636. :return:
  637. """
  638. all_models = []
  639. for provider_configuration in self.values():
  640. if provider and provider_configuration.provider.provider != provider:
  641. continue
  642. all_models.extend(provider_configuration.get_provider_models(model_type, only_active))
  643. return all_models
  644. def to_list(self) -> list[ProviderConfiguration]:
  645. """
  646. Convert to list.
  647. :return:
  648. """
  649. return list(self.values())
  650. def __getitem__(self, key):
  651. return self.configurations[key]
  652. def __setitem__(self, key, value):
  653. self.configurations[key] = value
  654. def __iter__(self):
  655. return iter(self.configurations)
  656. def values(self) -> Iterator[ProviderConfiguration]:
  657. return self.configurations.values()
  658. def get(self, key, default=None):
  659. return self.configurations.get(key, default)
  660. class ProviderModelBundle(BaseModel):
  661. """
  662. Provider model bundle.
  663. """
  664. configuration: ProviderConfiguration
  665. provider_instance: ModelProvider
  666. model_type_instance: AIModel
  667. class Config:
  668. """Configuration for this pydantic object."""
  669. arbitrary_types_allowed = True