provider_configuration.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697
  1. import datetime
  2. import json
  3. import logging
  4. import time
  5. from json import JSONDecodeError
  6. from typing import Optional, List, Dict, Tuple, Iterator
  7. from pydantic import BaseModel
  8. from core.entities.model_entities import ModelWithProviderEntity, ModelStatus, SimpleModelProviderEntity
  9. from core.entities.provider_entities import SystemConfiguration, CustomConfiguration, SystemConfigurationStatus
  10. from core.helper import encrypter
  11. from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
  12. from core.model_runtime.entities.model_entities import ModelType
  13. from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType
  14. from core.model_runtime.model_providers import model_provider_factory
  15. from core.model_runtime.model_providers.__base.ai_model import AIModel
  16. from core.model_runtime.model_providers.__base.model_provider import ModelProvider
  17. from core.model_runtime.utils import encoders
  18. from extensions.ext_database import db
  19. from models.provider import ProviderType, Provider, ProviderModel, TenantPreferredModelProvider
  20. logger = logging.getLogger(__name__)
  21. class ProviderConfiguration(BaseModel):
  22. """
  23. Model class for provider configuration.
  24. """
  25. tenant_id: str
  26. provider: ProviderEntity
  27. preferred_provider_type: ProviderType
  28. using_provider_type: ProviderType
  29. system_configuration: SystemConfiguration
  30. custom_configuration: CustomConfiguration
  31. def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]:
  32. """
  33. Get current credentials.
  34. :param model_type: model type
  35. :param model: model name
  36. :return:
  37. """
  38. if self.using_provider_type == ProviderType.SYSTEM:
  39. return self.system_configuration.credentials
  40. else:
  41. if self.custom_configuration.models:
  42. for model_configuration in self.custom_configuration.models:
  43. if model_configuration.model_type == model_type and model_configuration.model == model:
  44. return model_configuration.credentials
  45. if self.custom_configuration.provider:
  46. return self.custom_configuration.provider.credentials
  47. else:
  48. return None
  49. def get_system_configuration_status(self) -> SystemConfigurationStatus:
  50. """
  51. Get system configuration status.
  52. :return:
  53. """
  54. if self.system_configuration.enabled is False:
  55. return SystemConfigurationStatus.UNSUPPORTED
  56. current_quota_type = self.system_configuration.current_quota_type
  57. current_quota_configuration = next(
  58. (q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type),
  59. None
  60. )
  61. return SystemConfigurationStatus.ACTIVE if current_quota_configuration.is_valid else \
  62. SystemConfigurationStatus.QUOTA_EXCEEDED
  63. def is_custom_configuration_available(self) -> bool:
  64. """
  65. Check custom configuration available.
  66. :return:
  67. """
  68. return (self.custom_configuration.provider is not None
  69. or len(self.custom_configuration.models) > 0)
  70. def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]:
  71. """
  72. Get custom credentials.
  73. :param obfuscated: obfuscated secret data in credentials
  74. :return:
  75. """
  76. if self.custom_configuration.provider is None:
  77. return None
  78. credentials = self.custom_configuration.provider.credentials
  79. if not obfuscated:
  80. return credentials
  81. # Obfuscate credentials
  82. return self._obfuscated_credentials(
  83. credentials=credentials,
  84. credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas
  85. if self.provider.provider_credential_schema else []
  86. )
  87. def custom_credentials_validate(self, credentials: dict) -> Tuple[Provider, dict]:
  88. """
  89. Validate custom credentials.
  90. :param credentials: provider credentials
  91. :return:
  92. """
  93. # get provider
  94. provider_record = db.session.query(Provider) \
  95. .filter(
  96. Provider.tenant_id == self.tenant_id,
  97. Provider.provider_name == self.provider.provider,
  98. Provider.provider_type == ProviderType.CUSTOM.value
  99. ).first()
  100. # Get provider credential secret variables
  101. provider_credential_secret_variables = self._extract_secret_variables(
  102. self.provider.provider_credential_schema.credential_form_schemas
  103. if self.provider.provider_credential_schema else []
  104. )
  105. if provider_record:
  106. try:
  107. original_credentials = json.loads(provider_record.encrypted_config) if provider_record.encrypted_config else {}
  108. except JSONDecodeError:
  109. original_credentials = {}
  110. # encrypt credentials
  111. for key, value in credentials.items():
  112. if key in provider_credential_secret_variables:
  113. # if send [__HIDDEN__] in secret input, it will be same as original value
  114. if value == '[__HIDDEN__]' and key in original_credentials:
  115. credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
  116. model_provider_factory.provider_credentials_validate(
  117. self.provider.provider,
  118. credentials
  119. )
  120. for key, value in credentials.items():
  121. if key in provider_credential_secret_variables:
  122. credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
  123. return provider_record, credentials
  124. def add_or_update_custom_credentials(self, credentials: dict) -> None:
  125. """
  126. Add or update custom provider credentials.
  127. :param credentials:
  128. :return:
  129. """
  130. # validate custom provider config
  131. provider_record, credentials = self.custom_credentials_validate(credentials)
  132. # save provider
  133. # Note: Do not switch the preferred provider, which allows users to use quotas first
  134. if provider_record:
  135. provider_record.encrypted_config = json.dumps(credentials)
  136. provider_record.is_valid = True
  137. provider_record.updated_at = datetime.datetime.utcnow()
  138. db.session.commit()
  139. else:
  140. provider_record = Provider(
  141. tenant_id=self.tenant_id,
  142. provider_name=self.provider.provider,
  143. provider_type=ProviderType.CUSTOM.value,
  144. encrypted_config=json.dumps(credentials),
  145. is_valid=True
  146. )
  147. db.session.add(provider_record)
  148. db.session.commit()
  149. provider_model_credentials_cache = ProviderCredentialsCache(
  150. tenant_id=self.tenant_id,
  151. identity_id=provider_record.id,
  152. cache_type=ProviderCredentialsCacheType.PROVIDER
  153. )
  154. provider_model_credentials_cache.delete()
  155. self.switch_preferred_provider_type(ProviderType.CUSTOM)
  156. def delete_custom_credentials(self) -> None:
  157. """
  158. Delete custom provider credentials.
  159. :return:
  160. """
  161. # get provider
  162. provider_record = db.session.query(Provider) \
  163. .filter(
  164. Provider.tenant_id == self.tenant_id,
  165. Provider.provider_name == self.provider.provider,
  166. Provider.provider_type == ProviderType.CUSTOM.value
  167. ).first()
  168. # delete provider
  169. if provider_record:
  170. self.switch_preferred_provider_type(ProviderType.SYSTEM)
  171. db.session.delete(provider_record)
  172. db.session.commit()
  173. provider_model_credentials_cache = ProviderCredentialsCache(
  174. tenant_id=self.tenant_id,
  175. identity_id=provider_record.id,
  176. cache_type=ProviderCredentialsCacheType.PROVIDER
  177. )
  178. provider_model_credentials_cache.delete()
  179. def get_custom_model_credentials(self, model_type: ModelType, model: str, obfuscated: bool = False) \
  180. -> Optional[dict]:
  181. """
  182. Get custom model credentials.
  183. :param model_type: model type
  184. :param model: model name
  185. :param obfuscated: obfuscated secret data in credentials
  186. :return:
  187. """
  188. if not self.custom_configuration.models:
  189. return None
  190. for model_configuration in self.custom_configuration.models:
  191. if model_configuration.model_type == model_type and model_configuration.model == model:
  192. credentials = model_configuration.credentials
  193. if not obfuscated:
  194. return credentials
  195. # Obfuscate credentials
  196. return self._obfuscated_credentials(
  197. credentials=credentials,
  198. credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas
  199. if self.provider.model_credential_schema else []
  200. )
  201. return None
  202. def custom_model_credentials_validate(self, model_type: ModelType, model: str, credentials: dict) \
  203. -> Tuple[ProviderModel, dict]:
  204. """
  205. Validate custom model credentials.
  206. :param model_type: model type
  207. :param model: model name
  208. :param credentials: model credentials
  209. :return:
  210. """
  211. # get provider model
  212. provider_model_record = db.session.query(ProviderModel) \
  213. .filter(
  214. ProviderModel.tenant_id == self.tenant_id,
  215. ProviderModel.provider_name == self.provider.provider,
  216. ProviderModel.model_name == model,
  217. ProviderModel.model_type == model_type.to_origin_model_type()
  218. ).first()
  219. # Get provider credential secret variables
  220. provider_credential_secret_variables = self._extract_secret_variables(
  221. self.provider.model_credential_schema.credential_form_schemas
  222. if self.provider.model_credential_schema else []
  223. )
  224. if provider_model_record:
  225. try:
  226. original_credentials = json.loads(provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {}
  227. except JSONDecodeError:
  228. original_credentials = {}
  229. # decrypt credentials
  230. for key, value in credentials.items():
  231. if key in provider_credential_secret_variables:
  232. # if send [__HIDDEN__] in secret input, it will be same as original value
  233. if value == '[__HIDDEN__]' and key in original_credentials:
  234. credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
  235. model_provider_factory.model_credentials_validate(
  236. provider=self.provider.provider,
  237. model_type=model_type,
  238. model=model,
  239. credentials=credentials
  240. )
  241. model_schema = (
  242. model_provider_factory.get_provider_instance(self.provider.provider)
  243. .get_model_instance(model_type)._get_customizable_model_schema(
  244. model=model,
  245. credentials=credentials
  246. )
  247. )
  248. if model_schema:
  249. credentials['schema'] = json.dumps(encoders.jsonable_encoder(model_schema))
  250. for key, value in credentials.items():
  251. if key in provider_credential_secret_variables:
  252. credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
  253. return provider_model_record, credentials
  254. def add_or_update_custom_model_credentials(self, model_type: ModelType, model: str, credentials: dict) -> None:
  255. """
  256. Add or update custom model credentials.
  257. :param model_type: model type
  258. :param model: model name
  259. :param credentials: model credentials
  260. :return:
  261. """
  262. # validate custom model config
  263. provider_model_record, credentials = self.custom_model_credentials_validate(model_type, model, credentials)
  264. # save provider model
  265. # Note: Do not switch the preferred provider, which allows users to use quotas first
  266. if provider_model_record:
  267. provider_model_record.encrypted_config = json.dumps(credentials)
  268. provider_model_record.is_valid = True
  269. provider_model_record.updated_at = datetime.datetime.utcnow()
  270. db.session.commit()
  271. else:
  272. provider_model_record = ProviderModel(
  273. tenant_id=self.tenant_id,
  274. provider_name=self.provider.provider,
  275. model_name=model,
  276. model_type=model_type.to_origin_model_type(),
  277. encrypted_config=json.dumps(credentials),
  278. is_valid=True
  279. )
  280. db.session.add(provider_model_record)
  281. db.session.commit()
  282. provider_model_credentials_cache = ProviderCredentialsCache(
  283. tenant_id=self.tenant_id,
  284. identity_id=provider_model_record.id,
  285. cache_type=ProviderCredentialsCacheType.MODEL
  286. )
  287. provider_model_credentials_cache.delete()
  288. def delete_custom_model_credentials(self, model_type: ModelType, model: str) -> None:
  289. """
  290. Delete custom model credentials.
  291. :param model_type: model type
  292. :param model: model name
  293. :return:
  294. """
  295. # get provider model
  296. provider_model_record = db.session.query(ProviderModel) \
  297. .filter(
  298. ProviderModel.tenant_id == self.tenant_id,
  299. ProviderModel.provider_name == self.provider.provider,
  300. ProviderModel.model_name == model,
  301. ProviderModel.model_type == model_type.to_origin_model_type()
  302. ).first()
  303. # delete provider model
  304. if provider_model_record:
  305. db.session.delete(provider_model_record)
  306. db.session.commit()
  307. provider_model_credentials_cache = ProviderCredentialsCache(
  308. tenant_id=self.tenant_id,
  309. identity_id=provider_model_record.id,
  310. cache_type=ProviderCredentialsCacheType.MODEL
  311. )
  312. provider_model_credentials_cache.delete()
  313. def get_provider_instance(self) -> ModelProvider:
  314. """
  315. Get provider instance.
  316. :return:
  317. """
  318. return model_provider_factory.get_provider_instance(self.provider.provider)
  319. def get_model_type_instance(self, model_type: ModelType) -> AIModel:
  320. """
  321. Get current model type instance.
  322. :param model_type: model type
  323. :return:
  324. """
  325. # Get provider instance
  326. provider_instance = self.get_provider_instance()
  327. # Get model instance of LLM
  328. return provider_instance.get_model_instance(model_type)
  329. def switch_preferred_provider_type(self, provider_type: ProviderType) -> None:
  330. """
  331. Switch preferred provider type.
  332. :param provider_type:
  333. :return:
  334. """
  335. if provider_type == self.preferred_provider_type:
  336. return
  337. if provider_type == ProviderType.SYSTEM and not self.system_configuration.enabled:
  338. return
  339. # get preferred provider
  340. preferred_model_provider = db.session.query(TenantPreferredModelProvider) \
  341. .filter(
  342. TenantPreferredModelProvider.tenant_id == self.tenant_id,
  343. TenantPreferredModelProvider.provider_name == self.provider.provider
  344. ).first()
  345. if preferred_model_provider:
  346. preferred_model_provider.preferred_provider_type = provider_type.value
  347. else:
  348. preferred_model_provider = TenantPreferredModelProvider(
  349. tenant_id=self.tenant_id,
  350. provider_name=self.provider.provider,
  351. preferred_provider_type=provider_type.value
  352. )
  353. db.session.add(preferred_model_provider)
  354. db.session.commit()
  355. def _extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]:
  356. """
  357. Extract secret input form variables.
  358. :param credential_form_schemas:
  359. :return:
  360. """
  361. secret_input_form_variables = []
  362. for credential_form_schema in credential_form_schemas:
  363. if credential_form_schema.type == FormType.SECRET_INPUT:
  364. secret_input_form_variables.append(credential_form_schema.variable)
  365. return secret_input_form_variables
  366. def _obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]) -> dict:
  367. """
  368. Obfuscated credentials.
  369. :param credentials: credentials
  370. :param credential_form_schemas: credential form schemas
  371. :return:
  372. """
  373. # Get provider credential secret variables
  374. credential_secret_variables = self._extract_secret_variables(
  375. credential_form_schemas
  376. )
  377. # Obfuscate provider credentials
  378. copy_credentials = credentials.copy()
  379. for key, value in copy_credentials.items():
  380. if key in credential_secret_variables:
  381. copy_credentials[key] = encrypter.obfuscated_token(value)
  382. return copy_credentials
  383. def get_provider_model(self, model_type: ModelType,
  384. model: str,
  385. only_active: bool = False) -> Optional[ModelWithProviderEntity]:
  386. """
  387. Get provider model.
  388. :param model_type: model type
  389. :param model: model name
  390. :param only_active: return active model only
  391. :return:
  392. """
  393. provider_models = self.get_provider_models(model_type, only_active)
  394. for provider_model in provider_models:
  395. if provider_model.model == model:
  396. return provider_model
  397. return None
  398. def get_provider_models(self, model_type: Optional[ModelType] = None,
  399. only_active: bool = False) -> list[ModelWithProviderEntity]:
  400. """
  401. Get provider models.
  402. :param model_type: model type
  403. :param only_active: only active models
  404. :return:
  405. """
  406. provider_instance = self.get_provider_instance()
  407. model_types = []
  408. if model_type:
  409. model_types.append(model_type)
  410. else:
  411. model_types = provider_instance.get_provider_schema().supported_model_types
  412. if self.using_provider_type == ProviderType.SYSTEM:
  413. provider_models = self._get_system_provider_models(
  414. model_types=model_types,
  415. provider_instance=provider_instance
  416. )
  417. else:
  418. provider_models = self._get_custom_provider_models(
  419. model_types=model_types,
  420. provider_instance=provider_instance
  421. )
  422. if only_active:
  423. provider_models = [m for m in provider_models if m.status == ModelStatus.ACTIVE]
  424. # resort provider_models
  425. return sorted(provider_models, key=lambda x: x.model_type.value)
  426. def _get_system_provider_models(self,
  427. model_types: list[ModelType],
  428. provider_instance: ModelProvider) -> list[ModelWithProviderEntity]:
  429. """
  430. Get system provider models.
  431. :param model_types: model types
  432. :param provider_instance: provider instance
  433. :return:
  434. """
  435. provider_models = []
  436. for model_type in model_types:
  437. provider_models.extend(
  438. [
  439. ModelWithProviderEntity(
  440. **m.dict(),
  441. provider=SimpleModelProviderEntity(self.provider),
  442. status=ModelStatus.ACTIVE
  443. )
  444. for m in provider_instance.models(model_type)
  445. ]
  446. )
  447. for quota_configuration in self.system_configuration.quota_configurations:
  448. if self.system_configuration.current_quota_type != quota_configuration.quota_type:
  449. continue
  450. restrict_llms = quota_configuration.restrict_llms
  451. if not restrict_llms:
  452. break
  453. # if llm name not in restricted llm list, remove it
  454. for m in provider_models:
  455. if m.model_type == ModelType.LLM and m.model not in restrict_llms:
  456. m.status = ModelStatus.NO_PERMISSION
  457. elif not quota_configuration.is_valid:
  458. m.status = ModelStatus.QUOTA_EXCEEDED
  459. return provider_models
  460. def _get_custom_provider_models(self,
  461. model_types: list[ModelType],
  462. provider_instance: ModelProvider) -> list[ModelWithProviderEntity]:
  463. """
  464. Get custom provider models.
  465. :param model_types: model types
  466. :param provider_instance: provider instance
  467. :return:
  468. """
  469. provider_models = []
  470. credentials = None
  471. if self.custom_configuration.provider:
  472. credentials = self.custom_configuration.provider.credentials
  473. for model_type in model_types:
  474. if model_type not in self.provider.supported_model_types:
  475. continue
  476. models = provider_instance.models(model_type)
  477. for m in models:
  478. provider_models.append(
  479. ModelWithProviderEntity(
  480. **m.dict(),
  481. provider=SimpleModelProviderEntity(self.provider),
  482. status=ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE
  483. )
  484. )
  485. # custom models
  486. for model_configuration in self.custom_configuration.models:
  487. if model_configuration.model_type not in model_types:
  488. continue
  489. try:
  490. custom_model_schema = (
  491. provider_instance.get_model_instance(model_configuration.model_type)
  492. .get_customizable_model_schema_from_credentials(
  493. model_configuration.model,
  494. model_configuration.credentials
  495. )
  496. )
  497. except Exception as ex:
  498. logger.warning(f'get custom model schema failed, {ex}')
  499. continue
  500. if not custom_model_schema:
  501. continue
  502. provider_models.append(
  503. ModelWithProviderEntity(
  504. **custom_model_schema.dict(),
  505. provider=SimpleModelProviderEntity(self.provider),
  506. status=ModelStatus.ACTIVE
  507. )
  508. )
  509. return provider_models
  510. class ProviderConfigurations(BaseModel):
  511. """
  512. Model class for provider configuration dict.
  513. """
  514. tenant_id: str
  515. configurations: Dict[str, ProviderConfiguration] = {}
  516. def __init__(self, tenant_id: str):
  517. super().__init__(tenant_id=tenant_id)
  518. def get_models(self,
  519. provider: Optional[str] = None,
  520. model_type: Optional[ModelType] = None,
  521. only_active: bool = False) \
  522. -> list[ModelWithProviderEntity]:
  523. """
  524. Get available models.
  525. If preferred provider type is `system`:
  526. Get the current **system mode** if provider supported,
  527. if all system modes are not available (no quota), it is considered to be the **custom credential mode**.
  528. If there is no model configured in custom mode, it is treated as no_configure.
  529. system > custom > no_configure
  530. If preferred provider type is `custom`:
  531. If custom credentials are configured, it is treated as custom mode.
  532. Otherwise, get the current **system mode** if supported,
  533. If all system modes are not available (no quota), it is treated as no_configure.
  534. custom > system > no_configure
  535. If real mode is `system`, use system credentials to get models,
  536. paid quotas > provider free quotas > system free quotas
  537. include pre-defined models (exclude GPT-4, status marked as `no_permission`).
  538. If real mode is `custom`, use workspace custom credentials to get models,
  539. include pre-defined models, custom models(manual append).
  540. If real mode is `no_configure`, only return pre-defined models from `model runtime`.
  541. (model status marked as `no_configure` if preferred provider type is `custom` otherwise `quota_exceeded`)
  542. model status marked as `active` is available.
  543. :param provider: provider name
  544. :param model_type: model type
  545. :param only_active: only active models
  546. :return:
  547. """
  548. all_models = []
  549. for provider_configuration in self.values():
  550. if provider and provider_configuration.provider.provider != provider:
  551. continue
  552. all_models.extend(provider_configuration.get_provider_models(model_type, only_active))
  553. return all_models
  554. def to_list(self) -> List[ProviderConfiguration]:
  555. """
  556. Convert to list.
  557. :return:
  558. """
  559. return list(self.values())
  560. def __getitem__(self, key):
  561. return self.configurations[key]
  562. def __setitem__(self, key, value):
  563. self.configurations[key] = value
  564. def __iter__(self):
  565. return iter(self.configurations)
  566. def values(self) -> Iterator[ProviderConfiguration]:
  567. return self.configurations.values()
  568. def get(self, key, default=None):
  569. return self.configurations.get(key, default)
  570. class ProviderModelBundle(BaseModel):
  571. """
  572. Provider model bundle.
  573. """
  574. configuration: ProviderConfiguration
  575. provider_instance: ModelProvider
  576. model_type_instance: AIModel
  577. class Config:
  578. """Configuration for this pydantic object."""
  579. arbitrary_types_allowed = True