provider_configuration.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798
  1. import datetime
  2. import json
  3. import logging
  4. from json import JSONDecodeError
  5. from typing import Optional, List, Dict, Tuple, Iterator
  6. from pydantic import BaseModel
  7. from core.entities.model_entities import ModelWithProviderEntity, ModelStatus, SimpleModelProviderEntity
  8. from core.entities.provider_entities import SystemConfiguration, CustomConfiguration, SystemConfigurationStatus
  9. from core.helper import encrypter
  10. from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
  11. from core.model_runtime.entities.model_entities import ModelType, FetchFrom
  12. from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType, \
  13. ConfigurateMethod
  14. from core.model_runtime.model_providers import model_provider_factory
  15. from core.model_runtime.model_providers.__base.ai_model import AIModel
  16. from core.model_runtime.model_providers.__base.model_provider import ModelProvider
  17. from core.model_runtime.utils import encoders
  18. from extensions.ext_database import db
  19. from models.provider import ProviderType, Provider, ProviderModel, TenantPreferredModelProvider
  20. logger = logging.getLogger(__name__)
  21. original_provider_configurate_methods = {}
  22. class ProviderConfiguration(BaseModel):
  23. """
  24. Model class for provider configuration.
  25. """
  26. tenant_id: str
  27. provider: ProviderEntity
  28. preferred_provider_type: ProviderType
  29. using_provider_type: ProviderType
  30. system_configuration: SystemConfiguration
  31. custom_configuration: CustomConfiguration
  32. def __init__(self, **data):
  33. super().__init__(**data)
  34. if self.provider.provider not in original_provider_configurate_methods:
  35. original_provider_configurate_methods[self.provider.provider] = []
  36. for configurate_method in self.provider.configurate_methods:
  37. original_provider_configurate_methods[self.provider.provider].append(configurate_method)
  38. if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
  39. if (any([len(quota_configuration.restrict_models) > 0
  40. for quota_configuration in self.system_configuration.quota_configurations])
  41. and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods):
  42. self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL)
  43. def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]:
  44. """
  45. Get current credentials.
  46. :param model_type: model type
  47. :param model: model name
  48. :return:
  49. """
  50. if self.using_provider_type == ProviderType.SYSTEM:
  51. restrict_models = []
  52. for quota_configuration in self.system_configuration.quota_configurations:
  53. if self.system_configuration.current_quota_type != quota_configuration.quota_type:
  54. continue
  55. restrict_models = quota_configuration.restrict_models
  56. copy_credentials = self.system_configuration.credentials.copy()
  57. if restrict_models:
  58. for restrict_model in restrict_models:
  59. if (restrict_model.model_type == model_type
  60. and restrict_model.model == model
  61. and restrict_model.base_model_name):
  62. copy_credentials['base_model_name'] = restrict_model.base_model_name
  63. return copy_credentials
  64. else:
  65. if self.custom_configuration.models:
  66. for model_configuration in self.custom_configuration.models:
  67. if model_configuration.model_type == model_type and model_configuration.model == model:
  68. return model_configuration.credentials
  69. if self.custom_configuration.provider:
  70. return self.custom_configuration.provider.credentials
  71. else:
  72. return None
  73. def get_system_configuration_status(self) -> SystemConfigurationStatus:
  74. """
  75. Get system configuration status.
  76. :return:
  77. """
  78. if self.system_configuration.enabled is False:
  79. return SystemConfigurationStatus.UNSUPPORTED
  80. current_quota_type = self.system_configuration.current_quota_type
  81. current_quota_configuration = next(
  82. (q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type),
  83. None
  84. )
  85. return SystemConfigurationStatus.ACTIVE if current_quota_configuration.is_valid else \
  86. SystemConfigurationStatus.QUOTA_EXCEEDED
  87. def is_custom_configuration_available(self) -> bool:
  88. """
  89. Check custom configuration available.
  90. :return:
  91. """
  92. return (self.custom_configuration.provider is not None
  93. or len(self.custom_configuration.models) > 0)
  94. def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]:
  95. """
  96. Get custom credentials.
  97. :param obfuscated: obfuscated secret data in credentials
  98. :return:
  99. """
  100. if self.custom_configuration.provider is None:
  101. return None
  102. credentials = self.custom_configuration.provider.credentials
  103. if not obfuscated:
  104. return credentials
  105. # Obfuscate credentials
  106. return self._obfuscated_credentials(
  107. credentials=credentials,
  108. credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas
  109. if self.provider.provider_credential_schema else []
  110. )
  111. def custom_credentials_validate(self, credentials: dict) -> Tuple[Provider, dict]:
  112. """
  113. Validate custom credentials.
  114. :param credentials: provider credentials
  115. :return:
  116. """
  117. # get provider
  118. provider_record = db.session.query(Provider) \
  119. .filter(
  120. Provider.tenant_id == self.tenant_id,
  121. Provider.provider_name == self.provider.provider,
  122. Provider.provider_type == ProviderType.CUSTOM.value
  123. ).first()
  124. # Get provider credential secret variables
  125. provider_credential_secret_variables = self._extract_secret_variables(
  126. self.provider.provider_credential_schema.credential_form_schemas
  127. if self.provider.provider_credential_schema else []
  128. )
  129. if provider_record:
  130. try:
  131. original_credentials = json.loads(
  132. provider_record.encrypted_config) if provider_record.encrypted_config else {}
  133. except JSONDecodeError:
  134. original_credentials = {}
  135. # encrypt credentials
  136. for key, value in credentials.items():
  137. if key in provider_credential_secret_variables:
  138. # if send [__HIDDEN__] in secret input, it will be same as original value
  139. if value == '[__HIDDEN__]' and key in original_credentials:
  140. credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
  141. model_provider_factory.provider_credentials_validate(
  142. self.provider.provider,
  143. credentials
  144. )
  145. for key, value in credentials.items():
  146. if key in provider_credential_secret_variables:
  147. credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
  148. return provider_record, credentials
  149. def add_or_update_custom_credentials(self, credentials: dict) -> None:
  150. """
  151. Add or update custom provider credentials.
  152. :param credentials:
  153. :return:
  154. """
  155. # validate custom provider config
  156. provider_record, credentials = self.custom_credentials_validate(credentials)
  157. # save provider
  158. # Note: Do not switch the preferred provider, which allows users to use quotas first
  159. if provider_record:
  160. provider_record.encrypted_config = json.dumps(credentials)
  161. provider_record.is_valid = True
  162. provider_record.updated_at = datetime.datetime.utcnow()
  163. db.session.commit()
  164. else:
  165. provider_record = Provider(
  166. tenant_id=self.tenant_id,
  167. provider_name=self.provider.provider,
  168. provider_type=ProviderType.CUSTOM.value,
  169. encrypted_config=json.dumps(credentials),
  170. is_valid=True
  171. )
  172. db.session.add(provider_record)
  173. db.session.commit()
  174. provider_model_credentials_cache = ProviderCredentialsCache(
  175. tenant_id=self.tenant_id,
  176. identity_id=provider_record.id,
  177. cache_type=ProviderCredentialsCacheType.PROVIDER
  178. )
  179. provider_model_credentials_cache.delete()
  180. self.switch_preferred_provider_type(ProviderType.CUSTOM)
  181. def delete_custom_credentials(self) -> None:
  182. """
  183. Delete custom provider credentials.
  184. :return:
  185. """
  186. # get provider
  187. provider_record = db.session.query(Provider) \
  188. .filter(
  189. Provider.tenant_id == self.tenant_id,
  190. Provider.provider_name == self.provider.provider,
  191. Provider.provider_type == ProviderType.CUSTOM.value
  192. ).first()
  193. # delete provider
  194. if provider_record:
  195. self.switch_preferred_provider_type(ProviderType.SYSTEM)
  196. db.session.delete(provider_record)
  197. db.session.commit()
  198. provider_model_credentials_cache = ProviderCredentialsCache(
  199. tenant_id=self.tenant_id,
  200. identity_id=provider_record.id,
  201. cache_type=ProviderCredentialsCacheType.PROVIDER
  202. )
  203. provider_model_credentials_cache.delete()
  204. def get_custom_model_credentials(self, model_type: ModelType, model: str, obfuscated: bool = False) \
  205. -> Optional[dict]:
  206. """
  207. Get custom model credentials.
  208. :param model_type: model type
  209. :param model: model name
  210. :param obfuscated: obfuscated secret data in credentials
  211. :return:
  212. """
  213. if not self.custom_configuration.models:
  214. return None
  215. for model_configuration in self.custom_configuration.models:
  216. if model_configuration.model_type == model_type and model_configuration.model == model:
  217. credentials = model_configuration.credentials
  218. if not obfuscated:
  219. return credentials
  220. # Obfuscate credentials
  221. return self._obfuscated_credentials(
  222. credentials=credentials,
  223. credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas
  224. if self.provider.model_credential_schema else []
  225. )
  226. return None
  227. def custom_model_credentials_validate(self, model_type: ModelType, model: str, credentials: dict) \
  228. -> Tuple[ProviderModel, dict]:
  229. """
  230. Validate custom model credentials.
  231. :param model_type: model type
  232. :param model: model name
  233. :param credentials: model credentials
  234. :return:
  235. """
  236. # get provider model
  237. provider_model_record = db.session.query(ProviderModel) \
  238. .filter(
  239. ProviderModel.tenant_id == self.tenant_id,
  240. ProviderModel.provider_name == self.provider.provider,
  241. ProviderModel.model_name == model,
  242. ProviderModel.model_type == model_type.to_origin_model_type()
  243. ).first()
  244. # Get provider credential secret variables
  245. provider_credential_secret_variables = self._extract_secret_variables(
  246. self.provider.model_credential_schema.credential_form_schemas
  247. if self.provider.model_credential_schema else []
  248. )
  249. if provider_model_record:
  250. try:
  251. original_credentials = json.loads(
  252. provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {}
  253. except JSONDecodeError:
  254. original_credentials = {}
  255. # decrypt credentials
  256. for key, value in credentials.items():
  257. if key in provider_credential_secret_variables:
  258. # if send [__HIDDEN__] in secret input, it will be same as original value
  259. if value == '[__HIDDEN__]' and key in original_credentials:
  260. credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
  261. model_provider_factory.model_credentials_validate(
  262. provider=self.provider.provider,
  263. model_type=model_type,
  264. model=model,
  265. credentials=credentials
  266. )
  267. model_schema = (
  268. model_provider_factory.get_provider_instance(self.provider.provider)
  269. .get_model_instance(model_type)._get_customizable_model_schema(
  270. model=model,
  271. credentials=credentials
  272. )
  273. )
  274. if model_schema:
  275. credentials['schema'] = json.dumps(encoders.jsonable_encoder(model_schema))
  276. for key, value in credentials.items():
  277. if key in provider_credential_secret_variables:
  278. credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
  279. return provider_model_record, credentials
  280. def add_or_update_custom_model_credentials(self, model_type: ModelType, model: str, credentials: dict) -> None:
  281. """
  282. Add or update custom model credentials.
  283. :param model_type: model type
  284. :param model: model name
  285. :param credentials: model credentials
  286. :return:
  287. """
  288. # validate custom model config
  289. provider_model_record, credentials = self.custom_model_credentials_validate(model_type, model, credentials)
  290. # save provider model
  291. # Note: Do not switch the preferred provider, which allows users to use quotas first
  292. if provider_model_record:
  293. provider_model_record.encrypted_config = json.dumps(credentials)
  294. provider_model_record.is_valid = True
  295. provider_model_record.updated_at = datetime.datetime.utcnow()
  296. db.session.commit()
  297. else:
  298. provider_model_record = ProviderModel(
  299. tenant_id=self.tenant_id,
  300. provider_name=self.provider.provider,
  301. model_name=model,
  302. model_type=model_type.to_origin_model_type(),
  303. encrypted_config=json.dumps(credentials),
  304. is_valid=True
  305. )
  306. db.session.add(provider_model_record)
  307. db.session.commit()
  308. provider_model_credentials_cache = ProviderCredentialsCache(
  309. tenant_id=self.tenant_id,
  310. identity_id=provider_model_record.id,
  311. cache_type=ProviderCredentialsCacheType.MODEL
  312. )
  313. provider_model_credentials_cache.delete()
  314. def delete_custom_model_credentials(self, model_type: ModelType, model: str) -> None:
  315. """
  316. Delete custom model credentials.
  317. :param model_type: model type
  318. :param model: model name
  319. :return:
  320. """
  321. # get provider model
  322. provider_model_record = db.session.query(ProviderModel) \
  323. .filter(
  324. ProviderModel.tenant_id == self.tenant_id,
  325. ProviderModel.provider_name == self.provider.provider,
  326. ProviderModel.model_name == model,
  327. ProviderModel.model_type == model_type.to_origin_model_type()
  328. ).first()
  329. # delete provider model
  330. if provider_model_record:
  331. db.session.delete(provider_model_record)
  332. db.session.commit()
  333. provider_model_credentials_cache = ProviderCredentialsCache(
  334. tenant_id=self.tenant_id,
  335. identity_id=provider_model_record.id,
  336. cache_type=ProviderCredentialsCacheType.MODEL
  337. )
  338. provider_model_credentials_cache.delete()
  339. def get_provider_instance(self) -> ModelProvider:
  340. """
  341. Get provider instance.
  342. :return:
  343. """
  344. return model_provider_factory.get_provider_instance(self.provider.provider)
  345. def get_model_type_instance(self, model_type: ModelType) -> AIModel:
  346. """
  347. Get current model type instance.
  348. :param model_type: model type
  349. :return:
  350. """
  351. # Get provider instance
  352. provider_instance = self.get_provider_instance()
  353. # Get model instance of LLM
  354. return provider_instance.get_model_instance(model_type)
  355. def switch_preferred_provider_type(self, provider_type: ProviderType) -> None:
  356. """
  357. Switch preferred provider type.
  358. :param provider_type:
  359. :return:
  360. """
  361. if provider_type == self.preferred_provider_type:
  362. return
  363. if provider_type == ProviderType.SYSTEM and not self.system_configuration.enabled:
  364. return
  365. # get preferred provider
  366. preferred_model_provider = db.session.query(TenantPreferredModelProvider) \
  367. .filter(
  368. TenantPreferredModelProvider.tenant_id == self.tenant_id,
  369. TenantPreferredModelProvider.provider_name == self.provider.provider
  370. ).first()
  371. if preferred_model_provider:
  372. preferred_model_provider.preferred_provider_type = provider_type.value
  373. else:
  374. preferred_model_provider = TenantPreferredModelProvider(
  375. tenant_id=self.tenant_id,
  376. provider_name=self.provider.provider,
  377. preferred_provider_type=provider_type.value
  378. )
  379. db.session.add(preferred_model_provider)
  380. db.session.commit()
  381. def _extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]:
  382. """
  383. Extract secret input form variables.
  384. :param credential_form_schemas:
  385. :return:
  386. """
  387. secret_input_form_variables = []
  388. for credential_form_schema in credential_form_schemas:
  389. if credential_form_schema.type == FormType.SECRET_INPUT:
  390. secret_input_form_variables.append(credential_form_schema.variable)
  391. return secret_input_form_variables
  392. def _obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]) -> dict:
  393. """
  394. Obfuscated credentials.
  395. :param credentials: credentials
  396. :param credential_form_schemas: credential form schemas
  397. :return:
  398. """
  399. # Get provider credential secret variables
  400. credential_secret_variables = self._extract_secret_variables(
  401. credential_form_schemas
  402. )
  403. # Obfuscate provider credentials
  404. copy_credentials = credentials.copy()
  405. for key, value in copy_credentials.items():
  406. if key in credential_secret_variables:
  407. copy_credentials[key] = encrypter.obfuscated_token(value)
  408. return copy_credentials
  409. def get_provider_model(self, model_type: ModelType,
  410. model: str,
  411. only_active: bool = False) -> Optional[ModelWithProviderEntity]:
  412. """
  413. Get provider model.
  414. :param model_type: model type
  415. :param model: model name
  416. :param only_active: return active model only
  417. :return:
  418. """
  419. provider_models = self.get_provider_models(model_type, only_active)
  420. for provider_model in provider_models:
  421. if provider_model.model == model:
  422. return provider_model
  423. return None
  424. def get_provider_models(self, model_type: Optional[ModelType] = None,
  425. only_active: bool = False) -> list[ModelWithProviderEntity]:
  426. """
  427. Get provider models.
  428. :param model_type: model type
  429. :param only_active: only active models
  430. :return:
  431. """
  432. provider_instance = self.get_provider_instance()
  433. model_types = []
  434. if model_type:
  435. model_types.append(model_type)
  436. else:
  437. model_types = provider_instance.get_provider_schema().supported_model_types
  438. if self.using_provider_type == ProviderType.SYSTEM:
  439. provider_models = self._get_system_provider_models(
  440. model_types=model_types,
  441. provider_instance=provider_instance
  442. )
  443. else:
  444. provider_models = self._get_custom_provider_models(
  445. model_types=model_types,
  446. provider_instance=provider_instance
  447. )
  448. if only_active:
  449. provider_models = [m for m in provider_models if m.status == ModelStatus.ACTIVE]
  450. # resort provider_models
  451. return sorted(provider_models, key=lambda x: x.model_type.value)
  452. def _get_system_provider_models(self,
  453. model_types: list[ModelType],
  454. provider_instance: ModelProvider) -> list[ModelWithProviderEntity]:
  455. """
  456. Get system provider models.
  457. :param model_types: model types
  458. :param provider_instance: provider instance
  459. :return:
  460. """
  461. provider_models = []
  462. for model_type in model_types:
  463. provider_models.extend(
  464. [
  465. ModelWithProviderEntity(
  466. model=m.model,
  467. label=m.label,
  468. model_type=m.model_type,
  469. features=m.features,
  470. fetch_from=m.fetch_from,
  471. model_properties=m.model_properties,
  472. deprecated=m.deprecated,
  473. provider=SimpleModelProviderEntity(self.provider),
  474. status=ModelStatus.ACTIVE
  475. )
  476. for m in provider_instance.models(model_type)
  477. ]
  478. )
  479. if self.provider.provider not in original_provider_configurate_methods:
  480. original_provider_configurate_methods[self.provider.provider] = []
  481. for configurate_method in provider_instance.get_provider_schema().configurate_methods:
  482. original_provider_configurate_methods[self.provider.provider].append(configurate_method)
  483. should_use_custom_model = False
  484. if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
  485. should_use_custom_model = True
  486. for quota_configuration in self.system_configuration.quota_configurations:
  487. if self.system_configuration.current_quota_type != quota_configuration.quota_type:
  488. continue
  489. restrict_models = quota_configuration.restrict_models
  490. if len(restrict_models) == 0:
  491. break
  492. if should_use_custom_model:
  493. if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
  494. # only customizable model
  495. for restrict_model in restrict_models:
  496. copy_credentials = self.system_configuration.credentials.copy()
  497. if restrict_model.base_model_name:
  498. copy_credentials['base_model_name'] = restrict_model.base_model_name
  499. try:
  500. custom_model_schema = (
  501. provider_instance.get_model_instance(restrict_model.model_type)
  502. .get_customizable_model_schema_from_credentials(
  503. restrict_model.model,
  504. copy_credentials
  505. )
  506. )
  507. except Exception as ex:
  508. logger.warning(f'get custom model schema failed, {ex}')
  509. continue
  510. if not custom_model_schema:
  511. continue
  512. if custom_model_schema.model_type not in model_types:
  513. continue
  514. provider_models.append(
  515. ModelWithProviderEntity(
  516. model=custom_model_schema.model,
  517. label=custom_model_schema.label,
  518. model_type=custom_model_schema.model_type,
  519. features=custom_model_schema.features,
  520. fetch_from=FetchFrom.PREDEFINED_MODEL,
  521. model_properties=custom_model_schema.model_properties,
  522. deprecated=custom_model_schema.deprecated,
  523. provider=SimpleModelProviderEntity(self.provider),
  524. status=ModelStatus.ACTIVE
  525. )
  526. )
  527. # if llm name not in restricted llm list, remove it
  528. restrict_model_names = [rm.model for rm in restrict_models]
  529. for m in provider_models:
  530. if m.model_type == ModelType.LLM and m.model not in restrict_model_names:
  531. m.status = ModelStatus.NO_PERMISSION
  532. elif not quota_configuration.is_valid:
  533. m.status = ModelStatus.QUOTA_EXCEEDED
  534. return provider_models
  535. def _get_custom_provider_models(self,
  536. model_types: list[ModelType],
  537. provider_instance: ModelProvider) -> list[ModelWithProviderEntity]:
  538. """
  539. Get custom provider models.
  540. :param model_types: model types
  541. :param provider_instance: provider instance
  542. :return:
  543. """
  544. provider_models = []
  545. credentials = None
  546. if self.custom_configuration.provider:
  547. credentials = self.custom_configuration.provider.credentials
  548. for model_type in model_types:
  549. if model_type not in self.provider.supported_model_types:
  550. continue
  551. models = provider_instance.models(model_type)
  552. for m in models:
  553. provider_models.append(
  554. ModelWithProviderEntity(
  555. model=m.model,
  556. label=m.label,
  557. model_type=m.model_type,
  558. features=m.features,
  559. fetch_from=m.fetch_from,
  560. model_properties=m.model_properties,
  561. deprecated=m.deprecated,
  562. provider=SimpleModelProviderEntity(self.provider),
  563. status=ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE
  564. )
  565. )
  566. # custom models
  567. for model_configuration in self.custom_configuration.models:
  568. if model_configuration.model_type not in model_types:
  569. continue
  570. try:
  571. custom_model_schema = (
  572. provider_instance.get_model_instance(model_configuration.model_type)
  573. .get_customizable_model_schema_from_credentials(
  574. model_configuration.model,
  575. model_configuration.credentials
  576. )
  577. )
  578. except Exception as ex:
  579. logger.warning(f'get custom model schema failed, {ex}')
  580. continue
  581. if not custom_model_schema:
  582. continue
  583. provider_models.append(
  584. ModelWithProviderEntity(
  585. model=custom_model_schema.model,
  586. label=custom_model_schema.label,
  587. model_type=custom_model_schema.model_type,
  588. features=custom_model_schema.features,
  589. fetch_from=custom_model_schema.fetch_from,
  590. model_properties=custom_model_schema.model_properties,
  591. deprecated=custom_model_schema.deprecated,
  592. provider=SimpleModelProviderEntity(self.provider),
  593. status=ModelStatus.ACTIVE
  594. )
  595. )
  596. return provider_models
  597. class ProviderConfigurations(BaseModel):
  598. """
  599. Model class for provider configuration dict.
  600. """
  601. tenant_id: str
  602. configurations: Dict[str, ProviderConfiguration] = {}
  603. def __init__(self, tenant_id: str):
  604. super().__init__(tenant_id=tenant_id)
  605. def get_models(self,
  606. provider: Optional[str] = None,
  607. model_type: Optional[ModelType] = None,
  608. only_active: bool = False) \
  609. -> list[ModelWithProviderEntity]:
  610. """
  611. Get available models.
  612. If preferred provider type is `system`:
  613. Get the current **system mode** if provider supported,
  614. if all system modes are not available (no quota), it is considered to be the **custom credential mode**.
  615. If there is no model configured in custom mode, it is treated as no_configure.
  616. system > custom > no_configure
  617. If preferred provider type is `custom`:
  618. If custom credentials are configured, it is treated as custom mode.
  619. Otherwise, get the current **system mode** if supported,
  620. If all system modes are not available (no quota), it is treated as no_configure.
  621. custom > system > no_configure
  622. If real mode is `system`, use system credentials to get models,
  623. paid quotas > provider free quotas > system free quotas
  624. include pre-defined models (exclude GPT-4, status marked as `no_permission`).
  625. If real mode is `custom`, use workspace custom credentials to get models,
  626. include pre-defined models, custom models(manual append).
  627. If real mode is `no_configure`, only return pre-defined models from `model runtime`.
  628. (model status marked as `no_configure` if preferred provider type is `custom` otherwise `quota_exceeded`)
  629. model status marked as `active` is available.
  630. :param provider: provider name
  631. :param model_type: model type
  632. :param only_active: only active models
  633. :return:
  634. """
  635. all_models = []
  636. for provider_configuration in self.values():
  637. if provider and provider_configuration.provider.provider != provider:
  638. continue
  639. all_models.extend(provider_configuration.get_provider_models(model_type, only_active))
  640. return all_models
  641. def to_list(self) -> List[ProviderConfiguration]:
  642. """
  643. Convert to list.
  644. :return:
  645. """
  646. return list(self.values())
  647. def __getitem__(self, key):
  648. return self.configurations[key]
  649. def __setitem__(self, key, value):
  650. self.configurations[key] = value
  651. def __iter__(self):
  652. return iter(self.configurations)
  653. def values(self) -> Iterator[ProviderConfiguration]:
  654. return self.configurations.values()
  655. def get(self, key, default=None):
  656. return self.configurations.get(key, default)
  657. class ProviderModelBundle(BaseModel):
  658. """
  659. Provider model bundle.
  660. """
  661. configuration: ProviderConfiguration
  662. provider_instance: ModelProvider
  663. model_type_instance: AIModel
  664. class Config:
  665. """Configuration for this pydantic object."""
  666. arbitrary_types_allowed = True