provider_configuration.py 40 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027
  1. import datetime
  2. import json
  3. import logging
  4. from collections import defaultdict
  5. from collections.abc import Iterator
  6. from json import JSONDecodeError
  7. from typing import Optional
  8. from pydantic import BaseModel, ConfigDict
  9. from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
  10. from core.entities.provider_entities import (
  11. CustomConfiguration,
  12. ModelSettings,
  13. SystemConfiguration,
  14. SystemConfigurationStatus,
  15. )
  16. from core.helper import encrypter
  17. from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
  18. from core.model_runtime.entities.model_entities import FetchFrom, ModelType
  19. from core.model_runtime.entities.provider_entities import (
  20. ConfigurateMethod,
  21. CredentialFormSchema,
  22. FormType,
  23. ProviderEntity,
  24. )
  25. from core.model_runtime.model_providers import model_provider_factory
  26. from core.model_runtime.model_providers.__base.ai_model import AIModel
  27. from core.model_runtime.model_providers.__base.model_provider import ModelProvider
  28. from extensions.ext_database import db
  29. from models.provider import (
  30. LoadBalancingModelConfig,
  31. Provider,
  32. ProviderModel,
  33. ProviderModelSetting,
  34. ProviderType,
  35. TenantPreferredModelProvider,
  36. )
  37. logger = logging.getLogger(__name__)
  38. original_provider_configurate_methods = {}
  39. class ProviderConfiguration(BaseModel):
  40. """
  41. Model class for provider configuration.
  42. """
  43. tenant_id: str
  44. provider: ProviderEntity
  45. preferred_provider_type: ProviderType
  46. using_provider_type: ProviderType
  47. system_configuration: SystemConfiguration
  48. custom_configuration: CustomConfiguration
  49. model_settings: list[ModelSettings]
  50. # pydantic configs
  51. model_config = ConfigDict(protected_namespaces=())
  52. def __init__(self, **data):
  53. super().__init__(**data)
  54. if self.provider.provider not in original_provider_configurate_methods:
  55. original_provider_configurate_methods[self.provider.provider] = []
  56. for configurate_method in self.provider.configurate_methods:
  57. original_provider_configurate_methods[self.provider.provider].append(configurate_method)
  58. if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
  59. if (any(len(quota_configuration.restrict_models) > 0
  60. for quota_configuration in self.system_configuration.quota_configurations)
  61. and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods):
  62. self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL)
  63. def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]:
  64. """
  65. Get current credentials.
  66. :param model_type: model type
  67. :param model: model name
  68. :return:
  69. """
  70. if self.model_settings:
  71. # check if model is disabled by admin
  72. for model_setting in self.model_settings:
  73. if (model_setting.model_type == model_type
  74. and model_setting.model == model):
  75. if not model_setting.enabled:
  76. raise ValueError(f'Model {model} is disabled.')
  77. if self.using_provider_type == ProviderType.SYSTEM:
  78. restrict_models = []
  79. for quota_configuration in self.system_configuration.quota_configurations:
  80. if self.system_configuration.current_quota_type != quota_configuration.quota_type:
  81. continue
  82. restrict_models = quota_configuration.restrict_models
  83. copy_credentials = self.system_configuration.credentials.copy()
  84. if restrict_models:
  85. for restrict_model in restrict_models:
  86. if (restrict_model.model_type == model_type
  87. and restrict_model.model == model
  88. and restrict_model.base_model_name):
  89. copy_credentials['base_model_name'] = restrict_model.base_model_name
  90. return copy_credentials
  91. else:
  92. credentials = None
  93. if self.custom_configuration.models:
  94. for model_configuration in self.custom_configuration.models:
  95. if model_configuration.model_type == model_type and model_configuration.model == model:
  96. credentials = model_configuration.credentials
  97. break
  98. if self.custom_configuration.provider:
  99. credentials = self.custom_configuration.provider.credentials
  100. return credentials
  101. def get_system_configuration_status(self) -> SystemConfigurationStatus:
  102. """
  103. Get system configuration status.
  104. :return:
  105. """
  106. if self.system_configuration.enabled is False:
  107. return SystemConfigurationStatus.UNSUPPORTED
  108. current_quota_type = self.system_configuration.current_quota_type
  109. current_quota_configuration = next(
  110. (q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type),
  111. None
  112. )
  113. return SystemConfigurationStatus.ACTIVE if current_quota_configuration.is_valid else \
  114. SystemConfigurationStatus.QUOTA_EXCEEDED
  115. def is_custom_configuration_available(self) -> bool:
  116. """
  117. Check custom configuration available.
  118. :return:
  119. """
  120. return (self.custom_configuration.provider is not None
  121. or len(self.custom_configuration.models) > 0)
  122. def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]:
  123. """
  124. Get custom credentials.
  125. :param obfuscated: obfuscated secret data in credentials
  126. :return:
  127. """
  128. if self.custom_configuration.provider is None:
  129. return None
  130. credentials = self.custom_configuration.provider.credentials
  131. if not obfuscated:
  132. return credentials
  133. # Obfuscate credentials
  134. return self.obfuscated_credentials(
  135. credentials=credentials,
  136. credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas
  137. if self.provider.provider_credential_schema else []
  138. )
  139. def custom_credentials_validate(self, credentials: dict) -> tuple[Provider, dict]:
  140. """
  141. Validate custom credentials.
  142. :param credentials: provider credentials
  143. :return:
  144. """
  145. # get provider
  146. provider_record = db.session.query(Provider) \
  147. .filter(
  148. Provider.tenant_id == self.tenant_id,
  149. Provider.provider_name == self.provider.provider,
  150. Provider.provider_type == ProviderType.CUSTOM.value
  151. ).first()
  152. # Get provider credential secret variables
  153. provider_credential_secret_variables = self.extract_secret_variables(
  154. self.provider.provider_credential_schema.credential_form_schemas
  155. if self.provider.provider_credential_schema else []
  156. )
  157. if provider_record:
  158. try:
  159. # fix origin data
  160. if provider_record.encrypted_config:
  161. if not provider_record.encrypted_config.startswith("{"):
  162. original_credentials = {
  163. "openai_api_key": provider_record.encrypted_config
  164. }
  165. else:
  166. original_credentials = json.loads(provider_record.encrypted_config)
  167. else:
  168. original_credentials = {}
  169. except JSONDecodeError:
  170. original_credentials = {}
  171. # encrypt credentials
  172. for key, value in credentials.items():
  173. if key in provider_credential_secret_variables:
  174. # if send [__HIDDEN__] in secret input, it will be same as original value
  175. if value == '[__HIDDEN__]' and key in original_credentials:
  176. credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
  177. credentials = model_provider_factory.provider_credentials_validate(
  178. provider=self.provider.provider,
  179. credentials=credentials
  180. )
  181. for key, value in credentials.items():
  182. if key in provider_credential_secret_variables:
  183. credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
  184. return provider_record, credentials
  185. def add_or_update_custom_credentials(self, credentials: dict) -> None:
  186. """
  187. Add or update custom provider credentials.
  188. :param credentials:
  189. :return:
  190. """
  191. # validate custom provider config
  192. provider_record, credentials = self.custom_credentials_validate(credentials)
  193. # save provider
  194. # Note: Do not switch the preferred provider, which allows users to use quotas first
  195. if provider_record:
  196. provider_record.encrypted_config = json.dumps(credentials)
  197. provider_record.is_valid = True
  198. provider_record.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  199. db.session.commit()
  200. else:
  201. provider_record = Provider(
  202. tenant_id=self.tenant_id,
  203. provider_name=self.provider.provider,
  204. provider_type=ProviderType.CUSTOM.value,
  205. encrypted_config=json.dumps(credentials),
  206. is_valid=True
  207. )
  208. db.session.add(provider_record)
  209. db.session.commit()
  210. provider_model_credentials_cache = ProviderCredentialsCache(
  211. tenant_id=self.tenant_id,
  212. identity_id=provider_record.id,
  213. cache_type=ProviderCredentialsCacheType.PROVIDER
  214. )
  215. provider_model_credentials_cache.delete()
  216. self.switch_preferred_provider_type(ProviderType.CUSTOM)
  217. def delete_custom_credentials(self) -> None:
  218. """
  219. Delete custom provider credentials.
  220. :return:
  221. """
  222. # get provider
  223. provider_record = db.session.query(Provider) \
  224. .filter(
  225. Provider.tenant_id == self.tenant_id,
  226. Provider.provider_name == self.provider.provider,
  227. Provider.provider_type == ProviderType.CUSTOM.value
  228. ).first()
  229. # delete provider
  230. if provider_record:
  231. self.switch_preferred_provider_type(ProviderType.SYSTEM)
  232. db.session.delete(provider_record)
  233. db.session.commit()
  234. provider_model_credentials_cache = ProviderCredentialsCache(
  235. tenant_id=self.tenant_id,
  236. identity_id=provider_record.id,
  237. cache_type=ProviderCredentialsCacheType.PROVIDER
  238. )
  239. provider_model_credentials_cache.delete()
  240. def get_custom_model_credentials(self, model_type: ModelType, model: str, obfuscated: bool = False) \
  241. -> Optional[dict]:
  242. """
  243. Get custom model credentials.
  244. :param model_type: model type
  245. :param model: model name
  246. :param obfuscated: obfuscated secret data in credentials
  247. :return:
  248. """
  249. if not self.custom_configuration.models:
  250. return None
  251. for model_configuration in self.custom_configuration.models:
  252. if model_configuration.model_type == model_type and model_configuration.model == model:
  253. credentials = model_configuration.credentials
  254. if not obfuscated:
  255. return credentials
  256. # Obfuscate credentials
  257. return self.obfuscated_credentials(
  258. credentials=credentials,
  259. credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas
  260. if self.provider.model_credential_schema else []
  261. )
  262. return None
  263. def custom_model_credentials_validate(self, model_type: ModelType, model: str, credentials: dict) \
  264. -> tuple[ProviderModel, dict]:
  265. """
  266. Validate custom model credentials.
  267. :param model_type: model type
  268. :param model: model name
  269. :param credentials: model credentials
  270. :return:
  271. """
  272. # get provider model
  273. provider_model_record = db.session.query(ProviderModel) \
  274. .filter(
  275. ProviderModel.tenant_id == self.tenant_id,
  276. ProviderModel.provider_name == self.provider.provider,
  277. ProviderModel.model_name == model,
  278. ProviderModel.model_type == model_type.to_origin_model_type()
  279. ).first()
  280. # Get provider credential secret variables
  281. provider_credential_secret_variables = self.extract_secret_variables(
  282. self.provider.model_credential_schema.credential_form_schemas
  283. if self.provider.model_credential_schema else []
  284. )
  285. if provider_model_record:
  286. try:
  287. original_credentials = json.loads(
  288. provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {}
  289. except JSONDecodeError:
  290. original_credentials = {}
  291. # decrypt credentials
  292. for key, value in credentials.items():
  293. if key in provider_credential_secret_variables:
  294. # if send [__HIDDEN__] in secret input, it will be same as original value
  295. if value == '[__HIDDEN__]' and key in original_credentials:
  296. credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
  297. credentials = model_provider_factory.model_credentials_validate(
  298. provider=self.provider.provider,
  299. model_type=model_type,
  300. model=model,
  301. credentials=credentials
  302. )
  303. for key, value in credentials.items():
  304. if key in provider_credential_secret_variables:
  305. credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
  306. return provider_model_record, credentials
  307. def add_or_update_custom_model_credentials(self, model_type: ModelType, model: str, credentials: dict) -> None:
  308. """
  309. Add or update custom model credentials.
  310. :param model_type: model type
  311. :param model: model name
  312. :param credentials: model credentials
  313. :return:
  314. """
  315. # validate custom model config
  316. provider_model_record, credentials = self.custom_model_credentials_validate(model_type, model, credentials)
  317. # save provider model
  318. # Note: Do not switch the preferred provider, which allows users to use quotas first
  319. if provider_model_record:
  320. provider_model_record.encrypted_config = json.dumps(credentials)
  321. provider_model_record.is_valid = True
  322. provider_model_record.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  323. db.session.commit()
  324. else:
  325. provider_model_record = ProviderModel(
  326. tenant_id=self.tenant_id,
  327. provider_name=self.provider.provider,
  328. model_name=model,
  329. model_type=model_type.to_origin_model_type(),
  330. encrypted_config=json.dumps(credentials),
  331. is_valid=True
  332. )
  333. db.session.add(provider_model_record)
  334. db.session.commit()
  335. provider_model_credentials_cache = ProviderCredentialsCache(
  336. tenant_id=self.tenant_id,
  337. identity_id=provider_model_record.id,
  338. cache_type=ProviderCredentialsCacheType.MODEL
  339. )
  340. provider_model_credentials_cache.delete()
  341. def delete_custom_model_credentials(self, model_type: ModelType, model: str) -> None:
  342. """
  343. Delete custom model credentials.
  344. :param model_type: model type
  345. :param model: model name
  346. :return:
  347. """
  348. # get provider model
  349. provider_model_record = db.session.query(ProviderModel) \
  350. .filter(
  351. ProviderModel.tenant_id == self.tenant_id,
  352. ProviderModel.provider_name == self.provider.provider,
  353. ProviderModel.model_name == model,
  354. ProviderModel.model_type == model_type.to_origin_model_type()
  355. ).first()
  356. # delete provider model
  357. if provider_model_record:
  358. db.session.delete(provider_model_record)
  359. db.session.commit()
  360. provider_model_credentials_cache = ProviderCredentialsCache(
  361. tenant_id=self.tenant_id,
  362. identity_id=provider_model_record.id,
  363. cache_type=ProviderCredentialsCacheType.MODEL
  364. )
  365. provider_model_credentials_cache.delete()
  366. def enable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting:
  367. """
  368. Enable model.
  369. :param model_type: model type
  370. :param model: model name
  371. :return:
  372. """
  373. model_setting = db.session.query(ProviderModelSetting) \
  374. .filter(
  375. ProviderModelSetting.tenant_id == self.tenant_id,
  376. ProviderModelSetting.provider_name == self.provider.provider,
  377. ProviderModelSetting.model_type == model_type.to_origin_model_type(),
  378. ProviderModelSetting.model_name == model
  379. ).first()
  380. if model_setting:
  381. model_setting.enabled = True
  382. model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  383. db.session.commit()
  384. else:
  385. model_setting = ProviderModelSetting(
  386. tenant_id=self.tenant_id,
  387. provider_name=self.provider.provider,
  388. model_type=model_type.to_origin_model_type(),
  389. model_name=model,
  390. enabled=True
  391. )
  392. db.session.add(model_setting)
  393. db.session.commit()
  394. return model_setting
  395. def disable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting:
  396. """
  397. Disable model.
  398. :param model_type: model type
  399. :param model: model name
  400. :return:
  401. """
  402. model_setting = db.session.query(ProviderModelSetting) \
  403. .filter(
  404. ProviderModelSetting.tenant_id == self.tenant_id,
  405. ProviderModelSetting.provider_name == self.provider.provider,
  406. ProviderModelSetting.model_type == model_type.to_origin_model_type(),
  407. ProviderModelSetting.model_name == model
  408. ).first()
  409. if model_setting:
  410. model_setting.enabled = False
  411. model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  412. db.session.commit()
  413. else:
  414. model_setting = ProviderModelSetting(
  415. tenant_id=self.tenant_id,
  416. provider_name=self.provider.provider,
  417. model_type=model_type.to_origin_model_type(),
  418. model_name=model,
  419. enabled=False
  420. )
  421. db.session.add(model_setting)
  422. db.session.commit()
  423. return model_setting
  424. def get_provider_model_setting(self, model_type: ModelType, model: str) -> Optional[ProviderModelSetting]:
  425. """
  426. Get provider model setting.
  427. :param model_type: model type
  428. :param model: model name
  429. :return:
  430. """
  431. return db.session.query(ProviderModelSetting) \
  432. .filter(
  433. ProviderModelSetting.tenant_id == self.tenant_id,
  434. ProviderModelSetting.provider_name == self.provider.provider,
  435. ProviderModelSetting.model_type == model_type.to_origin_model_type(),
  436. ProviderModelSetting.model_name == model
  437. ).first()
  438. def enable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting:
  439. """
  440. Enable model load balancing.
  441. :param model_type: model type
  442. :param model: model name
  443. :return:
  444. """
  445. load_balancing_config_count = db.session.query(LoadBalancingModelConfig) \
  446. .filter(
  447. LoadBalancingModelConfig.tenant_id == self.tenant_id,
  448. LoadBalancingModelConfig.provider_name == self.provider.provider,
  449. LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
  450. LoadBalancingModelConfig.model_name == model
  451. ).count()
  452. if load_balancing_config_count <= 1:
  453. raise ValueError('Model load balancing configuration must be more than 1.')
  454. model_setting = db.session.query(ProviderModelSetting) \
  455. .filter(
  456. ProviderModelSetting.tenant_id == self.tenant_id,
  457. ProviderModelSetting.provider_name == self.provider.provider,
  458. ProviderModelSetting.model_type == model_type.to_origin_model_type(),
  459. ProviderModelSetting.model_name == model
  460. ).first()
  461. if model_setting:
  462. model_setting.load_balancing_enabled = True
  463. model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  464. db.session.commit()
  465. else:
  466. model_setting = ProviderModelSetting(
  467. tenant_id=self.tenant_id,
  468. provider_name=self.provider.provider,
  469. model_type=model_type.to_origin_model_type(),
  470. model_name=model,
  471. load_balancing_enabled=True
  472. )
  473. db.session.add(model_setting)
  474. db.session.commit()
  475. return model_setting
  476. def disable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting:
  477. """
  478. Disable model load balancing.
  479. :param model_type: model type
  480. :param model: model name
  481. :return:
  482. """
  483. model_setting = db.session.query(ProviderModelSetting) \
  484. .filter(
  485. ProviderModelSetting.tenant_id == self.tenant_id,
  486. ProviderModelSetting.provider_name == self.provider.provider,
  487. ProviderModelSetting.model_type == model_type.to_origin_model_type(),
  488. ProviderModelSetting.model_name == model
  489. ).first()
  490. if model_setting:
  491. model_setting.load_balancing_enabled = False
  492. model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  493. db.session.commit()
  494. else:
  495. model_setting = ProviderModelSetting(
  496. tenant_id=self.tenant_id,
  497. provider_name=self.provider.provider,
  498. model_type=model_type.to_origin_model_type(),
  499. model_name=model,
  500. load_balancing_enabled=False
  501. )
  502. db.session.add(model_setting)
  503. db.session.commit()
  504. return model_setting
  505. def get_provider_instance(self) -> ModelProvider:
  506. """
  507. Get provider instance.
  508. :return:
  509. """
  510. return model_provider_factory.get_provider_instance(self.provider.provider)
  511. def get_model_type_instance(self, model_type: ModelType) -> AIModel:
  512. """
  513. Get current model type instance.
  514. :param model_type: model type
  515. :return:
  516. """
  517. # Get provider instance
  518. provider_instance = self.get_provider_instance()
  519. # Get model instance of LLM
  520. return provider_instance.get_model_instance(model_type)
  521. def switch_preferred_provider_type(self, provider_type: ProviderType) -> None:
  522. """
  523. Switch preferred provider type.
  524. :param provider_type:
  525. :return:
  526. """
  527. if provider_type == self.preferred_provider_type:
  528. return
  529. if provider_type == ProviderType.SYSTEM and not self.system_configuration.enabled:
  530. return
  531. # get preferred provider
  532. preferred_model_provider = db.session.query(TenantPreferredModelProvider) \
  533. .filter(
  534. TenantPreferredModelProvider.tenant_id == self.tenant_id,
  535. TenantPreferredModelProvider.provider_name == self.provider.provider
  536. ).first()
  537. if preferred_model_provider:
  538. preferred_model_provider.preferred_provider_type = provider_type.value
  539. else:
  540. preferred_model_provider = TenantPreferredModelProvider(
  541. tenant_id=self.tenant_id,
  542. provider_name=self.provider.provider,
  543. preferred_provider_type=provider_type.value
  544. )
  545. db.session.add(preferred_model_provider)
  546. db.session.commit()
  547. def extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]:
  548. """
  549. Extract secret input form variables.
  550. :param credential_form_schemas:
  551. :return:
  552. """
  553. secret_input_form_variables = []
  554. for credential_form_schema in credential_form_schemas:
  555. if credential_form_schema.type == FormType.SECRET_INPUT:
  556. secret_input_form_variables.append(credential_form_schema.variable)
  557. return secret_input_form_variables
  558. def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]) -> dict:
  559. """
  560. Obfuscated credentials.
  561. :param credentials: credentials
  562. :param credential_form_schemas: credential form schemas
  563. :return:
  564. """
  565. # Get provider credential secret variables
  566. credential_secret_variables = self.extract_secret_variables(
  567. credential_form_schemas
  568. )
  569. # Obfuscate provider credentials
  570. copy_credentials = credentials.copy()
  571. for key, value in copy_credentials.items():
  572. if key in credential_secret_variables:
  573. copy_credentials[key] = encrypter.obfuscated_token(value)
  574. return copy_credentials
  575. def get_provider_model(self, model_type: ModelType,
  576. model: str,
  577. only_active: bool = False) -> Optional[ModelWithProviderEntity]:
  578. """
  579. Get provider model.
  580. :param model_type: model type
  581. :param model: model name
  582. :param only_active: return active model only
  583. :return:
  584. """
  585. provider_models = self.get_provider_models(model_type, only_active)
  586. for provider_model in provider_models:
  587. if provider_model.model == model:
  588. return provider_model
  589. return None
  590. def get_provider_models(self, model_type: Optional[ModelType] = None,
  591. only_active: bool = False) -> list[ModelWithProviderEntity]:
  592. """
  593. Get provider models.
  594. :param model_type: model type
  595. :param only_active: only active models
  596. :return:
  597. """
  598. provider_instance = self.get_provider_instance()
  599. model_types = []
  600. if model_type:
  601. model_types.append(model_type)
  602. else:
  603. model_types = provider_instance.get_provider_schema().supported_model_types
  604. # Group model settings by model type and model
  605. model_setting_map = defaultdict(dict)
  606. for model_setting in self.model_settings:
  607. model_setting_map[model_setting.model_type][model_setting.model] = model_setting
  608. if self.using_provider_type == ProviderType.SYSTEM:
  609. provider_models = self._get_system_provider_models(
  610. model_types=model_types,
  611. provider_instance=provider_instance,
  612. model_setting_map=model_setting_map
  613. )
  614. else:
  615. provider_models = self._get_custom_provider_models(
  616. model_types=model_types,
  617. provider_instance=provider_instance,
  618. model_setting_map=model_setting_map
  619. )
  620. if only_active:
  621. provider_models = [m for m in provider_models if m.status == ModelStatus.ACTIVE]
  622. # resort provider_models
  623. return sorted(provider_models, key=lambda x: x.model_type.value)
  624. def _get_system_provider_models(self,
  625. model_types: list[ModelType],
  626. provider_instance: ModelProvider,
  627. model_setting_map: dict[ModelType, dict[str, ModelSettings]]) \
  628. -> list[ModelWithProviderEntity]:
  629. """
  630. Get system provider models.
  631. :param model_types: model types
  632. :param provider_instance: provider instance
  633. :param model_setting_map: model setting map
  634. :return:
  635. """
  636. provider_models = []
  637. for model_type in model_types:
  638. for m in provider_instance.models(model_type):
  639. status = ModelStatus.ACTIVE
  640. if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]:
  641. model_setting = model_setting_map[m.model_type][m.model]
  642. if model_setting.enabled is False:
  643. status = ModelStatus.DISABLED
  644. provider_models.append(
  645. ModelWithProviderEntity(
  646. model=m.model,
  647. label=m.label,
  648. model_type=m.model_type,
  649. features=m.features,
  650. fetch_from=m.fetch_from,
  651. model_properties=m.model_properties,
  652. deprecated=m.deprecated,
  653. provider=SimpleModelProviderEntity(self.provider),
  654. status=status
  655. )
  656. )
  657. if self.provider.provider not in original_provider_configurate_methods:
  658. original_provider_configurate_methods[self.provider.provider] = []
  659. for configurate_method in provider_instance.get_provider_schema().configurate_methods:
  660. original_provider_configurate_methods[self.provider.provider].append(configurate_method)
  661. should_use_custom_model = False
  662. if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
  663. should_use_custom_model = True
  664. for quota_configuration in self.system_configuration.quota_configurations:
  665. if self.system_configuration.current_quota_type != quota_configuration.quota_type:
  666. continue
  667. restrict_models = quota_configuration.restrict_models
  668. if len(restrict_models) == 0:
  669. break
  670. if should_use_custom_model:
  671. if original_provider_configurate_methods[self.provider.provider] == [
  672. ConfigurateMethod.CUSTOMIZABLE_MODEL]:
  673. # only customizable model
  674. for restrict_model in restrict_models:
  675. copy_credentials = self.system_configuration.credentials.copy()
  676. if restrict_model.base_model_name:
  677. copy_credentials['base_model_name'] = restrict_model.base_model_name
  678. try:
  679. custom_model_schema = (
  680. provider_instance.get_model_instance(restrict_model.model_type)
  681. .get_customizable_model_schema_from_credentials(
  682. restrict_model.model,
  683. copy_credentials
  684. )
  685. )
  686. except Exception as ex:
  687. logger.warning(f'get custom model schema failed, {ex}')
  688. continue
  689. if not custom_model_schema:
  690. continue
  691. if custom_model_schema.model_type not in model_types:
  692. continue
  693. status = ModelStatus.ACTIVE
  694. if (custom_model_schema.model_type in model_setting_map
  695. and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]):
  696. model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model]
  697. if model_setting.enabled is False:
  698. status = ModelStatus.DISABLED
  699. provider_models.append(
  700. ModelWithProviderEntity(
  701. model=custom_model_schema.model,
  702. label=custom_model_schema.label,
  703. model_type=custom_model_schema.model_type,
  704. features=custom_model_schema.features,
  705. fetch_from=FetchFrom.PREDEFINED_MODEL,
  706. model_properties=custom_model_schema.model_properties,
  707. deprecated=custom_model_schema.deprecated,
  708. provider=SimpleModelProviderEntity(self.provider),
  709. status=status
  710. )
  711. )
  712. # if llm name not in restricted llm list, remove it
  713. restrict_model_names = [rm.model for rm in restrict_models]
  714. for m in provider_models:
  715. if m.model_type == ModelType.LLM and m.model not in restrict_model_names:
  716. m.status = ModelStatus.NO_PERMISSION
  717. elif not quota_configuration.is_valid:
  718. m.status = ModelStatus.QUOTA_EXCEEDED
  719. return provider_models
  720. def _get_custom_provider_models(self,
  721. model_types: list[ModelType],
  722. provider_instance: ModelProvider,
  723. model_setting_map: dict[ModelType, dict[str, ModelSettings]]) \
  724. -> list[ModelWithProviderEntity]:
  725. """
  726. Get custom provider models.
  727. :param model_types: model types
  728. :param provider_instance: provider instance
  729. :param model_setting_map: model setting map
  730. :return:
  731. """
  732. provider_models = []
  733. credentials = None
  734. if self.custom_configuration.provider:
  735. credentials = self.custom_configuration.provider.credentials
  736. for model_type in model_types:
  737. if model_type not in self.provider.supported_model_types:
  738. continue
  739. models = provider_instance.models(model_type)
  740. for m in models:
  741. status = ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE
  742. load_balancing_enabled = False
  743. if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]:
  744. model_setting = model_setting_map[m.model_type][m.model]
  745. if model_setting.enabled is False:
  746. status = ModelStatus.DISABLED
  747. if len(model_setting.load_balancing_configs) > 1:
  748. load_balancing_enabled = True
  749. provider_models.append(
  750. ModelWithProviderEntity(
  751. model=m.model,
  752. label=m.label,
  753. model_type=m.model_type,
  754. features=m.features,
  755. fetch_from=m.fetch_from,
  756. model_properties=m.model_properties,
  757. deprecated=m.deprecated,
  758. provider=SimpleModelProviderEntity(self.provider),
  759. status=status,
  760. load_balancing_enabled=load_balancing_enabled
  761. )
  762. )
  763. # custom models
  764. for model_configuration in self.custom_configuration.models:
  765. if model_configuration.model_type not in model_types:
  766. continue
  767. try:
  768. custom_model_schema = (
  769. provider_instance.get_model_instance(model_configuration.model_type)
  770. .get_customizable_model_schema_from_credentials(
  771. model_configuration.model,
  772. model_configuration.credentials
  773. )
  774. )
  775. except Exception as ex:
  776. logger.warning(f'get custom model schema failed, {ex}')
  777. continue
  778. if not custom_model_schema:
  779. continue
  780. status = ModelStatus.ACTIVE
  781. load_balancing_enabled = False
  782. if (custom_model_schema.model_type in model_setting_map
  783. and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]):
  784. model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model]
  785. if model_setting.enabled is False:
  786. status = ModelStatus.DISABLED
  787. if len(model_setting.load_balancing_configs) > 1:
  788. load_balancing_enabled = True
  789. provider_models.append(
  790. ModelWithProviderEntity(
  791. model=custom_model_schema.model,
  792. label=custom_model_schema.label,
  793. model_type=custom_model_schema.model_type,
  794. features=custom_model_schema.features,
  795. fetch_from=custom_model_schema.fetch_from,
  796. model_properties=custom_model_schema.model_properties,
  797. deprecated=custom_model_schema.deprecated,
  798. provider=SimpleModelProviderEntity(self.provider),
  799. status=status,
  800. load_balancing_enabled=load_balancing_enabled
  801. )
  802. )
  803. return provider_models
  804. class ProviderConfigurations(BaseModel):
  805. """
  806. Model class for provider configuration dict.
  807. """
  808. tenant_id: str
  809. configurations: dict[str, ProviderConfiguration] = {}
  810. def __init__(self, tenant_id: str):
  811. super().__init__(tenant_id=tenant_id)
  812. def get_models(self,
  813. provider: Optional[str] = None,
  814. model_type: Optional[ModelType] = None,
  815. only_active: bool = False) \
  816. -> list[ModelWithProviderEntity]:
  817. """
  818. Get available models.
  819. If preferred provider type is `system`:
  820. Get the current **system mode** if provider supported,
  821. if all system modes are not available (no quota), it is considered to be the **custom credential mode**.
  822. If there is no model configured in custom mode, it is treated as no_configure.
  823. system > custom > no_configure
  824. If preferred provider type is `custom`:
  825. If custom credentials are configured, it is treated as custom mode.
  826. Otherwise, get the current **system mode** if supported,
  827. If all system modes are not available (no quota), it is treated as no_configure.
  828. custom > system > no_configure
  829. If real mode is `system`, use system credentials to get models,
  830. paid quotas > provider free quotas > system free quotas
  831. include pre-defined models (exclude GPT-4, status marked as `no_permission`).
  832. If real mode is `custom`, use workspace custom credentials to get models,
  833. include pre-defined models, custom models(manual append).
  834. If real mode is `no_configure`, only return pre-defined models from `model runtime`.
  835. (model status marked as `no_configure` if preferred provider type is `custom` otherwise `quota_exceeded`)
  836. model status marked as `active` is available.
  837. :param provider: provider name
  838. :param model_type: model type
  839. :param only_active: only active models
  840. :return:
  841. """
  842. all_models = []
  843. for provider_configuration in self.values():
  844. if provider and provider_configuration.provider.provider != provider:
  845. continue
  846. all_models.extend(provider_configuration.get_provider_models(model_type, only_active))
  847. return all_models
  848. def to_list(self) -> list[ProviderConfiguration]:
  849. """
  850. Convert to list.
  851. :return:
  852. """
  853. return list(self.values())
  854. def __getitem__(self, key):
  855. return self.configurations[key]
  856. def __setitem__(self, key, value):
  857. self.configurations[key] = value
  858. def __iter__(self):
  859. return iter(self.configurations)
  860. def values(self) -> Iterator[ProviderConfiguration]:
  861. return self.configurations.values()
  862. def get(self, key, default=None):
  863. return self.configurations.get(key, default)
  864. class ProviderModelBundle(BaseModel):
  865. """
  866. Provider model bundle.
  867. """
  868. configuration: ProviderConfiguration
  869. provider_instance: ModelProvider
  870. model_type_instance: AIModel
  871. # pydantic configs
  872. model_config = ConfigDict(arbitrary_types_allowed=True,
  873. protected_namespaces=())