provider_configuration.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664
  1. import datetime
  2. import json
  3. import logging
  4. import time
  5. from json import JSONDecodeError
  6. from typing import Optional, List, Dict, Tuple, Iterator
  7. from pydantic import BaseModel
  8. from core.entities.model_entities import ModelWithProviderEntity, ModelStatus, SimpleModelProviderEntity
  9. from core.entities.provider_entities import SystemConfiguration, CustomConfiguration, SystemConfigurationStatus
  10. from core.helper import encrypter
  11. from core.model_runtime.entities.model_entities import ModelType
  12. from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType
  13. from core.model_runtime.model_providers import model_provider_factory
  14. from core.model_runtime.model_providers.__base.ai_model import AIModel
  15. from core.model_runtime.model_providers.__base.model_provider import ModelProvider
  16. from core.model_runtime.utils import encoders
  17. from extensions.ext_database import db
  18. from models.provider import ProviderType, Provider, ProviderModel, TenantPreferredModelProvider
  19. logger = logging.getLogger(__name__)
  20. class ProviderConfiguration(BaseModel):
  21. """
  22. Model class for provider configuration.
  23. """
  24. tenant_id: str
  25. provider: ProviderEntity
  26. preferred_provider_type: ProviderType
  27. using_provider_type: ProviderType
  28. system_configuration: SystemConfiguration
  29. custom_configuration: CustomConfiguration
  30. def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]:
  31. """
  32. Get current credentials.
  33. :param model_type: model type
  34. :param model: model name
  35. :return:
  36. """
  37. if self.using_provider_type == ProviderType.SYSTEM:
  38. return self.system_configuration.credentials
  39. else:
  40. if self.custom_configuration.models:
  41. for model_configuration in self.custom_configuration.models:
  42. if model_configuration.model_type == model_type and model_configuration.model == model:
  43. return model_configuration.credentials
  44. if self.custom_configuration.provider:
  45. return self.custom_configuration.provider.credentials
  46. else:
  47. return None
  48. def get_system_configuration_status(self) -> SystemConfigurationStatus:
  49. """
  50. Get system configuration status.
  51. :return:
  52. """
  53. if self.system_configuration.enabled is False:
  54. return SystemConfigurationStatus.UNSUPPORTED
  55. current_quota_type = self.system_configuration.current_quota_type
  56. current_quota_configuration = next(
  57. (q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type),
  58. None
  59. )
  60. return SystemConfigurationStatus.ACTIVE if current_quota_configuration.is_valid else \
  61. SystemConfigurationStatus.QUOTA_EXCEEDED
  62. def is_custom_configuration_available(self) -> bool:
  63. """
  64. Check custom configuration available.
  65. :return:
  66. """
  67. return (self.custom_configuration.provider is not None
  68. or len(self.custom_configuration.models) > 0)
  69. def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]:
  70. """
  71. Get custom credentials.
  72. :param obfuscated: obfuscated secret data in credentials
  73. :return:
  74. """
  75. if self.custom_configuration.provider is None:
  76. return None
  77. credentials = self.custom_configuration.provider.credentials
  78. if not obfuscated:
  79. return credentials
  80. # Obfuscate credentials
  81. return self._obfuscated_credentials(
  82. credentials=credentials,
  83. credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas
  84. if self.provider.provider_credential_schema else []
  85. )
  86. def custom_credentials_validate(self, credentials: dict) -> Tuple[Provider, dict]:
  87. """
  88. Validate custom credentials.
  89. :param credentials: provider credentials
  90. :return:
  91. """
  92. # get provider
  93. provider_record = db.session.query(Provider) \
  94. .filter(
  95. Provider.tenant_id == self.tenant_id,
  96. Provider.provider_name == self.provider.provider,
  97. Provider.provider_type == ProviderType.CUSTOM.value
  98. ).first()
  99. # Get provider credential secret variables
  100. provider_credential_secret_variables = self._extract_secret_variables(
  101. self.provider.provider_credential_schema.credential_form_schemas
  102. if self.provider.provider_credential_schema else []
  103. )
  104. if provider_record:
  105. try:
  106. original_credentials = json.loads(provider_record.encrypted_config) if provider_record.encrypted_config else {}
  107. except JSONDecodeError:
  108. original_credentials = {}
  109. # encrypt credentials
  110. for key, value in credentials.items():
  111. if key in provider_credential_secret_variables:
  112. # if send [__HIDDEN__] in secret input, it will be same as original value
  113. if value == '[__HIDDEN__]' and key in original_credentials:
  114. credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
  115. model_provider_factory.provider_credentials_validate(
  116. self.provider.provider,
  117. credentials
  118. )
  119. for key, value in credentials.items():
  120. if key in provider_credential_secret_variables:
  121. credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
  122. return provider_record, credentials
  123. def add_or_update_custom_credentials(self, credentials: dict) -> None:
  124. """
  125. Add or update custom provider credentials.
  126. :param credentials:
  127. :return:
  128. """
  129. # validate custom provider config
  130. provider_record, credentials = self.custom_credentials_validate(credentials)
  131. # save provider
  132. # Note: Do not switch the preferred provider, which allows users to use quotas first
  133. if provider_record:
  134. provider_record.encrypted_config = json.dumps(credentials)
  135. provider_record.is_valid = True
  136. provider_record.updated_at = datetime.datetime.utcnow()
  137. db.session.commit()
  138. else:
  139. provider_record = Provider(
  140. tenant_id=self.tenant_id,
  141. provider_name=self.provider.provider,
  142. provider_type=ProviderType.CUSTOM.value,
  143. encrypted_config=json.dumps(credentials),
  144. is_valid=True
  145. )
  146. db.session.add(provider_record)
  147. db.session.commit()
  148. self.switch_preferred_provider_type(ProviderType.CUSTOM)
  149. def delete_custom_credentials(self) -> None:
  150. """
  151. Delete custom provider credentials.
  152. :return:
  153. """
  154. # get provider
  155. provider_record = db.session.query(Provider) \
  156. .filter(
  157. Provider.tenant_id == self.tenant_id,
  158. Provider.provider_name == self.provider.provider,
  159. Provider.provider_type == ProviderType.CUSTOM.value
  160. ).first()
  161. # delete provider
  162. if provider_record:
  163. self.switch_preferred_provider_type(ProviderType.SYSTEM)
  164. db.session.delete(provider_record)
  165. db.session.commit()
  166. def get_custom_model_credentials(self, model_type: ModelType, model: str, obfuscated: bool = False) \
  167. -> Optional[dict]:
  168. """
  169. Get custom model credentials.
  170. :param model_type: model type
  171. :param model: model name
  172. :param obfuscated: obfuscated secret data in credentials
  173. :return:
  174. """
  175. if not self.custom_configuration.models:
  176. return None
  177. for model_configuration in self.custom_configuration.models:
  178. if model_configuration.model_type == model_type and model_configuration.model == model:
  179. credentials = model_configuration.credentials
  180. if not obfuscated:
  181. return credentials
  182. # Obfuscate credentials
  183. return self._obfuscated_credentials(
  184. credentials=credentials,
  185. credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas
  186. if self.provider.model_credential_schema else []
  187. )
  188. return None
  189. def custom_model_credentials_validate(self, model_type: ModelType, model: str, credentials: dict) \
  190. -> Tuple[ProviderModel, dict]:
  191. """
  192. Validate custom model credentials.
  193. :param model_type: model type
  194. :param model: model name
  195. :param credentials: model credentials
  196. :return:
  197. """
  198. # get provider model
  199. provider_model_record = db.session.query(ProviderModel) \
  200. .filter(
  201. ProviderModel.tenant_id == self.tenant_id,
  202. ProviderModel.provider_name == self.provider.provider,
  203. ProviderModel.model_name == model,
  204. ProviderModel.model_type == model_type.to_origin_model_type()
  205. ).first()
  206. # Get provider credential secret variables
  207. provider_credential_secret_variables = self._extract_secret_variables(
  208. self.provider.model_credential_schema.credential_form_schemas
  209. if self.provider.model_credential_schema else []
  210. )
  211. if provider_model_record:
  212. try:
  213. original_credentials = json.loads(provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {}
  214. except JSONDecodeError:
  215. original_credentials = {}
  216. # decrypt credentials
  217. for key, value in credentials.items():
  218. if key in provider_credential_secret_variables:
  219. # if send [__HIDDEN__] in secret input, it will be same as original value
  220. if value == '[__HIDDEN__]' and key in original_credentials:
  221. credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
  222. model_provider_factory.model_credentials_validate(
  223. provider=self.provider.provider,
  224. model_type=model_type,
  225. model=model,
  226. credentials=credentials
  227. )
  228. model_schema = (
  229. model_provider_factory.get_provider_instance(self.provider.provider)
  230. .get_model_instance(model_type)._get_customizable_model_schema(
  231. model=model,
  232. credentials=credentials
  233. )
  234. )
  235. if model_schema:
  236. credentials['schema'] = json.dumps(encoders.jsonable_encoder(model_schema))
  237. for key, value in credentials.items():
  238. if key in provider_credential_secret_variables:
  239. credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
  240. return provider_model_record, credentials
  241. def add_or_update_custom_model_credentials(self, model_type: ModelType, model: str, credentials: dict) -> None:
  242. """
  243. Add or update custom model credentials.
  244. :param model_type: model type
  245. :param model: model name
  246. :param credentials: model credentials
  247. :return:
  248. """
  249. # validate custom model config
  250. provider_model_record, credentials = self.custom_model_credentials_validate(model_type, model, credentials)
  251. # save provider model
  252. # Note: Do not switch the preferred provider, which allows users to use quotas first
  253. if provider_model_record:
  254. provider_model_record.encrypted_config = json.dumps(credentials)
  255. provider_model_record.is_valid = True
  256. provider_model_record.updated_at = datetime.datetime.utcnow()
  257. db.session.commit()
  258. else:
  259. provider_model_record = ProviderModel(
  260. tenant_id=self.tenant_id,
  261. provider_name=self.provider.provider,
  262. model_name=model,
  263. model_type=model_type.to_origin_model_type(),
  264. encrypted_config=json.dumps(credentials),
  265. is_valid=True
  266. )
  267. db.session.add(provider_model_record)
  268. db.session.commit()
  269. def delete_custom_model_credentials(self, model_type: ModelType, model: str) -> None:
  270. """
  271. Delete custom model credentials.
  272. :param model_type: model type
  273. :param model: model name
  274. :return:
  275. """
  276. # get provider model
  277. provider_model_record = db.session.query(ProviderModel) \
  278. .filter(
  279. ProviderModel.tenant_id == self.tenant_id,
  280. ProviderModel.provider_name == self.provider.provider,
  281. ProviderModel.model_name == model,
  282. ProviderModel.model_type == model_type.to_origin_model_type()
  283. ).first()
  284. # delete provider model
  285. if provider_model_record:
  286. db.session.delete(provider_model_record)
  287. db.session.commit()
  288. def get_provider_instance(self) -> ModelProvider:
  289. """
  290. Get provider instance.
  291. :return:
  292. """
  293. return model_provider_factory.get_provider_instance(self.provider.provider)
  294. def get_model_type_instance(self, model_type: ModelType) -> AIModel:
  295. """
  296. Get current model type instance.
  297. :param model_type: model type
  298. :return:
  299. """
  300. # Get provider instance
  301. provider_instance = self.get_provider_instance()
  302. # Get model instance of LLM
  303. return provider_instance.get_model_instance(model_type)
  304. def switch_preferred_provider_type(self, provider_type: ProviderType) -> None:
  305. """
  306. Switch preferred provider type.
  307. :param provider_type:
  308. :return:
  309. """
  310. if provider_type == self.preferred_provider_type:
  311. return
  312. if provider_type == ProviderType.SYSTEM and not self.system_configuration.enabled:
  313. return
  314. # get preferred provider
  315. preferred_model_provider = db.session.query(TenantPreferredModelProvider) \
  316. .filter(
  317. TenantPreferredModelProvider.tenant_id == self.tenant_id,
  318. TenantPreferredModelProvider.provider_name == self.provider.provider
  319. ).first()
  320. if preferred_model_provider:
  321. preferred_model_provider.preferred_provider_type = provider_type.value
  322. else:
  323. preferred_model_provider = TenantPreferredModelProvider(
  324. tenant_id=self.tenant_id,
  325. provider_name=self.provider.provider,
  326. preferred_provider_type=provider_type.value
  327. )
  328. db.session.add(preferred_model_provider)
  329. db.session.commit()
  330. def _extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]:
  331. """
  332. Extract secret input form variables.
  333. :param credential_form_schemas:
  334. :return:
  335. """
  336. secret_input_form_variables = []
  337. for credential_form_schema in credential_form_schemas:
  338. if credential_form_schema.type == FormType.SECRET_INPUT:
  339. secret_input_form_variables.append(credential_form_schema.variable)
  340. return secret_input_form_variables
  341. def _obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]) -> dict:
  342. """
  343. Obfuscated credentials.
  344. :param credentials: credentials
  345. :param credential_form_schemas: credential form schemas
  346. :return:
  347. """
  348. # Get provider credential secret variables
  349. credential_secret_variables = self._extract_secret_variables(
  350. credential_form_schemas
  351. )
  352. # Obfuscate provider credentials
  353. copy_credentials = credentials.copy()
  354. for key, value in copy_credentials.items():
  355. if key in credential_secret_variables:
  356. copy_credentials[key] = encrypter.obfuscated_token(value)
  357. return copy_credentials
  358. def get_provider_model(self, model_type: ModelType,
  359. model: str,
  360. only_active: bool = False) -> Optional[ModelWithProviderEntity]:
  361. """
  362. Get provider model.
  363. :param model_type: model type
  364. :param model: model name
  365. :param only_active: return active model only
  366. :return:
  367. """
  368. provider_models = self.get_provider_models(model_type, only_active)
  369. for provider_model in provider_models:
  370. if provider_model.model == model:
  371. return provider_model
  372. return None
  373. def get_provider_models(self, model_type: Optional[ModelType] = None,
  374. only_active: bool = False) -> list[ModelWithProviderEntity]:
  375. """
  376. Get provider models.
  377. :param model_type: model type
  378. :param only_active: only active models
  379. :return:
  380. """
  381. provider_instance = self.get_provider_instance()
  382. model_types = []
  383. if model_type:
  384. model_types.append(model_type)
  385. else:
  386. model_types = provider_instance.get_provider_schema().supported_model_types
  387. if self.using_provider_type == ProviderType.SYSTEM:
  388. provider_models = self._get_system_provider_models(
  389. model_types=model_types,
  390. provider_instance=provider_instance
  391. )
  392. else:
  393. provider_models = self._get_custom_provider_models(
  394. model_types=model_types,
  395. provider_instance=provider_instance
  396. )
  397. if only_active:
  398. provider_models = [m for m in provider_models if m.status == ModelStatus.ACTIVE]
  399. # resort provider_models
  400. return sorted(provider_models, key=lambda x: x.model_type.value)
  401. def _get_system_provider_models(self,
  402. model_types: list[ModelType],
  403. provider_instance: ModelProvider) -> list[ModelWithProviderEntity]:
  404. """
  405. Get system provider models.
  406. :param model_types: model types
  407. :param provider_instance: provider instance
  408. :return:
  409. """
  410. provider_models = []
  411. for model_type in model_types:
  412. provider_models.extend(
  413. [
  414. ModelWithProviderEntity(
  415. **m.dict(),
  416. provider=SimpleModelProviderEntity(self.provider),
  417. status=ModelStatus.ACTIVE
  418. )
  419. for m in provider_instance.models(model_type)
  420. ]
  421. )
  422. for quota_configuration in self.system_configuration.quota_configurations:
  423. if self.system_configuration.current_quota_type != quota_configuration.quota_type:
  424. continue
  425. restrict_llms = quota_configuration.restrict_llms
  426. if not restrict_llms:
  427. break
  428. # if llm name not in restricted llm list, remove it
  429. for m in provider_models:
  430. if m.model_type == ModelType.LLM and m.model not in restrict_llms:
  431. m.status = ModelStatus.NO_PERMISSION
  432. elif not quota_configuration.is_valid:
  433. m.status = ModelStatus.QUOTA_EXCEEDED
  434. return provider_models
  435. def _get_custom_provider_models(self,
  436. model_types: list[ModelType],
  437. provider_instance: ModelProvider) -> list[ModelWithProviderEntity]:
  438. """
  439. Get custom provider models.
  440. :param model_types: model types
  441. :param provider_instance: provider instance
  442. :return:
  443. """
  444. provider_models = []
  445. credentials = None
  446. if self.custom_configuration.provider:
  447. credentials = self.custom_configuration.provider.credentials
  448. for model_type in model_types:
  449. if model_type not in self.provider.supported_model_types:
  450. continue
  451. models = provider_instance.models(model_type)
  452. for m in models:
  453. provider_models.append(
  454. ModelWithProviderEntity(
  455. **m.dict(),
  456. provider=SimpleModelProviderEntity(self.provider),
  457. status=ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE
  458. )
  459. )
  460. # custom models
  461. for model_configuration in self.custom_configuration.models:
  462. if model_configuration.model_type not in model_types:
  463. continue
  464. try:
  465. custom_model_schema = (
  466. provider_instance.get_model_instance(model_configuration.model_type)
  467. .get_customizable_model_schema_from_credentials(
  468. model_configuration.model,
  469. model_configuration.credentials
  470. )
  471. )
  472. except Exception as ex:
  473. logger.warning(f'get custom model schema failed, {ex}')
  474. continue
  475. if not custom_model_schema:
  476. continue
  477. provider_models.append(
  478. ModelWithProviderEntity(
  479. **custom_model_schema.dict(),
  480. provider=SimpleModelProviderEntity(self.provider),
  481. status=ModelStatus.ACTIVE
  482. )
  483. )
  484. return provider_models
  485. class ProviderConfigurations(BaseModel):
  486. """
  487. Model class for provider configuration dict.
  488. """
  489. tenant_id: str
  490. configurations: Dict[str, ProviderConfiguration] = {}
  491. def __init__(self, tenant_id: str):
  492. super().__init__(tenant_id=tenant_id)
  493. def get_models(self,
  494. provider: Optional[str] = None,
  495. model_type: Optional[ModelType] = None,
  496. only_active: bool = False) \
  497. -> list[ModelWithProviderEntity]:
  498. """
  499. Get available models.
  500. If preferred provider type is `system`:
  501. Get the current **system mode** if provider supported,
  502. if all system modes are not available (no quota), it is considered to be the **custom credential mode**.
  503. If there is no model configured in custom mode, it is treated as no_configure.
  504. system > custom > no_configure
  505. If preferred provider type is `custom`:
  506. If custom credentials are configured, it is treated as custom mode.
  507. Otherwise, get the current **system mode** if supported,
  508. If all system modes are not available (no quota), it is treated as no_configure.
  509. custom > system > no_configure
  510. If real mode is `system`, use system credentials to get models,
  511. paid quotas > provider free quotas > system free quotas
  512. include pre-defined models (exclude GPT-4, status marked as `no_permission`).
  513. If real mode is `custom`, use workspace custom credentials to get models,
  514. include pre-defined models, custom models(manual append).
  515. If real mode is `no_configure`, only return pre-defined models from `model runtime`.
  516. (model status marked as `no_configure` if preferred provider type is `custom` otherwise `quota_exceeded`)
  517. model status marked as `active` is available.
  518. :param provider: provider name
  519. :param model_type: model type
  520. :param only_active: only active models
  521. :return:
  522. """
  523. all_models = []
  524. for provider_configuration in self.values():
  525. if provider and provider_configuration.provider.provider != provider:
  526. continue
  527. all_models.extend(provider_configuration.get_provider_models(model_type, only_active))
  528. return all_models
  529. def to_list(self) -> List[ProviderConfiguration]:
  530. """
  531. Convert to list.
  532. :return:
  533. """
  534. return list(self.values())
  535. def __getitem__(self, key):
  536. return self.configurations[key]
  537. def __setitem__(self, key, value):
  538. self.configurations[key] = value
  539. def __iter__(self):
  540. return iter(self.configurations)
  541. def values(self) -> Iterator[ProviderConfiguration]:
  542. return self.configurations.values()
  543. def get(self, key, default=None):
  544. return self.configurations.get(key, default)
  545. class ProviderModelBundle(BaseModel):
  546. """
  547. Provider model bundle.
  548. """
  549. configuration: ProviderConfiguration
  550. provider_instance: ModelProvider
  551. model_type_instance: AIModel
  552. class Config:
  553. """Configuration for this pydantic object."""
  554. arbitrary_types_allowed = True