provider_service.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596
  1. import datetime
  2. import json
  3. import logging
  4. import os
  5. from collections import defaultdict
  6. from typing import Optional
  7. import requests
  8. from core.model_providers.model_factory import ModelFactory
  9. from extensions.ext_database import db
  10. from core.model_providers.model_provider_factory import ModelProviderFactory
  11. from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules
  12. from models.provider import Provider, ProviderModel, TenantPreferredModelProvider, ProviderType, ProviderQuotaType, \
  13. TenantDefaultModel
  14. class ProviderService:
  15. def get_provider_list(self, tenant_id: str, model_type: Optional[str] = None) -> list:
  16. """
  17. get provider list of tenant.
  18. :param tenant_id: workspace id
  19. :param model_type: filter by model type
  20. :return:
  21. """
  22. # get rules for all providers
  23. model_provider_rules = ModelProviderFactory.get_provider_rules()
  24. model_provider_names = [model_provider_name for model_provider_name, _ in model_provider_rules.items()]
  25. for model_provider_name, model_provider_rule in model_provider_rules.items():
  26. if ProviderType.SYSTEM.value in model_provider_rule['support_provider_types'] \
  27. and 'system_config' in model_provider_rule and model_provider_rule['system_config'] \
  28. and 'supported_quota_types' in model_provider_rule['system_config'] \
  29. and 'trial' in model_provider_rule['system_config']['supported_quota_types']:
  30. ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
  31. configurable_model_provider_names = [
  32. model_provider_name
  33. for model_provider_name, model_provider_rules in model_provider_rules.items()
  34. if 'custom' in model_provider_rules['support_provider_types']
  35. and model_provider_rules['model_flexibility'] == 'configurable'
  36. ]
  37. # get all providers for the tenant
  38. providers = db.session.query(Provider) \
  39. .filter(
  40. Provider.tenant_id == tenant_id,
  41. Provider.provider_name.in_(model_provider_names),
  42. Provider.is_valid == True
  43. ).order_by(Provider.created_at.desc()).all()
  44. provider_name_to_provider_dict = defaultdict(list)
  45. for provider in providers:
  46. provider_name_to_provider_dict[provider.provider_name].append(provider)
  47. # get all configurable provider models for the tenant
  48. provider_models = db.session.query(ProviderModel) \
  49. .filter(
  50. ProviderModel.tenant_id == tenant_id,
  51. ProviderModel.provider_name.in_(configurable_model_provider_names),
  52. ProviderModel.is_valid == True
  53. ).order_by(ProviderModel.created_at.desc()).all()
  54. provider_name_to_provider_model_dict = defaultdict(list)
  55. for provider_model in provider_models:
  56. provider_name_to_provider_model_dict[provider_model.provider_name].append(provider_model)
  57. # get all preferred provider type for the tenant
  58. preferred_provider_types = db.session.query(TenantPreferredModelProvider) \
  59. .filter(
  60. TenantPreferredModelProvider.tenant_id == tenant_id,
  61. TenantPreferredModelProvider.provider_name.in_(model_provider_names)
  62. ).all()
  63. provider_name_to_preferred_provider_type_dict = {preferred_provider_type.provider_name: preferred_provider_type
  64. for preferred_provider_type in preferred_provider_types}
  65. providers_list = {}
  66. for model_provider_name, model_provider_rule in model_provider_rules.items():
  67. if model_type and model_type not in model_provider_rule.get('supported_model_types', []):
  68. continue
  69. # get preferred provider type
  70. preferred_model_provider = provider_name_to_preferred_provider_type_dict.get(model_provider_name)
  71. preferred_provider_type = ModelProviderFactory.get_preferred_type_by_preferred_model_provider(
  72. tenant_id,
  73. model_provider_name,
  74. preferred_model_provider
  75. )
  76. provider_config_dict = {
  77. "preferred_provider_type": preferred_provider_type,
  78. "model_flexibility": model_provider_rule['model_flexibility'],
  79. "supported_model_types": model_provider_rule.get("supported_model_types", []),
  80. }
  81. provider_parameter_dict = {}
  82. if ProviderType.SYSTEM.value in model_provider_rule['support_provider_types']:
  83. for quota_type_enum in ProviderQuotaType:
  84. quota_type = quota_type_enum.value
  85. if quota_type in model_provider_rule['system_config']['supported_quota_types']:
  86. key = ProviderType.SYSTEM.value + ':' + quota_type
  87. provider_parameter_dict[key] = {
  88. "provider_name": model_provider_name,
  89. "provider_type": ProviderType.SYSTEM.value,
  90. "config": None,
  91. "is_valid": False, # need update
  92. "quota_type": quota_type,
  93. "quota_unit": model_provider_rule['system_config']['quota_unit'], # need update
  94. "quota_limit": 0 if quota_type != ProviderQuotaType.TRIAL.value else
  95. model_provider_rule['system_config']['quota_limit'], # need update
  96. "quota_used": 0, # need update
  97. "last_used": None # need update
  98. }
  99. if ProviderType.CUSTOM.value in model_provider_rule['support_provider_types']:
  100. provider_parameter_dict[ProviderType.CUSTOM.value] = {
  101. "provider_name": model_provider_name,
  102. "provider_type": ProviderType.CUSTOM.value,
  103. "config": None, # need update
  104. "models": [], # need update
  105. "is_valid": False,
  106. "last_used": None # need update
  107. }
  108. model_provider_class = ModelProviderFactory.get_model_provider_class(model_provider_name)
  109. current_providers = provider_name_to_provider_dict[model_provider_name]
  110. for provider in current_providers:
  111. if provider.provider_type == ProviderType.SYSTEM.value:
  112. quota_type = provider.quota_type
  113. key = f'{ProviderType.SYSTEM.value}:{quota_type}'
  114. if key in provider_parameter_dict:
  115. provider_parameter_dict[key]['is_valid'] = provider.is_valid
  116. provider_parameter_dict[key]['quota_used'] = provider.quota_used
  117. provider_parameter_dict[key]['quota_limit'] = provider.quota_limit
  118. provider_parameter_dict[key]['last_used'] = int(provider.last_used.timestamp()) \
  119. if provider.last_used else None
  120. elif provider.provider_type == ProviderType.CUSTOM.value \
  121. and ProviderType.CUSTOM.value in provider_parameter_dict:
  122. # if custom
  123. key = ProviderType.CUSTOM.value
  124. provider_parameter_dict[key]['last_used'] = int(provider.last_used.timestamp()) \
  125. if provider.last_used else None
  126. provider_parameter_dict[key]['is_valid'] = provider.is_valid
  127. if model_provider_rule['model_flexibility'] == 'fixed':
  128. provider_parameter_dict[key]['config'] = model_provider_class(provider=provider) \
  129. .get_provider_credentials(obfuscated=True)
  130. else:
  131. models = []
  132. provider_models = provider_name_to_provider_model_dict[model_provider_name]
  133. for provider_model in provider_models:
  134. models.append({
  135. "model_name": provider_model.model_name,
  136. "model_type": provider_model.model_type,
  137. "config": model_provider_class(provider=provider) \
  138. .get_model_credentials(provider_model.model_name,
  139. ModelType.value_of(provider_model.model_type),
  140. obfuscated=True),
  141. "is_valid": provider_model.is_valid
  142. })
  143. provider_parameter_dict[key]['models'] = models
  144. provider_config_dict['providers'] = list(provider_parameter_dict.values())
  145. providers_list[model_provider_name] = provider_config_dict
  146. return providers_list
  147. def custom_provider_config_validate(self, provider_name: str, config: dict) -> None:
  148. """
  149. validate custom provider config.
  150. :param provider_name:
  151. :param config:
  152. :return:
  153. :raises CredentialsValidateFailedError: When the config credential verification fails.
  154. """
  155. # get model provider rules
  156. model_provider_rules = ModelProviderFactory.get_provider_rule(provider_name)
  157. if model_provider_rules['model_flexibility'] != 'fixed':
  158. raise ValueError('Only support fixed model provider')
  159. # only support provider type CUSTOM
  160. if ProviderType.CUSTOM.value not in model_provider_rules['support_provider_types']:
  161. raise ValueError('Only support provider type CUSTOM')
  162. # validate provider config
  163. model_provider_class = ModelProviderFactory.get_model_provider_class(provider_name)
  164. model_provider_class.is_provider_credentials_valid_or_raise(config)
  165. def save_custom_provider_config(self, tenant_id: str, provider_name: str, config: dict) -> None:
  166. """
  167. save custom provider config.
  168. :param tenant_id:
  169. :param provider_name:
  170. :param config:
  171. :return:
  172. """
  173. # validate custom provider config
  174. self.custom_provider_config_validate(provider_name, config)
  175. # get provider
  176. provider = db.session.query(Provider) \
  177. .filter(
  178. Provider.tenant_id == tenant_id,
  179. Provider.provider_name == provider_name,
  180. Provider.provider_type == ProviderType.CUSTOM.value
  181. ).first()
  182. model_provider_class = ModelProviderFactory.get_model_provider_class(provider_name)
  183. encrypted_config = model_provider_class.encrypt_provider_credentials(tenant_id, config)
  184. # save provider
  185. if provider:
  186. provider.encrypted_config = json.dumps(encrypted_config)
  187. provider.is_valid = True
  188. provider.updated_at = datetime.datetime.utcnow()
  189. db.session.commit()
  190. else:
  191. provider = Provider(
  192. tenant_id=tenant_id,
  193. provider_name=provider_name,
  194. provider_type=ProviderType.CUSTOM.value,
  195. encrypted_config=json.dumps(encrypted_config),
  196. is_valid=True
  197. )
  198. db.session.add(provider)
  199. db.session.commit()
  200. def delete_custom_provider(self, tenant_id: str, provider_name: str) -> None:
  201. """
  202. delete custom provider.
  203. :param tenant_id:
  204. :param provider_name:
  205. :return:
  206. """
  207. # get provider
  208. provider = db.session.query(Provider) \
  209. .filter(
  210. Provider.tenant_id == tenant_id,
  211. Provider.provider_name == provider_name,
  212. Provider.provider_type == ProviderType.CUSTOM.value
  213. ).first()
  214. if provider:
  215. try:
  216. self.switch_preferred_provider(tenant_id, provider_name, ProviderType.SYSTEM.value)
  217. except ValueError:
  218. pass
  219. db.session.delete(provider)
  220. db.session.commit()
  221. def custom_provider_model_config_validate(self,
  222. provider_name: str,
  223. model_name: str,
  224. model_type: str,
  225. config: dict) -> None:
  226. """
  227. validate custom provider model config.
  228. :param provider_name:
  229. :param model_name:
  230. :param model_type:
  231. :param config:
  232. :return:
  233. :raises CredentialsValidateFailedError: When the config credential verification fails.
  234. """
  235. # get model provider rules
  236. model_provider_rules = ModelProviderFactory.get_provider_rule(provider_name)
  237. if model_provider_rules['model_flexibility'] != 'configurable':
  238. raise ValueError('Only support configurable model provider')
  239. # only support provider type CUSTOM
  240. if ProviderType.CUSTOM.value not in model_provider_rules['support_provider_types']:
  241. raise ValueError('Only support provider type CUSTOM')
  242. # validate provider model config
  243. model_type = ModelType.value_of(model_type)
  244. model_provider_class = ModelProviderFactory.get_model_provider_class(provider_name)
  245. model_provider_class.is_model_credentials_valid_or_raise(model_name, model_type, config)
  246. def add_or_save_custom_provider_model_config(self,
  247. tenant_id: str,
  248. provider_name: str,
  249. model_name: str,
  250. model_type: str,
  251. config: dict) -> None:
  252. """
  253. Add or save custom provider model config.
  254. :param tenant_id:
  255. :param provider_name:
  256. :param model_name:
  257. :param model_type:
  258. :param config:
  259. :return:
  260. """
  261. # validate custom provider model config
  262. self.custom_provider_model_config_validate(provider_name, model_name, model_type, config)
  263. # get provider
  264. provider = db.session.query(Provider) \
  265. .filter(
  266. Provider.tenant_id == tenant_id,
  267. Provider.provider_name == provider_name,
  268. Provider.provider_type == ProviderType.CUSTOM.value
  269. ).first()
  270. if not provider:
  271. provider = Provider(
  272. tenant_id=tenant_id,
  273. provider_name=provider_name,
  274. provider_type=ProviderType.CUSTOM.value,
  275. is_valid=True
  276. )
  277. db.session.add(provider)
  278. db.session.commit()
  279. elif not provider.is_valid:
  280. provider.is_valid = True
  281. provider.encrypted_config = None
  282. db.session.commit()
  283. model_provider_class = ModelProviderFactory.get_model_provider_class(provider_name)
  284. encrypted_config = model_provider_class.encrypt_model_credentials(
  285. tenant_id,
  286. model_name,
  287. ModelType.value_of(model_type),
  288. config
  289. )
  290. # get provider model
  291. provider_model = db.session.query(ProviderModel) \
  292. .filter(
  293. ProviderModel.tenant_id == tenant_id,
  294. ProviderModel.provider_name == provider_name,
  295. ProviderModel.model_name == model_name,
  296. ProviderModel.model_type == model_type
  297. ).first()
  298. if provider_model:
  299. provider_model.encrypted_config = json.dumps(encrypted_config)
  300. provider_model.is_valid = True
  301. db.session.commit()
  302. else:
  303. provider_model = ProviderModel(
  304. tenant_id=tenant_id,
  305. provider_name=provider_name,
  306. model_name=model_name,
  307. model_type=model_type,
  308. encrypted_config=json.dumps(encrypted_config),
  309. is_valid=True
  310. )
  311. db.session.add(provider_model)
  312. db.session.commit()
  313. def delete_custom_provider_model(self,
  314. tenant_id: str,
  315. provider_name: str,
  316. model_name: str,
  317. model_type: str) -> None:
  318. """
  319. delete custom provider model.
  320. :param tenant_id:
  321. :param provider_name:
  322. :param model_name:
  323. :param model_type:
  324. :return:
  325. """
  326. # get provider model
  327. provider_model = db.session.query(ProviderModel) \
  328. .filter(
  329. ProviderModel.tenant_id == tenant_id,
  330. ProviderModel.provider_name == provider_name,
  331. ProviderModel.model_name == model_name,
  332. ProviderModel.model_type == model_type
  333. ).first()
  334. if provider_model:
  335. db.session.delete(provider_model)
  336. db.session.commit()
  337. def switch_preferred_provider(self, tenant_id: str, provider_name: str, preferred_provider_type: str) -> None:
  338. """
  339. switch preferred provider.
  340. :param tenant_id:
  341. :param provider_name:
  342. :param preferred_provider_type:
  343. :return:
  344. """
  345. provider_type = ProviderType.value_of(preferred_provider_type)
  346. if not provider_type:
  347. raise ValueError(f'Invalid preferred provider type: {preferred_provider_type}')
  348. model_provider_rules = ModelProviderFactory.get_provider_rule(provider_name)
  349. if preferred_provider_type not in model_provider_rules['support_provider_types']:
  350. raise ValueError(f'Not support provider type: {preferred_provider_type}')
  351. model_provider = ModelProviderFactory.get_model_provider_class(provider_name)
  352. if not model_provider.is_provider_type_system_supported():
  353. return
  354. # get preferred provider
  355. preferred_model_provider = db.session.query(TenantPreferredModelProvider) \
  356. .filter(
  357. TenantPreferredModelProvider.tenant_id == tenant_id,
  358. TenantPreferredModelProvider.provider_name == provider_name
  359. ).first()
  360. if preferred_model_provider:
  361. preferred_model_provider.preferred_provider_type = preferred_provider_type
  362. else:
  363. preferred_model_provider = TenantPreferredModelProvider(
  364. tenant_id=tenant_id,
  365. provider_name=provider_name,
  366. preferred_provider_type=preferred_provider_type
  367. )
  368. db.session.add(preferred_model_provider)
  369. db.session.commit()
  370. def get_default_model_of_model_type(self, tenant_id: str, model_type: str) -> Optional[TenantDefaultModel]:
  371. """
  372. get default model of model type.
  373. :param tenant_id:
  374. :param model_type:
  375. :return:
  376. """
  377. return ModelFactory.get_default_model(tenant_id, ModelType.value_of(model_type))
  378. def update_default_model_of_model_type(self,
  379. tenant_id: str,
  380. model_type: str,
  381. provider_name: str,
  382. model_name: str) -> TenantDefaultModel:
  383. """
  384. update default model of model type.
  385. :param tenant_id:
  386. :param model_type:
  387. :param provider_name:
  388. :param model_name:
  389. :return:
  390. """
  391. return ModelFactory.update_default_model(tenant_id, ModelType.value_of(model_type), provider_name, model_name)
  392. def get_valid_model_list(self, tenant_id: str, model_type: str) -> list:
  393. """
  394. get valid model list.
  395. :param tenant_id:
  396. :param model_type:
  397. :return:
  398. """
  399. valid_model_list = []
  400. # get model provider rules
  401. model_provider_rules = ModelProviderFactory.get_provider_rules()
  402. for model_provider_name, model_provider_rule in model_provider_rules.items():
  403. model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
  404. if not model_provider:
  405. continue
  406. model_list = model_provider.get_supported_model_list(ModelType.value_of(model_type))
  407. provider = model_provider.provider
  408. for model in model_list:
  409. valid_model_dict = {
  410. "model_name": model['id'],
  411. "model_display_name": model['name'],
  412. "model_type": model_type,
  413. "model_provider": {
  414. "provider_name": provider.provider_name,
  415. "provider_type": provider.provider_type
  416. },
  417. 'features': []
  418. }
  419. if 'mode' in model:
  420. valid_model_dict['model_mode'] = model['mode']
  421. if 'features' in model:
  422. valid_model_dict['features'] = model['features']
  423. if provider.provider_type == ProviderType.SYSTEM.value:
  424. valid_model_dict['model_provider']['quota_type'] = provider.quota_type
  425. valid_model_dict['model_provider']['quota_unit'] = model_provider_rule['system_config']['quota_unit']
  426. valid_model_dict['model_provider']['quota_limit'] = provider.quota_limit
  427. valid_model_dict['model_provider']['quota_used'] = provider.quota_used
  428. valid_model_list.append(valid_model_dict)
  429. return valid_model_list
  430. def get_model_parameter_rules(self, tenant_id: str, model_provider_name: str, model_name: str, model_type: str) \
  431. -> ModelKwargsRules:
  432. """
  433. get model parameter rules.
  434. It depends on preferred provider in use.
  435. :param tenant_id:
  436. :param model_provider_name:
  437. :param model_name:
  438. :param model_type:
  439. :return:
  440. """
  441. # get model provider
  442. model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
  443. if not model_provider:
  444. # get empty model provider
  445. return ModelKwargsRules()
  446. # get model parameter rules
  447. return model_provider.get_model_parameter_rules(model_name, ModelType.value_of(model_type))
  448. def free_quota_submit(self, tenant_id: str, provider_name: str):
  449. api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY")
  450. api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL")
  451. api_url = api_base_url + '/api/v1/providers/apply'
  452. headers = {
  453. 'Content-Type': 'application/json',
  454. 'Authorization': f"Bearer {api_key}"
  455. }
  456. response = requests.post(api_url, headers=headers, json={'workspace_id': tenant_id, 'provider_name': provider_name})
  457. if not response.ok:
  458. logging.error(f"Request FREE QUOTA APPLY SERVER Error: {response.status_code} ")
  459. raise ValueError(f"Error: {response.status_code} ")
  460. if response.json()["code"] != 'success':
  461. raise ValueError(
  462. f"error: {response.json()['message']}"
  463. )
  464. rst = response.json()
  465. if rst['type'] == 'redirect':
  466. return {
  467. 'type': rst['type'],
  468. 'redirect_url': rst['redirect_url']
  469. }
  470. else:
  471. return {
  472. 'type': rst['type'],
  473. 'result': 'success'
  474. }
  475. def free_quota_qualification_verify(self, tenant_id: str, provider_name: str, token: Optional[str]):
  476. api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY")
  477. api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL")
  478. api_url = api_base_url + '/api/v1/providers/qualification-verify'
  479. headers = {
  480. 'Content-Type': 'application/json',
  481. 'Authorization': f"Bearer {api_key}"
  482. }
  483. json_data = {'workspace_id': tenant_id, 'provider_name': provider_name}
  484. if token:
  485. json_data['token'] = token
  486. response = requests.post(api_url, headers=headers,
  487. json=json_data)
  488. if not response.ok:
  489. logging.error(f"Request FREE QUOTA APPLY SERVER Error: {response.status_code} ")
  490. raise ValueError(f"Error: {response.status_code} ")
  491. rst = response.json()
  492. if rst["code"] != 'success':
  493. raise ValueError(
  494. f"error: {rst['message']}"
  495. )
  496. data = rst['data']
  497. if data['qualified'] is True:
  498. return {
  499. 'result': 'success',
  500. 'provider_name': provider_name,
  501. 'flag': True
  502. }
  503. else:
  504. return {
  505. 'result': 'success',
  506. 'provider_name': provider_name,
  507. 'flag': False,
  508. 'reason': data['reason']
  509. }