provider_configuration.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797
  1. import datetime
  2. import json
  3. import logging
  4. from json import JSONDecodeError
  5. from typing import Dict, Iterator, List, Optional, Tuple
  6. from pydantic import BaseModel
  7. from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
  8. from core.entities.provider_entities import CustomConfiguration, SystemConfiguration, SystemConfigurationStatus
  9. from core.helper import encrypter
  10. from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
  11. from core.model_runtime.entities.model_entities import FetchFrom, ModelType
  12. from core.model_runtime.entities.provider_entities import (
  13. ConfigurateMethod,
  14. CredentialFormSchema,
  15. FormType,
  16. ProviderEntity,
  17. )
  18. from core.model_runtime.model_providers import model_provider_factory
  19. from core.model_runtime.model_providers.__base.ai_model import AIModel
  20. from core.model_runtime.model_providers.__base.model_provider import ModelProvider
  21. from extensions.ext_database import db
  22. from models.provider import Provider, ProviderModel, ProviderType, TenantPreferredModelProvider
  23. logger = logging.getLogger(__name__)
  24. original_provider_configurate_methods = {}
  25. class ProviderConfiguration(BaseModel):
  26. """
  27. Model class for provider configuration.
  28. """
  29. tenant_id: str
  30. provider: ProviderEntity
  31. preferred_provider_type: ProviderType
  32. using_provider_type: ProviderType
  33. system_configuration: SystemConfiguration
  34. custom_configuration: CustomConfiguration
  35. def __init__(self, **data):
  36. super().__init__(**data)
  37. if self.provider.provider not in original_provider_configurate_methods:
  38. original_provider_configurate_methods[self.provider.provider] = []
  39. for configurate_method in self.provider.configurate_methods:
  40. original_provider_configurate_methods[self.provider.provider].append(configurate_method)
  41. if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
  42. if (any([len(quota_configuration.restrict_models) > 0
  43. for quota_configuration in self.system_configuration.quota_configurations])
  44. and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods):
  45. self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL)
  46. def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]:
  47. """
  48. Get current credentials.
  49. :param model_type: model type
  50. :param model: model name
  51. :return:
  52. """
  53. if self.using_provider_type == ProviderType.SYSTEM:
  54. restrict_models = []
  55. for quota_configuration in self.system_configuration.quota_configurations:
  56. if self.system_configuration.current_quota_type != quota_configuration.quota_type:
  57. continue
  58. restrict_models = quota_configuration.restrict_models
  59. copy_credentials = self.system_configuration.credentials.copy()
  60. if restrict_models:
  61. for restrict_model in restrict_models:
  62. if (restrict_model.model_type == model_type
  63. and restrict_model.model == model
  64. and restrict_model.base_model_name):
  65. copy_credentials['base_model_name'] = restrict_model.base_model_name
  66. return copy_credentials
  67. else:
  68. if self.custom_configuration.models:
  69. for model_configuration in self.custom_configuration.models:
  70. if model_configuration.model_type == model_type and model_configuration.model == model:
  71. return model_configuration.credentials
  72. if self.custom_configuration.provider:
  73. return self.custom_configuration.provider.credentials
  74. else:
  75. return None
  76. def get_system_configuration_status(self) -> SystemConfigurationStatus:
  77. """
  78. Get system configuration status.
  79. :return:
  80. """
  81. if self.system_configuration.enabled is False:
  82. return SystemConfigurationStatus.UNSUPPORTED
  83. current_quota_type = self.system_configuration.current_quota_type
  84. current_quota_configuration = next(
  85. (q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type),
  86. None
  87. )
  88. return SystemConfigurationStatus.ACTIVE if current_quota_configuration.is_valid else \
  89. SystemConfigurationStatus.QUOTA_EXCEEDED
  90. def is_custom_configuration_available(self) -> bool:
  91. """
  92. Check custom configuration available.
  93. :return:
  94. """
  95. return (self.custom_configuration.provider is not None
  96. or len(self.custom_configuration.models) > 0)
  97. def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]:
  98. """
  99. Get custom credentials.
  100. :param obfuscated: obfuscated secret data in credentials
  101. :return:
  102. """
  103. if self.custom_configuration.provider is None:
  104. return None
  105. credentials = self.custom_configuration.provider.credentials
  106. if not obfuscated:
  107. return credentials
  108. # Obfuscate credentials
  109. return self._obfuscated_credentials(
  110. credentials=credentials,
  111. credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas
  112. if self.provider.provider_credential_schema else []
  113. )
  114. def custom_credentials_validate(self, credentials: dict) -> Tuple[Provider, dict]:
  115. """
  116. Validate custom credentials.
  117. :param credentials: provider credentials
  118. :return:
  119. """
  120. # get provider
  121. provider_record = db.session.query(Provider) \
  122. .filter(
  123. Provider.tenant_id == self.tenant_id,
  124. Provider.provider_name == self.provider.provider,
  125. Provider.provider_type == ProviderType.CUSTOM.value
  126. ).first()
  127. # Get provider credential secret variables
  128. provider_credential_secret_variables = self._extract_secret_variables(
  129. self.provider.provider_credential_schema.credential_form_schemas
  130. if self.provider.provider_credential_schema else []
  131. )
  132. if provider_record:
  133. try:
  134. # fix origin data
  135. if provider_record.encrypted_config:
  136. if not provider_record.encrypted_config.startswith("{"):
  137. original_credentials = {
  138. "openai_api_key": provider_record.encrypted_config
  139. }
  140. else:
  141. original_credentials = json.loads(provider_record.encrypted_config)
  142. else:
  143. original_credentials = {}
  144. except JSONDecodeError:
  145. original_credentials = {}
  146. # encrypt credentials
  147. for key, value in credentials.items():
  148. if key in provider_credential_secret_variables:
  149. # if send [__HIDDEN__] in secret input, it will be same as original value
  150. if value == '[__HIDDEN__]' and key in original_credentials:
  151. credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
  152. credentials = model_provider_factory.provider_credentials_validate(
  153. self.provider.provider,
  154. credentials
  155. )
  156. for key, value in credentials.items():
  157. if key in provider_credential_secret_variables:
  158. credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
  159. return provider_record, credentials
  160. def add_or_update_custom_credentials(self, credentials: dict) -> None:
  161. """
  162. Add or update custom provider credentials.
  163. :param credentials:
  164. :return:
  165. """
  166. # validate custom provider config
  167. provider_record, credentials = self.custom_credentials_validate(credentials)
  168. # save provider
  169. # Note: Do not switch the preferred provider, which allows users to use quotas first
  170. if provider_record:
  171. provider_record.encrypted_config = json.dumps(credentials)
  172. provider_record.is_valid = True
  173. provider_record.updated_at = datetime.datetime.utcnow()
  174. db.session.commit()
  175. else:
  176. provider_record = Provider(
  177. tenant_id=self.tenant_id,
  178. provider_name=self.provider.provider,
  179. provider_type=ProviderType.CUSTOM.value,
  180. encrypted_config=json.dumps(credentials),
  181. is_valid=True
  182. )
  183. db.session.add(provider_record)
  184. db.session.commit()
  185. provider_model_credentials_cache = ProviderCredentialsCache(
  186. tenant_id=self.tenant_id,
  187. identity_id=provider_record.id,
  188. cache_type=ProviderCredentialsCacheType.PROVIDER
  189. )
  190. provider_model_credentials_cache.delete()
  191. self.switch_preferred_provider_type(ProviderType.CUSTOM)
  192. def delete_custom_credentials(self) -> None:
  193. """
  194. Delete custom provider credentials.
  195. :return:
  196. """
  197. # get provider
  198. provider_record = db.session.query(Provider) \
  199. .filter(
  200. Provider.tenant_id == self.tenant_id,
  201. Provider.provider_name == self.provider.provider,
  202. Provider.provider_type == ProviderType.CUSTOM.value
  203. ).first()
  204. # delete provider
  205. if provider_record:
  206. self.switch_preferred_provider_type(ProviderType.SYSTEM)
  207. db.session.delete(provider_record)
  208. db.session.commit()
  209. provider_model_credentials_cache = ProviderCredentialsCache(
  210. tenant_id=self.tenant_id,
  211. identity_id=provider_record.id,
  212. cache_type=ProviderCredentialsCacheType.PROVIDER
  213. )
  214. provider_model_credentials_cache.delete()
  215. def get_custom_model_credentials(self, model_type: ModelType, model: str, obfuscated: bool = False) \
  216. -> Optional[dict]:
  217. """
  218. Get custom model credentials.
  219. :param model_type: model type
  220. :param model: model name
  221. :param obfuscated: obfuscated secret data in credentials
  222. :return:
  223. """
  224. if not self.custom_configuration.models:
  225. return None
  226. for model_configuration in self.custom_configuration.models:
  227. if model_configuration.model_type == model_type and model_configuration.model == model:
  228. credentials = model_configuration.credentials
  229. if not obfuscated:
  230. return credentials
  231. # Obfuscate credentials
  232. return self._obfuscated_credentials(
  233. credentials=credentials,
  234. credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas
  235. if self.provider.model_credential_schema else []
  236. )
  237. return None
  238. def custom_model_credentials_validate(self, model_type: ModelType, model: str, credentials: dict) \
  239. -> Tuple[ProviderModel, dict]:
  240. """
  241. Validate custom model credentials.
  242. :param model_type: model type
  243. :param model: model name
  244. :param credentials: model credentials
  245. :return:
  246. """
  247. # get provider model
  248. provider_model_record = db.session.query(ProviderModel) \
  249. .filter(
  250. ProviderModel.tenant_id == self.tenant_id,
  251. ProviderModel.provider_name == self.provider.provider,
  252. ProviderModel.model_name == model,
  253. ProviderModel.model_type == model_type.to_origin_model_type()
  254. ).first()
  255. # Get provider credential secret variables
  256. provider_credential_secret_variables = self._extract_secret_variables(
  257. self.provider.model_credential_schema.credential_form_schemas
  258. if self.provider.model_credential_schema else []
  259. )
  260. if provider_model_record:
  261. try:
  262. original_credentials = json.loads(
  263. provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {}
  264. except JSONDecodeError:
  265. original_credentials = {}
  266. # decrypt credentials
  267. for key, value in credentials.items():
  268. if key in provider_credential_secret_variables:
  269. # if send [__HIDDEN__] in secret input, it will be same as original value
  270. if value == '[__HIDDEN__]' and key in original_credentials:
  271. credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
  272. credentials = model_provider_factory.model_credentials_validate(
  273. provider=self.provider.provider,
  274. model_type=model_type,
  275. model=model,
  276. credentials=credentials
  277. )
  278. for key, value in credentials.items():
  279. if key in provider_credential_secret_variables:
  280. credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
  281. return provider_model_record, credentials
  282. def add_or_update_custom_model_credentials(self, model_type: ModelType, model: str, credentials: dict) -> None:
  283. """
  284. Add or update custom model credentials.
  285. :param model_type: model type
  286. :param model: model name
  287. :param credentials: model credentials
  288. :return:
  289. """
  290. # validate custom model config
  291. provider_model_record, credentials = self.custom_model_credentials_validate(model_type, model, credentials)
  292. # save provider model
  293. # Note: Do not switch the preferred provider, which allows users to use quotas first
  294. if provider_model_record:
  295. provider_model_record.encrypted_config = json.dumps(credentials)
  296. provider_model_record.is_valid = True
  297. provider_model_record.updated_at = datetime.datetime.utcnow()
  298. db.session.commit()
  299. else:
  300. provider_model_record = ProviderModel(
  301. tenant_id=self.tenant_id,
  302. provider_name=self.provider.provider,
  303. model_name=model,
  304. model_type=model_type.to_origin_model_type(),
  305. encrypted_config=json.dumps(credentials),
  306. is_valid=True
  307. )
  308. db.session.add(provider_model_record)
  309. db.session.commit()
  310. provider_model_credentials_cache = ProviderCredentialsCache(
  311. tenant_id=self.tenant_id,
  312. identity_id=provider_model_record.id,
  313. cache_type=ProviderCredentialsCacheType.MODEL
  314. )
  315. provider_model_credentials_cache.delete()
  316. def delete_custom_model_credentials(self, model_type: ModelType, model: str) -> None:
  317. """
  318. Delete custom model credentials.
  319. :param model_type: model type
  320. :param model: model name
  321. :return:
  322. """
  323. # get provider model
  324. provider_model_record = db.session.query(ProviderModel) \
  325. .filter(
  326. ProviderModel.tenant_id == self.tenant_id,
  327. ProviderModel.provider_name == self.provider.provider,
  328. ProviderModel.model_name == model,
  329. ProviderModel.model_type == model_type.to_origin_model_type()
  330. ).first()
  331. # delete provider model
  332. if provider_model_record:
  333. db.session.delete(provider_model_record)
  334. db.session.commit()
  335. provider_model_credentials_cache = ProviderCredentialsCache(
  336. tenant_id=self.tenant_id,
  337. identity_id=provider_model_record.id,
  338. cache_type=ProviderCredentialsCacheType.MODEL
  339. )
  340. provider_model_credentials_cache.delete()
  341. def get_provider_instance(self) -> ModelProvider:
  342. """
  343. Get provider instance.
  344. :return:
  345. """
  346. return model_provider_factory.get_provider_instance(self.provider.provider)
  347. def get_model_type_instance(self, model_type: ModelType) -> AIModel:
  348. """
  349. Get current model type instance.
  350. :param model_type: model type
  351. :return:
  352. """
  353. # Get provider instance
  354. provider_instance = self.get_provider_instance()
  355. # Get model instance of LLM
  356. return provider_instance.get_model_instance(model_type)
  357. def switch_preferred_provider_type(self, provider_type: ProviderType) -> None:
  358. """
  359. Switch preferred provider type.
  360. :param provider_type:
  361. :return:
  362. """
  363. if provider_type == self.preferred_provider_type:
  364. return
  365. if provider_type == ProviderType.SYSTEM and not self.system_configuration.enabled:
  366. return
  367. # get preferred provider
  368. preferred_model_provider = db.session.query(TenantPreferredModelProvider) \
  369. .filter(
  370. TenantPreferredModelProvider.tenant_id == self.tenant_id,
  371. TenantPreferredModelProvider.provider_name == self.provider.provider
  372. ).first()
  373. if preferred_model_provider:
  374. preferred_model_provider.preferred_provider_type = provider_type.value
  375. else:
  376. preferred_model_provider = TenantPreferredModelProvider(
  377. tenant_id=self.tenant_id,
  378. provider_name=self.provider.provider,
  379. preferred_provider_type=provider_type.value
  380. )
  381. db.session.add(preferred_model_provider)
  382. db.session.commit()
  383. def _extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]:
  384. """
  385. Extract secret input form variables.
  386. :param credential_form_schemas:
  387. :return:
  388. """
  389. secret_input_form_variables = []
  390. for credential_form_schema in credential_form_schemas:
  391. if credential_form_schema.type == FormType.SECRET_INPUT:
  392. secret_input_form_variables.append(credential_form_schema.variable)
  393. return secret_input_form_variables
  394. def _obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]) -> dict:
  395. """
  396. Obfuscated credentials.
  397. :param credentials: credentials
  398. :param credential_form_schemas: credential form schemas
  399. :return:
  400. """
  401. # Get provider credential secret variables
  402. credential_secret_variables = self._extract_secret_variables(
  403. credential_form_schemas
  404. )
  405. # Obfuscate provider credentials
  406. copy_credentials = credentials.copy()
  407. for key, value in copy_credentials.items():
  408. if key in credential_secret_variables:
  409. copy_credentials[key] = encrypter.obfuscated_token(value)
  410. return copy_credentials
  411. def get_provider_model(self, model_type: ModelType,
  412. model: str,
  413. only_active: bool = False) -> Optional[ModelWithProviderEntity]:
  414. """
  415. Get provider model.
  416. :param model_type: model type
  417. :param model: model name
  418. :param only_active: return active model only
  419. :return:
  420. """
  421. provider_models = self.get_provider_models(model_type, only_active)
  422. for provider_model in provider_models:
  423. if provider_model.model == model:
  424. return provider_model
  425. return None
  426. def get_provider_models(self, model_type: Optional[ModelType] = None,
  427. only_active: bool = False) -> list[ModelWithProviderEntity]:
  428. """
  429. Get provider models.
  430. :param model_type: model type
  431. :param only_active: only active models
  432. :return:
  433. """
  434. provider_instance = self.get_provider_instance()
  435. model_types = []
  436. if model_type:
  437. model_types.append(model_type)
  438. else:
  439. model_types = provider_instance.get_provider_schema().supported_model_types
  440. if self.using_provider_type == ProviderType.SYSTEM:
  441. provider_models = self._get_system_provider_models(
  442. model_types=model_types,
  443. provider_instance=provider_instance
  444. )
  445. else:
  446. provider_models = self._get_custom_provider_models(
  447. model_types=model_types,
  448. provider_instance=provider_instance
  449. )
  450. if only_active:
  451. provider_models = [m for m in provider_models if m.status == ModelStatus.ACTIVE]
  452. # resort provider_models
  453. return sorted(provider_models, key=lambda x: x.model_type.value)
  454. def _get_system_provider_models(self,
  455. model_types: list[ModelType],
  456. provider_instance: ModelProvider) -> list[ModelWithProviderEntity]:
  457. """
  458. Get system provider models.
  459. :param model_types: model types
  460. :param provider_instance: provider instance
  461. :return:
  462. """
  463. provider_models = []
  464. for model_type in model_types:
  465. provider_models.extend(
  466. [
  467. ModelWithProviderEntity(
  468. model=m.model,
  469. label=m.label,
  470. model_type=m.model_type,
  471. features=m.features,
  472. fetch_from=m.fetch_from,
  473. model_properties=m.model_properties,
  474. deprecated=m.deprecated,
  475. provider=SimpleModelProviderEntity(self.provider),
  476. status=ModelStatus.ACTIVE
  477. )
  478. for m in provider_instance.models(model_type)
  479. ]
  480. )
  481. if self.provider.provider not in original_provider_configurate_methods:
  482. original_provider_configurate_methods[self.provider.provider] = []
  483. for configurate_method in provider_instance.get_provider_schema().configurate_methods:
  484. original_provider_configurate_methods[self.provider.provider].append(configurate_method)
  485. should_use_custom_model = False
  486. if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
  487. should_use_custom_model = True
  488. for quota_configuration in self.system_configuration.quota_configurations:
  489. if self.system_configuration.current_quota_type != quota_configuration.quota_type:
  490. continue
  491. restrict_models = quota_configuration.restrict_models
  492. if len(restrict_models) == 0:
  493. break
  494. if should_use_custom_model:
  495. if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
  496. # only customizable model
  497. for restrict_model in restrict_models:
  498. copy_credentials = self.system_configuration.credentials.copy()
  499. if restrict_model.base_model_name:
  500. copy_credentials['base_model_name'] = restrict_model.base_model_name
  501. try:
  502. custom_model_schema = (
  503. provider_instance.get_model_instance(restrict_model.model_type)
  504. .get_customizable_model_schema_from_credentials(
  505. restrict_model.model,
  506. copy_credentials
  507. )
  508. )
  509. except Exception as ex:
  510. logger.warning(f'get custom model schema failed, {ex}')
  511. continue
  512. if not custom_model_schema:
  513. continue
  514. if custom_model_schema.model_type not in model_types:
  515. continue
  516. provider_models.append(
  517. ModelWithProviderEntity(
  518. model=custom_model_schema.model,
  519. label=custom_model_schema.label,
  520. model_type=custom_model_schema.model_type,
  521. features=custom_model_schema.features,
  522. fetch_from=FetchFrom.PREDEFINED_MODEL,
  523. model_properties=custom_model_schema.model_properties,
  524. deprecated=custom_model_schema.deprecated,
  525. provider=SimpleModelProviderEntity(self.provider),
  526. status=ModelStatus.ACTIVE
  527. )
  528. )
  529. # if llm name not in restricted llm list, remove it
  530. restrict_model_names = [rm.model for rm in restrict_models]
  531. for m in provider_models:
  532. if m.model_type == ModelType.LLM and m.model not in restrict_model_names:
  533. m.status = ModelStatus.NO_PERMISSION
  534. elif not quota_configuration.is_valid:
  535. m.status = ModelStatus.QUOTA_EXCEEDED
  536. return provider_models
  537. def _get_custom_provider_models(self,
  538. model_types: list[ModelType],
  539. provider_instance: ModelProvider) -> list[ModelWithProviderEntity]:
  540. """
  541. Get custom provider models.
  542. :param model_types: model types
  543. :param provider_instance: provider instance
  544. :return:
  545. """
  546. provider_models = []
  547. credentials = None
  548. if self.custom_configuration.provider:
  549. credentials = self.custom_configuration.provider.credentials
  550. for model_type in model_types:
  551. if model_type not in self.provider.supported_model_types:
  552. continue
  553. models = provider_instance.models(model_type)
  554. for m in models:
  555. provider_models.append(
  556. ModelWithProviderEntity(
  557. model=m.model,
  558. label=m.label,
  559. model_type=m.model_type,
  560. features=m.features,
  561. fetch_from=m.fetch_from,
  562. model_properties=m.model_properties,
  563. deprecated=m.deprecated,
  564. provider=SimpleModelProviderEntity(self.provider),
  565. status=ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE
  566. )
  567. )
  568. # custom models
  569. for model_configuration in self.custom_configuration.models:
  570. if model_configuration.model_type not in model_types:
  571. continue
  572. try:
  573. custom_model_schema = (
  574. provider_instance.get_model_instance(model_configuration.model_type)
  575. .get_customizable_model_schema_from_credentials(
  576. model_configuration.model,
  577. model_configuration.credentials
  578. )
  579. )
  580. except Exception as ex:
  581. logger.warning(f'get custom model schema failed, {ex}')
  582. continue
  583. if not custom_model_schema:
  584. continue
  585. provider_models.append(
  586. ModelWithProviderEntity(
  587. model=custom_model_schema.model,
  588. label=custom_model_schema.label,
  589. model_type=custom_model_schema.model_type,
  590. features=custom_model_schema.features,
  591. fetch_from=custom_model_schema.fetch_from,
  592. model_properties=custom_model_schema.model_properties,
  593. deprecated=custom_model_schema.deprecated,
  594. provider=SimpleModelProviderEntity(self.provider),
  595. status=ModelStatus.ACTIVE
  596. )
  597. )
  598. return provider_models
  599. class ProviderConfigurations(BaseModel):
  600. """
  601. Model class for provider configuration dict.
  602. """
  603. tenant_id: str
  604. configurations: Dict[str, ProviderConfiguration] = {}
  605. def __init__(self, tenant_id: str):
  606. super().__init__(tenant_id=tenant_id)
  607. def get_models(self,
  608. provider: Optional[str] = None,
  609. model_type: Optional[ModelType] = None,
  610. only_active: bool = False) \
  611. -> list[ModelWithProviderEntity]:
  612. """
  613. Get available models.
  614. If preferred provider type is `system`:
  615. Get the current **system mode** if provider supported,
  616. if all system modes are not available (no quota), it is considered to be the **custom credential mode**.
  617. If there is no model configured in custom mode, it is treated as no_configure.
  618. system > custom > no_configure
  619. If preferred provider type is `custom`:
  620. If custom credentials are configured, it is treated as custom mode.
  621. Otherwise, get the current **system mode** if supported,
  622. If all system modes are not available (no quota), it is treated as no_configure.
  623. custom > system > no_configure
  624. If real mode is `system`, use system credentials to get models,
  625. paid quotas > provider free quotas > system free quotas
  626. include pre-defined models (exclude GPT-4, status marked as `no_permission`).
  627. If real mode is `custom`, use workspace custom credentials to get models,
  628. include pre-defined models, custom models(manual append).
  629. If real mode is `no_configure`, only return pre-defined models from `model runtime`.
  630. (model status marked as `no_configure` if preferred provider type is `custom` otherwise `quota_exceeded`)
  631. model status marked as `active` is available.
  632. :param provider: provider name
  633. :param model_type: model type
  634. :param only_active: only active models
  635. :return:
  636. """
  637. all_models = []
  638. for provider_configuration in self.values():
  639. if provider and provider_configuration.provider.provider != provider:
  640. continue
  641. all_models.extend(provider_configuration.get_provider_models(model_type, only_active))
  642. return all_models
  643. def to_list(self) -> List[ProviderConfiguration]:
  644. """
  645. Convert to list.
  646. :return:
  647. """
  648. return list(self.values())
  649. def __getitem__(self, key):
  650. return self.configurations[key]
  651. def __setitem__(self, key, value):
  652. self.configurations[key] = value
  653. def __iter__(self):
  654. return iter(self.configurations)
  655. def values(self) -> Iterator[ProviderConfiguration]:
  656. return self.configurations.values()
  657. def get(self, key, default=None):
  658. return self.configurations.get(key, default)
  659. class ProviderModelBundle(BaseModel):
  660. """
  661. Provider model bundle.
  662. """
  663. configuration: ProviderConfiguration
  664. provider_instance: ModelProvider
  665. model_type_instance: AIModel
  666. class Config:
  667. """Configuration for this pydantic object."""
  668. arbitrary_types_allowed = True