provider_configuration.py 40 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028
  1. import datetime
  2. import json
  3. import logging
  4. from collections import defaultdict
  5. from collections.abc import Iterator
  6. from json import JSONDecodeError
  7. from typing import Optional
  8. from pydantic import BaseModel, ConfigDict
  9. from constants import HIDDEN_VALUE
  10. from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
  11. from core.entities.provider_entities import (
  12. CustomConfiguration,
  13. ModelSettings,
  14. SystemConfiguration,
  15. SystemConfigurationStatus,
  16. )
  17. from core.helper import encrypter
  18. from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
  19. from core.model_runtime.entities.model_entities import FetchFrom, ModelType
  20. from core.model_runtime.entities.provider_entities import (
  21. ConfigurateMethod,
  22. CredentialFormSchema,
  23. FormType,
  24. ProviderEntity,
  25. )
  26. from core.model_runtime.model_providers import model_provider_factory
  27. from core.model_runtime.model_providers.__base.ai_model import AIModel
  28. from core.model_runtime.model_providers.__base.model_provider import ModelProvider
  29. from extensions.ext_database import db
  30. from models.provider import (
  31. LoadBalancingModelConfig,
  32. Provider,
  33. ProviderModel,
  34. ProviderModelSetting,
  35. ProviderType,
  36. TenantPreferredModelProvider,
  37. )
  38. logger = logging.getLogger(__name__)
  39. original_provider_configurate_methods = {}
  40. class ProviderConfiguration(BaseModel):
  41. """
  42. Model class for provider configuration.
  43. """
  44. tenant_id: str
  45. provider: ProviderEntity
  46. preferred_provider_type: ProviderType
  47. using_provider_type: ProviderType
  48. system_configuration: SystemConfiguration
  49. custom_configuration: CustomConfiguration
  50. model_settings: list[ModelSettings]
  51. # pydantic configs
  52. model_config = ConfigDict(protected_namespaces=())
  53. def __init__(self, **data):
  54. super().__init__(**data)
  55. if self.provider.provider not in original_provider_configurate_methods:
  56. original_provider_configurate_methods[self.provider.provider] = []
  57. for configurate_method in self.provider.configurate_methods:
  58. original_provider_configurate_methods[self.provider.provider].append(configurate_method)
  59. if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
  60. if (any(len(quota_configuration.restrict_models) > 0
  61. for quota_configuration in self.system_configuration.quota_configurations)
  62. and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods):
  63. self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL)
  64. def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]:
  65. """
  66. Get current credentials.
  67. :param model_type: model type
  68. :param model: model name
  69. :return:
  70. """
  71. if self.model_settings:
  72. # check if model is disabled by admin
  73. for model_setting in self.model_settings:
  74. if (model_setting.model_type == model_type
  75. and model_setting.model == model):
  76. if not model_setting.enabled:
  77. raise ValueError(f'Model {model} is disabled.')
  78. if self.using_provider_type == ProviderType.SYSTEM:
  79. restrict_models = []
  80. for quota_configuration in self.system_configuration.quota_configurations:
  81. if self.system_configuration.current_quota_type != quota_configuration.quota_type:
  82. continue
  83. restrict_models = quota_configuration.restrict_models
  84. copy_credentials = self.system_configuration.credentials.copy()
  85. if restrict_models:
  86. for restrict_model in restrict_models:
  87. if (restrict_model.model_type == model_type
  88. and restrict_model.model == model
  89. and restrict_model.base_model_name):
  90. copy_credentials['base_model_name'] = restrict_model.base_model_name
  91. return copy_credentials
  92. else:
  93. credentials = None
  94. if self.custom_configuration.models:
  95. for model_configuration in self.custom_configuration.models:
  96. if model_configuration.model_type == model_type and model_configuration.model == model:
  97. credentials = model_configuration.credentials
  98. break
  99. if self.custom_configuration.provider:
  100. credentials = self.custom_configuration.provider.credentials
  101. return credentials
  102. def get_system_configuration_status(self) -> SystemConfigurationStatus:
  103. """
  104. Get system configuration status.
  105. :return:
  106. """
  107. if self.system_configuration.enabled is False:
  108. return SystemConfigurationStatus.UNSUPPORTED
  109. current_quota_type = self.system_configuration.current_quota_type
  110. current_quota_configuration = next(
  111. (q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type),
  112. None
  113. )
  114. return SystemConfigurationStatus.ACTIVE if current_quota_configuration.is_valid else \
  115. SystemConfigurationStatus.QUOTA_EXCEEDED
  116. def is_custom_configuration_available(self) -> bool:
  117. """
  118. Check custom configuration available.
  119. :return:
  120. """
  121. return (self.custom_configuration.provider is not None
  122. or len(self.custom_configuration.models) > 0)
  123. def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]:
  124. """
  125. Get custom credentials.
  126. :param obfuscated: obfuscated secret data in credentials
  127. :return:
  128. """
  129. if self.custom_configuration.provider is None:
  130. return None
  131. credentials = self.custom_configuration.provider.credentials
  132. if not obfuscated:
  133. return credentials
  134. # Obfuscate credentials
  135. return self.obfuscated_credentials(
  136. credentials=credentials,
  137. credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas
  138. if self.provider.provider_credential_schema else []
  139. )
  140. def custom_credentials_validate(self, credentials: dict) -> tuple[Provider, dict]:
  141. """
  142. Validate custom credentials.
  143. :param credentials: provider credentials
  144. :return:
  145. """
  146. # get provider
  147. provider_record = db.session.query(Provider) \
  148. .filter(
  149. Provider.tenant_id == self.tenant_id,
  150. Provider.provider_name == self.provider.provider,
  151. Provider.provider_type == ProviderType.CUSTOM.value
  152. ).first()
  153. # Get provider credential secret variables
  154. provider_credential_secret_variables = self.extract_secret_variables(
  155. self.provider.provider_credential_schema.credential_form_schemas
  156. if self.provider.provider_credential_schema else []
  157. )
  158. if provider_record:
  159. try:
  160. # fix origin data
  161. if provider_record.encrypted_config:
  162. if not provider_record.encrypted_config.startswith("{"):
  163. original_credentials = {
  164. "openai_api_key": provider_record.encrypted_config
  165. }
  166. else:
  167. original_credentials = json.loads(provider_record.encrypted_config)
  168. else:
  169. original_credentials = {}
  170. except JSONDecodeError:
  171. original_credentials = {}
  172. # encrypt credentials
  173. for key, value in credentials.items():
  174. if key in provider_credential_secret_variables:
  175. # if send [__HIDDEN__] in secret input, it will be same as original value
  176. if value == HIDDEN_VALUE and key in original_credentials:
  177. credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
  178. credentials = model_provider_factory.provider_credentials_validate(
  179. provider=self.provider.provider,
  180. credentials=credentials
  181. )
  182. for key, value in credentials.items():
  183. if key in provider_credential_secret_variables:
  184. credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
  185. return provider_record, credentials
  186. def add_or_update_custom_credentials(self, credentials: dict) -> None:
  187. """
  188. Add or update custom provider credentials.
  189. :param credentials:
  190. :return:
  191. """
  192. # validate custom provider config
  193. provider_record, credentials = self.custom_credentials_validate(credentials)
  194. # save provider
  195. # Note: Do not switch the preferred provider, which allows users to use quotas first
  196. if provider_record:
  197. provider_record.encrypted_config = json.dumps(credentials)
  198. provider_record.is_valid = True
  199. provider_record.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  200. db.session.commit()
  201. else:
  202. provider_record = Provider(
  203. tenant_id=self.tenant_id,
  204. provider_name=self.provider.provider,
  205. provider_type=ProviderType.CUSTOM.value,
  206. encrypted_config=json.dumps(credentials),
  207. is_valid=True
  208. )
  209. db.session.add(provider_record)
  210. db.session.commit()
  211. provider_model_credentials_cache = ProviderCredentialsCache(
  212. tenant_id=self.tenant_id,
  213. identity_id=provider_record.id,
  214. cache_type=ProviderCredentialsCacheType.PROVIDER
  215. )
  216. provider_model_credentials_cache.delete()
  217. self.switch_preferred_provider_type(ProviderType.CUSTOM)
  218. def delete_custom_credentials(self) -> None:
  219. """
  220. Delete custom provider credentials.
  221. :return:
  222. """
  223. # get provider
  224. provider_record = db.session.query(Provider) \
  225. .filter(
  226. Provider.tenant_id == self.tenant_id,
  227. Provider.provider_name == self.provider.provider,
  228. Provider.provider_type == ProviderType.CUSTOM.value
  229. ).first()
  230. # delete provider
  231. if provider_record:
  232. self.switch_preferred_provider_type(ProviderType.SYSTEM)
  233. db.session.delete(provider_record)
  234. db.session.commit()
  235. provider_model_credentials_cache = ProviderCredentialsCache(
  236. tenant_id=self.tenant_id,
  237. identity_id=provider_record.id,
  238. cache_type=ProviderCredentialsCacheType.PROVIDER
  239. )
  240. provider_model_credentials_cache.delete()
  241. def get_custom_model_credentials(self, model_type: ModelType, model: str, obfuscated: bool = False) \
  242. -> Optional[dict]:
  243. """
  244. Get custom model credentials.
  245. :param model_type: model type
  246. :param model: model name
  247. :param obfuscated: obfuscated secret data in credentials
  248. :return:
  249. """
  250. if not self.custom_configuration.models:
  251. return None
  252. for model_configuration in self.custom_configuration.models:
  253. if model_configuration.model_type == model_type and model_configuration.model == model:
  254. credentials = model_configuration.credentials
  255. if not obfuscated:
  256. return credentials
  257. # Obfuscate credentials
  258. return self.obfuscated_credentials(
  259. credentials=credentials,
  260. credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas
  261. if self.provider.model_credential_schema else []
  262. )
  263. return None
  264. def custom_model_credentials_validate(self, model_type: ModelType, model: str, credentials: dict) \
  265. -> tuple[ProviderModel, dict]:
  266. """
  267. Validate custom model credentials.
  268. :param model_type: model type
  269. :param model: model name
  270. :param credentials: model credentials
  271. :return:
  272. """
  273. # get provider model
  274. provider_model_record = db.session.query(ProviderModel) \
  275. .filter(
  276. ProviderModel.tenant_id == self.tenant_id,
  277. ProviderModel.provider_name == self.provider.provider,
  278. ProviderModel.model_name == model,
  279. ProviderModel.model_type == model_type.to_origin_model_type()
  280. ).first()
  281. # Get provider credential secret variables
  282. provider_credential_secret_variables = self.extract_secret_variables(
  283. self.provider.model_credential_schema.credential_form_schemas
  284. if self.provider.model_credential_schema else []
  285. )
  286. if provider_model_record:
  287. try:
  288. original_credentials = json.loads(
  289. provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {}
  290. except JSONDecodeError:
  291. original_credentials = {}
  292. # decrypt credentials
  293. for key, value in credentials.items():
  294. if key in provider_credential_secret_variables:
  295. # if send [__HIDDEN__] in secret input, it will be same as original value
  296. if value == HIDDEN_VALUE and key in original_credentials:
  297. credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
  298. credentials = model_provider_factory.model_credentials_validate(
  299. provider=self.provider.provider,
  300. model_type=model_type,
  301. model=model,
  302. credentials=credentials
  303. )
  304. for key, value in credentials.items():
  305. if key in provider_credential_secret_variables:
  306. credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
  307. return provider_model_record, credentials
  308. def add_or_update_custom_model_credentials(self, model_type: ModelType, model: str, credentials: dict) -> None:
  309. """
  310. Add or update custom model credentials.
  311. :param model_type: model type
  312. :param model: model name
  313. :param credentials: model credentials
  314. :return:
  315. """
  316. # validate custom model config
  317. provider_model_record, credentials = self.custom_model_credentials_validate(model_type, model, credentials)
  318. # save provider model
  319. # Note: Do not switch the preferred provider, which allows users to use quotas first
  320. if provider_model_record:
  321. provider_model_record.encrypted_config = json.dumps(credentials)
  322. provider_model_record.is_valid = True
  323. provider_model_record.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  324. db.session.commit()
  325. else:
  326. provider_model_record = ProviderModel(
  327. tenant_id=self.tenant_id,
  328. provider_name=self.provider.provider,
  329. model_name=model,
  330. model_type=model_type.to_origin_model_type(),
  331. encrypted_config=json.dumps(credentials),
  332. is_valid=True
  333. )
  334. db.session.add(provider_model_record)
  335. db.session.commit()
  336. provider_model_credentials_cache = ProviderCredentialsCache(
  337. tenant_id=self.tenant_id,
  338. identity_id=provider_model_record.id,
  339. cache_type=ProviderCredentialsCacheType.MODEL
  340. )
  341. provider_model_credentials_cache.delete()
  342. def delete_custom_model_credentials(self, model_type: ModelType, model: str) -> None:
  343. """
  344. Delete custom model credentials.
  345. :param model_type: model type
  346. :param model: model name
  347. :return:
  348. """
  349. # get provider model
  350. provider_model_record = db.session.query(ProviderModel) \
  351. .filter(
  352. ProviderModel.tenant_id == self.tenant_id,
  353. ProviderModel.provider_name == self.provider.provider,
  354. ProviderModel.model_name == model,
  355. ProviderModel.model_type == model_type.to_origin_model_type()
  356. ).first()
  357. # delete provider model
  358. if provider_model_record:
  359. db.session.delete(provider_model_record)
  360. db.session.commit()
  361. provider_model_credentials_cache = ProviderCredentialsCache(
  362. tenant_id=self.tenant_id,
  363. identity_id=provider_model_record.id,
  364. cache_type=ProviderCredentialsCacheType.MODEL
  365. )
  366. provider_model_credentials_cache.delete()
  367. def enable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting:
  368. """
  369. Enable model.
  370. :param model_type: model type
  371. :param model: model name
  372. :return:
  373. """
  374. model_setting = db.session.query(ProviderModelSetting) \
  375. .filter(
  376. ProviderModelSetting.tenant_id == self.tenant_id,
  377. ProviderModelSetting.provider_name == self.provider.provider,
  378. ProviderModelSetting.model_type == model_type.to_origin_model_type(),
  379. ProviderModelSetting.model_name == model
  380. ).first()
  381. if model_setting:
  382. model_setting.enabled = True
  383. model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  384. db.session.commit()
  385. else:
  386. model_setting = ProviderModelSetting(
  387. tenant_id=self.tenant_id,
  388. provider_name=self.provider.provider,
  389. model_type=model_type.to_origin_model_type(),
  390. model_name=model,
  391. enabled=True
  392. )
  393. db.session.add(model_setting)
  394. db.session.commit()
  395. return model_setting
  396. def disable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting:
  397. """
  398. Disable model.
  399. :param model_type: model type
  400. :param model: model name
  401. :return:
  402. """
  403. model_setting = db.session.query(ProviderModelSetting) \
  404. .filter(
  405. ProviderModelSetting.tenant_id == self.tenant_id,
  406. ProviderModelSetting.provider_name == self.provider.provider,
  407. ProviderModelSetting.model_type == model_type.to_origin_model_type(),
  408. ProviderModelSetting.model_name == model
  409. ).first()
  410. if model_setting:
  411. model_setting.enabled = False
  412. model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  413. db.session.commit()
  414. else:
  415. model_setting = ProviderModelSetting(
  416. tenant_id=self.tenant_id,
  417. provider_name=self.provider.provider,
  418. model_type=model_type.to_origin_model_type(),
  419. model_name=model,
  420. enabled=False
  421. )
  422. db.session.add(model_setting)
  423. db.session.commit()
  424. return model_setting
  425. def get_provider_model_setting(self, model_type: ModelType, model: str) -> Optional[ProviderModelSetting]:
  426. """
  427. Get provider model setting.
  428. :param model_type: model type
  429. :param model: model name
  430. :return:
  431. """
  432. return db.session.query(ProviderModelSetting) \
  433. .filter(
  434. ProviderModelSetting.tenant_id == self.tenant_id,
  435. ProviderModelSetting.provider_name == self.provider.provider,
  436. ProviderModelSetting.model_type == model_type.to_origin_model_type(),
  437. ProviderModelSetting.model_name == model
  438. ).first()
  439. def enable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting:
  440. """
  441. Enable model load balancing.
  442. :param model_type: model type
  443. :param model: model name
  444. :return:
  445. """
  446. load_balancing_config_count = db.session.query(LoadBalancingModelConfig) \
  447. .filter(
  448. LoadBalancingModelConfig.tenant_id == self.tenant_id,
  449. LoadBalancingModelConfig.provider_name == self.provider.provider,
  450. LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
  451. LoadBalancingModelConfig.model_name == model
  452. ).count()
  453. if load_balancing_config_count <= 1:
  454. raise ValueError('Model load balancing configuration must be more than 1.')
  455. model_setting = db.session.query(ProviderModelSetting) \
  456. .filter(
  457. ProviderModelSetting.tenant_id == self.tenant_id,
  458. ProviderModelSetting.provider_name == self.provider.provider,
  459. ProviderModelSetting.model_type == model_type.to_origin_model_type(),
  460. ProviderModelSetting.model_name == model
  461. ).first()
  462. if model_setting:
  463. model_setting.load_balancing_enabled = True
  464. model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  465. db.session.commit()
  466. else:
  467. model_setting = ProviderModelSetting(
  468. tenant_id=self.tenant_id,
  469. provider_name=self.provider.provider,
  470. model_type=model_type.to_origin_model_type(),
  471. model_name=model,
  472. load_balancing_enabled=True
  473. )
  474. db.session.add(model_setting)
  475. db.session.commit()
  476. return model_setting
  477. def disable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting:
  478. """
  479. Disable model load balancing.
  480. :param model_type: model type
  481. :param model: model name
  482. :return:
  483. """
  484. model_setting = db.session.query(ProviderModelSetting) \
  485. .filter(
  486. ProviderModelSetting.tenant_id == self.tenant_id,
  487. ProviderModelSetting.provider_name == self.provider.provider,
  488. ProviderModelSetting.model_type == model_type.to_origin_model_type(),
  489. ProviderModelSetting.model_name == model
  490. ).first()
  491. if model_setting:
  492. model_setting.load_balancing_enabled = False
  493. model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  494. db.session.commit()
  495. else:
  496. model_setting = ProviderModelSetting(
  497. tenant_id=self.tenant_id,
  498. provider_name=self.provider.provider,
  499. model_type=model_type.to_origin_model_type(),
  500. model_name=model,
  501. load_balancing_enabled=False
  502. )
  503. db.session.add(model_setting)
  504. db.session.commit()
  505. return model_setting
  506. def get_provider_instance(self) -> ModelProvider:
  507. """
  508. Get provider instance.
  509. :return:
  510. """
  511. return model_provider_factory.get_provider_instance(self.provider.provider)
  512. def get_model_type_instance(self, model_type: ModelType) -> AIModel:
  513. """
  514. Get current model type instance.
  515. :param model_type: model type
  516. :return:
  517. """
  518. # Get provider instance
  519. provider_instance = self.get_provider_instance()
  520. # Get model instance of LLM
  521. return provider_instance.get_model_instance(model_type)
  522. def switch_preferred_provider_type(self, provider_type: ProviderType) -> None:
  523. """
  524. Switch preferred provider type.
  525. :param provider_type:
  526. :return:
  527. """
  528. if provider_type == self.preferred_provider_type:
  529. return
  530. if provider_type == ProviderType.SYSTEM and not self.system_configuration.enabled:
  531. return
  532. # get preferred provider
  533. preferred_model_provider = db.session.query(TenantPreferredModelProvider) \
  534. .filter(
  535. TenantPreferredModelProvider.tenant_id == self.tenant_id,
  536. TenantPreferredModelProvider.provider_name == self.provider.provider
  537. ).first()
  538. if preferred_model_provider:
  539. preferred_model_provider.preferred_provider_type = provider_type.value
  540. else:
  541. preferred_model_provider = TenantPreferredModelProvider(
  542. tenant_id=self.tenant_id,
  543. provider_name=self.provider.provider,
  544. preferred_provider_type=provider_type.value
  545. )
  546. db.session.add(preferred_model_provider)
  547. db.session.commit()
  548. def extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]:
  549. """
  550. Extract secret input form variables.
  551. :param credential_form_schemas:
  552. :return:
  553. """
  554. secret_input_form_variables = []
  555. for credential_form_schema in credential_form_schemas:
  556. if credential_form_schema.type == FormType.SECRET_INPUT:
  557. secret_input_form_variables.append(credential_form_schema.variable)
  558. return secret_input_form_variables
  559. def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]) -> dict:
  560. """
  561. Obfuscated credentials.
  562. :param credentials: credentials
  563. :param credential_form_schemas: credential form schemas
  564. :return:
  565. """
  566. # Get provider credential secret variables
  567. credential_secret_variables = self.extract_secret_variables(
  568. credential_form_schemas
  569. )
  570. # Obfuscate provider credentials
  571. copy_credentials = credentials.copy()
  572. for key, value in copy_credentials.items():
  573. if key in credential_secret_variables:
  574. copy_credentials[key] = encrypter.obfuscated_token(value)
  575. return copy_credentials
  576. def get_provider_model(self, model_type: ModelType,
  577. model: str,
  578. only_active: bool = False) -> Optional[ModelWithProviderEntity]:
  579. """
  580. Get provider model.
  581. :param model_type: model type
  582. :param model: model name
  583. :param only_active: return active model only
  584. :return:
  585. """
  586. provider_models = self.get_provider_models(model_type, only_active)
  587. for provider_model in provider_models:
  588. if provider_model.model == model:
  589. return provider_model
  590. return None
  591. def get_provider_models(self, model_type: Optional[ModelType] = None,
  592. only_active: bool = False) -> list[ModelWithProviderEntity]:
  593. """
  594. Get provider models.
  595. :param model_type: model type
  596. :param only_active: only active models
  597. :return:
  598. """
  599. provider_instance = self.get_provider_instance()
  600. model_types = []
  601. if model_type:
  602. model_types.append(model_type)
  603. else:
  604. model_types = provider_instance.get_provider_schema().supported_model_types
  605. # Group model settings by model type and model
  606. model_setting_map = defaultdict(dict)
  607. for model_setting in self.model_settings:
  608. model_setting_map[model_setting.model_type][model_setting.model] = model_setting
  609. if self.using_provider_type == ProviderType.SYSTEM:
  610. provider_models = self._get_system_provider_models(
  611. model_types=model_types,
  612. provider_instance=provider_instance,
  613. model_setting_map=model_setting_map
  614. )
  615. else:
  616. provider_models = self._get_custom_provider_models(
  617. model_types=model_types,
  618. provider_instance=provider_instance,
  619. model_setting_map=model_setting_map
  620. )
  621. if only_active:
  622. provider_models = [m for m in provider_models if m.status == ModelStatus.ACTIVE]
  623. # resort provider_models
  624. return sorted(provider_models, key=lambda x: x.model_type.value)
  625. def _get_system_provider_models(self,
  626. model_types: list[ModelType],
  627. provider_instance: ModelProvider,
  628. model_setting_map: dict[ModelType, dict[str, ModelSettings]]) \
  629. -> list[ModelWithProviderEntity]:
  630. """
  631. Get system provider models.
  632. :param model_types: model types
  633. :param provider_instance: provider instance
  634. :param model_setting_map: model setting map
  635. :return:
  636. """
  637. provider_models = []
  638. for model_type in model_types:
  639. for m in provider_instance.models(model_type):
  640. status = ModelStatus.ACTIVE
  641. if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]:
  642. model_setting = model_setting_map[m.model_type][m.model]
  643. if model_setting.enabled is False:
  644. status = ModelStatus.DISABLED
  645. provider_models.append(
  646. ModelWithProviderEntity(
  647. model=m.model,
  648. label=m.label,
  649. model_type=m.model_type,
  650. features=m.features,
  651. fetch_from=m.fetch_from,
  652. model_properties=m.model_properties,
  653. deprecated=m.deprecated,
  654. provider=SimpleModelProviderEntity(self.provider),
  655. status=status
  656. )
  657. )
  658. if self.provider.provider not in original_provider_configurate_methods:
  659. original_provider_configurate_methods[self.provider.provider] = []
  660. for configurate_method in provider_instance.get_provider_schema().configurate_methods:
  661. original_provider_configurate_methods[self.provider.provider].append(configurate_method)
  662. should_use_custom_model = False
  663. if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
  664. should_use_custom_model = True
  665. for quota_configuration in self.system_configuration.quota_configurations:
  666. if self.system_configuration.current_quota_type != quota_configuration.quota_type:
  667. continue
  668. restrict_models = quota_configuration.restrict_models
  669. if len(restrict_models) == 0:
  670. break
  671. if should_use_custom_model:
  672. if original_provider_configurate_methods[self.provider.provider] == [
  673. ConfigurateMethod.CUSTOMIZABLE_MODEL]:
  674. # only customizable model
  675. for restrict_model in restrict_models:
  676. copy_credentials = self.system_configuration.credentials.copy()
  677. if restrict_model.base_model_name:
  678. copy_credentials['base_model_name'] = restrict_model.base_model_name
  679. try:
  680. custom_model_schema = (
  681. provider_instance.get_model_instance(restrict_model.model_type)
  682. .get_customizable_model_schema_from_credentials(
  683. restrict_model.model,
  684. copy_credentials
  685. )
  686. )
  687. except Exception as ex:
  688. logger.warning(f'get custom model schema failed, {ex}')
  689. continue
  690. if not custom_model_schema:
  691. continue
  692. if custom_model_schema.model_type not in model_types:
  693. continue
  694. status = ModelStatus.ACTIVE
  695. if (custom_model_schema.model_type in model_setting_map
  696. and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]):
  697. model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model]
  698. if model_setting.enabled is False:
  699. status = ModelStatus.DISABLED
  700. provider_models.append(
  701. ModelWithProviderEntity(
  702. model=custom_model_schema.model,
  703. label=custom_model_schema.label,
  704. model_type=custom_model_schema.model_type,
  705. features=custom_model_schema.features,
  706. fetch_from=FetchFrom.PREDEFINED_MODEL,
  707. model_properties=custom_model_schema.model_properties,
  708. deprecated=custom_model_schema.deprecated,
  709. provider=SimpleModelProviderEntity(self.provider),
  710. status=status
  711. )
  712. )
  713. # if llm name not in restricted llm list, remove it
  714. restrict_model_names = [rm.model for rm in restrict_models]
  715. for m in provider_models:
  716. if m.model_type == ModelType.LLM and m.model not in restrict_model_names:
  717. m.status = ModelStatus.NO_PERMISSION
  718. elif not quota_configuration.is_valid:
  719. m.status = ModelStatus.QUOTA_EXCEEDED
  720. return provider_models
  721. def _get_custom_provider_models(self,
  722. model_types: list[ModelType],
  723. provider_instance: ModelProvider,
  724. model_setting_map: dict[ModelType, dict[str, ModelSettings]]) \
  725. -> list[ModelWithProviderEntity]:
  726. """
  727. Get custom provider models.
  728. :param model_types: model types
  729. :param provider_instance: provider instance
  730. :param model_setting_map: model setting map
  731. :return:
  732. """
  733. provider_models = []
  734. credentials = None
  735. if self.custom_configuration.provider:
  736. credentials = self.custom_configuration.provider.credentials
  737. for model_type in model_types:
  738. if model_type not in self.provider.supported_model_types:
  739. continue
  740. models = provider_instance.models(model_type)
  741. for m in models:
  742. status = ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE
  743. load_balancing_enabled = False
  744. if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]:
  745. model_setting = model_setting_map[m.model_type][m.model]
  746. if model_setting.enabled is False:
  747. status = ModelStatus.DISABLED
  748. if len(model_setting.load_balancing_configs) > 1:
  749. load_balancing_enabled = True
  750. provider_models.append(
  751. ModelWithProviderEntity(
  752. model=m.model,
  753. label=m.label,
  754. model_type=m.model_type,
  755. features=m.features,
  756. fetch_from=m.fetch_from,
  757. model_properties=m.model_properties,
  758. deprecated=m.deprecated,
  759. provider=SimpleModelProviderEntity(self.provider),
  760. status=status,
  761. load_balancing_enabled=load_balancing_enabled
  762. )
  763. )
  764. # custom models
  765. for model_configuration in self.custom_configuration.models:
  766. if model_configuration.model_type not in model_types:
  767. continue
  768. try:
  769. custom_model_schema = (
  770. provider_instance.get_model_instance(model_configuration.model_type)
  771. .get_customizable_model_schema_from_credentials(
  772. model_configuration.model,
  773. model_configuration.credentials
  774. )
  775. )
  776. except Exception as ex:
  777. logger.warning(f'get custom model schema failed, {ex}')
  778. continue
  779. if not custom_model_schema:
  780. continue
  781. status = ModelStatus.ACTIVE
  782. load_balancing_enabled = False
  783. if (custom_model_schema.model_type in model_setting_map
  784. and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]):
  785. model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model]
  786. if model_setting.enabled is False:
  787. status = ModelStatus.DISABLED
  788. if len(model_setting.load_balancing_configs) > 1:
  789. load_balancing_enabled = True
  790. provider_models.append(
  791. ModelWithProviderEntity(
  792. model=custom_model_schema.model,
  793. label=custom_model_schema.label,
  794. model_type=custom_model_schema.model_type,
  795. features=custom_model_schema.features,
  796. fetch_from=custom_model_schema.fetch_from,
  797. model_properties=custom_model_schema.model_properties,
  798. deprecated=custom_model_schema.deprecated,
  799. provider=SimpleModelProviderEntity(self.provider),
  800. status=status,
  801. load_balancing_enabled=load_balancing_enabled
  802. )
  803. )
  804. return provider_models
  805. class ProviderConfigurations(BaseModel):
  806. """
  807. Model class for provider configuration dict.
  808. """
  809. tenant_id: str
  810. configurations: dict[str, ProviderConfiguration] = {}
  811. def __init__(self, tenant_id: str):
  812. super().__init__(tenant_id=tenant_id)
  813. def get_models(self,
  814. provider: Optional[str] = None,
  815. model_type: Optional[ModelType] = None,
  816. only_active: bool = False) \
  817. -> list[ModelWithProviderEntity]:
  818. """
  819. Get available models.
  820. If preferred provider type is `system`:
  821. Get the current **system mode** if provider supported,
  822. if all system modes are not available (no quota), it is considered to be the **custom credential mode**.
  823. If there is no model configured in custom mode, it is treated as no_configure.
  824. system > custom > no_configure
  825. If preferred provider type is `custom`:
  826. If custom credentials are configured, it is treated as custom mode.
  827. Otherwise, get the current **system mode** if supported,
  828. If all system modes are not available (no quota), it is treated as no_configure.
  829. custom > system > no_configure
  830. If real mode is `system`, use system credentials to get models,
  831. paid quotas > provider free quotas > system free quotas
  832. include pre-defined models (exclude GPT-4, status marked as `no_permission`).
  833. If real mode is `custom`, use workspace custom credentials to get models,
  834. include pre-defined models, custom models(manual append).
  835. If real mode is `no_configure`, only return pre-defined models from `model runtime`.
  836. (model status marked as `no_configure` if preferred provider type is `custom` otherwise `quota_exceeded`)
  837. model status marked as `active` is available.
  838. :param provider: provider name
  839. :param model_type: model type
  840. :param only_active: only active models
  841. :return:
  842. """
  843. all_models = []
  844. for provider_configuration in self.values():
  845. if provider and provider_configuration.provider.provider != provider:
  846. continue
  847. all_models.extend(provider_configuration.get_provider_models(model_type, only_active))
  848. return all_models
  849. def to_list(self) -> list[ProviderConfiguration]:
  850. """
  851. Convert to list.
  852. :return:
  853. """
  854. return list(self.values())
  855. def __getitem__(self, key):
  856. return self.configurations[key]
  857. def __setitem__(self, key, value):
  858. self.configurations[key] = value
  859. def __iter__(self):
  860. return iter(self.configurations)
  861. def values(self) -> Iterator[ProviderConfiguration]:
  862. return self.configurations.values()
  863. def get(self, key, default=None):
  864. return self.configurations.get(key, default)
  865. class ProviderModelBundle(BaseModel):
  866. """
  867. Provider model bundle.
  868. """
  869. configuration: ProviderConfiguration
  870. provider_instance: ModelProvider
  871. model_type_instance: AIModel
  872. # pydantic configs
  873. model_config = ConfigDict(arbitrary_types_allowed=True,
  874. protected_namespaces=())