model_provider_service.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530
  1. import logging
  2. import mimetypes
  3. import os
  4. from typing import Optional, cast, Tuple
  5. import requests
  6. from flask import current_app
  7. from core.entities.model_entities import ModelStatus
  8. from core.model_runtime.entities.model_entities import ModelType, ParameterRule
  9. from core.model_runtime.model_providers import model_provider_factory
  10. from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
  11. from core.provider_manager import ProviderManager
  12. from models.provider import ProviderType
  13. from services.entities.model_provider_entities import ProviderResponse, CustomConfigurationResponse, \
  14. SystemConfigurationResponse, CustomConfigurationStatus, ProviderWithModelsResponse, ModelResponse, \
  15. DefaultModelResponse, ModelWithProviderEntityResponse
  16. logger = logging.getLogger(__name__)
  17. class ModelProviderService:
  18. """
  19. Model Provider Service
  20. """
  21. def __init__(self) -> None:
  22. self.provider_manager = ProviderManager()
  23. def get_provider_list(self, tenant_id: str, model_type: Optional[str] = None) -> list[ProviderResponse]:
  24. """
  25. get provider list.
  26. :param tenant_id: workspace id
  27. :param model_type: model type
  28. :return:
  29. """
  30. # Get all provider configurations of the current workspace
  31. provider_configurations = self.provider_manager.get_configurations(tenant_id)
  32. provider_responses = []
  33. for provider_configuration in provider_configurations.values():
  34. if model_type:
  35. model_type_entity = ModelType.value_of(model_type)
  36. if model_type_entity not in provider_configuration.provider.supported_model_types:
  37. continue
  38. provider_response = ProviderResponse(
  39. **provider_configuration.provider.dict(),
  40. preferred_provider_type=provider_configuration.preferred_provider_type,
  41. custom_configuration=CustomConfigurationResponse(
  42. status=CustomConfigurationStatus.ACTIVE
  43. if provider_configuration.is_custom_configuration_available()
  44. else CustomConfigurationStatus.NO_CONFIGURE
  45. ),
  46. system_configuration=SystemConfigurationResponse(
  47. **provider_configuration.system_configuration.dict()
  48. )
  49. )
  50. provider_responses.append(provider_response)
  51. return provider_responses
  52. def get_models_by_provider(self, tenant_id: str, provider: str) -> list[ModelWithProviderEntityResponse]:
  53. """
  54. get provider models.
  55. For the model provider page,
  56. only supports passing in a single provider to query the list of supported models.
  57. :param tenant_id:
  58. :param provider:
  59. :return:
  60. """
  61. # Get all provider configurations of the current workspace
  62. provider_configurations = self.provider_manager.get_configurations(tenant_id)
  63. # Get provider available models
  64. return [ModelWithProviderEntityResponse(model) for model in provider_configurations.get_models(
  65. provider=provider
  66. )]
  67. def get_provider_credentials(self, tenant_id: str, provider: str) -> dict:
  68. """
  69. get provider credentials.
  70. :param tenant_id:
  71. :param provider:
  72. :return:
  73. """
  74. # Get all provider configurations of the current workspace
  75. provider_configurations = self.provider_manager.get_configurations(tenant_id)
  76. # Get provider configuration
  77. provider_configuration = provider_configurations.get(provider)
  78. if not provider_configuration:
  79. raise ValueError(f"Provider {provider} does not exist.")
  80. # Get provider custom credentials from workspace
  81. return provider_configuration.get_custom_credentials(obfuscated=True)
  82. def provider_credentials_validate(self, tenant_id: str, provider: str, credentials: dict) -> None:
  83. """
  84. validate provider credentials.
  85. :param tenant_id:
  86. :param provider:
  87. :param credentials:
  88. """
  89. # Get all provider configurations of the current workspace
  90. provider_configurations = self.provider_manager.get_configurations(tenant_id)
  91. # Get provider configuration
  92. provider_configuration = provider_configurations.get(provider)
  93. if not provider_configuration:
  94. raise ValueError(f"Provider {provider} does not exist.")
  95. provider_configuration.custom_credentials_validate(credentials)
  96. def save_provider_credentials(self, tenant_id: str, provider: str, credentials: dict) -> None:
  97. """
  98. save custom provider config.
  99. :param tenant_id: workspace id
  100. :param provider: provider name
  101. :param credentials: provider credentials
  102. :return:
  103. """
  104. # Get all provider configurations of the current workspace
  105. provider_configurations = self.provider_manager.get_configurations(tenant_id)
  106. # Get provider configuration
  107. provider_configuration = provider_configurations.get(provider)
  108. if not provider_configuration:
  109. raise ValueError(f"Provider {provider} does not exist.")
  110. # Add or update custom provider credentials.
  111. provider_configuration.add_or_update_custom_credentials(credentials)
  112. def remove_provider_credentials(self, tenant_id: str, provider: str) -> None:
  113. """
  114. remove custom provider config.
  115. :param tenant_id: workspace id
  116. :param provider: provider name
  117. :return:
  118. """
  119. # Get all provider configurations of the current workspace
  120. provider_configurations = self.provider_manager.get_configurations(tenant_id)
  121. # Get provider configuration
  122. provider_configuration = provider_configurations.get(provider)
  123. if not provider_configuration:
  124. raise ValueError(f"Provider {provider} does not exist.")
  125. # Remove custom provider credentials.
  126. provider_configuration.delete_custom_credentials()
  127. def get_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str) -> dict:
  128. """
  129. get model credentials.
  130. :param tenant_id: workspace id
  131. :param provider: provider name
  132. :param model_type: model type
  133. :param model: model name
  134. :return:
  135. """
  136. # Get all provider configurations of the current workspace
  137. provider_configurations = self.provider_manager.get_configurations(tenant_id)
  138. # Get provider configuration
  139. provider_configuration = provider_configurations.get(provider)
  140. if not provider_configuration:
  141. raise ValueError(f"Provider {provider} does not exist.")
  142. # Get model custom credentials from ProviderModel if exists
  143. return provider_configuration.get_custom_model_credentials(
  144. model_type=ModelType.value_of(model_type),
  145. model=model,
  146. obfuscated=True
  147. )
  148. def model_credentials_validate(self, tenant_id: str, provider: str, model_type: str, model: str,
  149. credentials: dict) -> None:
  150. """
  151. validate model credentials.
  152. :param tenant_id: workspace id
  153. :param provider: provider name
  154. :param model_type: model type
  155. :param model: model name
  156. :param credentials: model credentials
  157. :return:
  158. """
  159. # Get all provider configurations of the current workspace
  160. provider_configurations = self.provider_manager.get_configurations(tenant_id)
  161. # Get provider configuration
  162. provider_configuration = provider_configurations.get(provider)
  163. if not provider_configuration:
  164. raise ValueError(f"Provider {provider} does not exist.")
  165. # Validate model credentials
  166. provider_configuration.custom_model_credentials_validate(
  167. model_type=ModelType.value_of(model_type),
  168. model=model,
  169. credentials=credentials
  170. )
  171. def save_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str,
  172. credentials: dict) -> None:
  173. """
  174. save model credentials.
  175. :param tenant_id: workspace id
  176. :param provider: provider name
  177. :param model_type: model type
  178. :param model: model name
  179. :param credentials: model credentials
  180. :return:
  181. """
  182. # Get all provider configurations of the current workspace
  183. provider_configurations = self.provider_manager.get_configurations(tenant_id)
  184. # Get provider configuration
  185. provider_configuration = provider_configurations.get(provider)
  186. if not provider_configuration:
  187. raise ValueError(f"Provider {provider} does not exist.")
  188. # Add or update custom model credentials
  189. provider_configuration.add_or_update_custom_model_credentials(
  190. model_type=ModelType.value_of(model_type),
  191. model=model,
  192. credentials=credentials
  193. )
  194. def remove_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str) -> None:
  195. """
  196. remove model credentials.
  197. :param tenant_id: workspace id
  198. :param provider: provider name
  199. :param model_type: model type
  200. :param model: model name
  201. :return:
  202. """
  203. # Get all provider configurations of the current workspace
  204. provider_configurations = self.provider_manager.get_configurations(tenant_id)
  205. # Get provider configuration
  206. provider_configuration = provider_configurations.get(provider)
  207. if not provider_configuration:
  208. raise ValueError(f"Provider {provider} does not exist.")
  209. # Remove custom model credentials
  210. provider_configuration.delete_custom_model_credentials(
  211. model_type=ModelType.value_of(model_type),
  212. model=model
  213. )
  214. def get_models_by_model_type(self, tenant_id: str, model_type: str) -> list[ProviderWithModelsResponse]:
  215. """
  216. get models by model type.
  217. :param tenant_id: workspace id
  218. :param model_type: model type
  219. :return:
  220. """
  221. # Get all provider configurations of the current workspace
  222. provider_configurations = self.provider_manager.get_configurations(tenant_id)
  223. # Get provider available models
  224. models = provider_configurations.get_models(
  225. model_type=ModelType.value_of(model_type)
  226. )
  227. # Group models by provider
  228. provider_models = {}
  229. for model in models:
  230. if model.provider.provider not in provider_models:
  231. provider_models[model.provider.provider] = []
  232. if model.deprecated:
  233. continue
  234. provider_models[model.provider.provider].append(model)
  235. # convert to ProviderWithModelsResponse list
  236. providers_with_models: list[ProviderWithModelsResponse] = []
  237. for provider, models in provider_models.items():
  238. if not models:
  239. continue
  240. first_model = models[0]
  241. has_active_models = any([model.status == ModelStatus.ACTIVE for model in models])
  242. providers_with_models.append(
  243. ProviderWithModelsResponse(
  244. provider=provider,
  245. label=first_model.provider.label,
  246. icon_small=first_model.provider.icon_small,
  247. icon_large=first_model.provider.icon_large,
  248. status=CustomConfigurationStatus.ACTIVE
  249. if has_active_models else CustomConfigurationStatus.NO_CONFIGURE,
  250. models=[ModelResponse(
  251. model=model.model,
  252. label=model.label,
  253. model_type=model.model_type,
  254. features=model.features,
  255. fetch_from=model.fetch_from,
  256. model_properties=model.model_properties,
  257. status=model.status
  258. ) for model in models]
  259. )
  260. )
  261. return providers_with_models
  262. def get_model_parameter_rules(self, tenant_id: str, provider: str, model: str) -> list[ParameterRule]:
  263. """
  264. get model parameter rules.
  265. Only supports LLM.
  266. :param tenant_id: workspace id
  267. :param provider: provider name
  268. :param model: model name
  269. :return:
  270. """
  271. # Get all provider configurations of the current workspace
  272. provider_configurations = self.provider_manager.get_configurations(tenant_id)
  273. # Get provider configuration
  274. provider_configuration = provider_configurations.get(provider)
  275. if not provider_configuration:
  276. raise ValueError(f"Provider {provider} does not exist.")
  277. # Get model instance of LLM
  278. model_type_instance = provider_configuration.get_model_type_instance(ModelType.LLM)
  279. model_type_instance = cast(LargeLanguageModel, model_type_instance)
  280. # fetch credentials
  281. credentials = provider_configuration.get_current_credentials(
  282. model_type=ModelType.LLM,
  283. model=model
  284. )
  285. if not credentials:
  286. return []
  287. # Call get_parameter_rules method of model instance to get model parameter rules
  288. return model_type_instance.get_parameter_rules(
  289. model=model,
  290. credentials=credentials
  291. )
  292. def get_default_model_of_model_type(self, tenant_id: str, model_type: str) -> Optional[DefaultModelResponse]:
  293. """
  294. get default model of model type.
  295. :param tenant_id: workspace id
  296. :param model_type: model type
  297. :return:
  298. """
  299. model_type_enum = ModelType.value_of(model_type)
  300. result = self.provider_manager.get_default_model(
  301. tenant_id=tenant_id,
  302. model_type=model_type_enum
  303. )
  304. return DefaultModelResponse(
  305. **result.dict()
  306. ) if result else None
  307. def update_default_model_of_model_type(self, tenant_id: str, model_type: str, provider: str, model: str) -> None:
  308. """
  309. update default model of model type.
  310. :param tenant_id: workspace id
  311. :param model_type: model type
  312. :param provider: provider name
  313. :param model: model name
  314. :return:
  315. """
  316. model_type_enum = ModelType.value_of(model_type)
  317. self.provider_manager.update_default_model_record(
  318. tenant_id=tenant_id,
  319. model_type=model_type_enum,
  320. provider=provider,
  321. model=model
  322. )
  323. def get_model_provider_icon(self, provider: str, icon_type: str, lang: str) -> Tuple[Optional[bytes], Optional[str]]:
  324. """
  325. get model provider icon.
  326. :param provider: provider name
  327. :param icon_type: icon type (icon_small or icon_large)
  328. :param lang: language (zh_Hans or en_US)
  329. :return:
  330. """
  331. provider_instance = model_provider_factory.get_provider_instance(provider)
  332. provider_schema = provider_instance.get_provider_schema()
  333. if icon_type.lower() == 'icon_small':
  334. if not provider_schema.icon_small:
  335. raise ValueError(f"Provider {provider} does not have small icon.")
  336. if lang.lower() == 'zh_hans':
  337. file_name = provider_schema.icon_small.zh_Hans
  338. else:
  339. file_name = provider_schema.icon_small.en_US
  340. else:
  341. if not provider_schema.icon_large:
  342. raise ValueError(f"Provider {provider} does not have large icon.")
  343. if lang.lower() == 'zh_hans':
  344. file_name = provider_schema.icon_large.zh_Hans
  345. else:
  346. file_name = provider_schema.icon_large.en_US
  347. root_path = current_app.root_path
  348. provider_instance_path = os.path.dirname(os.path.join(root_path, provider_instance.__class__.__module__.replace('.', '/')))
  349. file_path = os.path.join(provider_instance_path, "_assets")
  350. file_path = os.path.join(file_path, file_name)
  351. if not os.path.exists(file_path):
  352. return None, None
  353. mimetype, _ = mimetypes.guess_type(file_path)
  354. mimetype = mimetype or 'application/octet-stream'
  355. # read binary from file
  356. with open(file_path, 'rb') as f:
  357. byte_data = f.read()
  358. return byte_data, mimetype
  359. def switch_preferred_provider(self, tenant_id: str, provider: str, preferred_provider_type: str) -> None:
  360. """
  361. switch preferred provider.
  362. :param tenant_id: workspace id
  363. :param provider: provider name
  364. :param preferred_provider_type: preferred provider type
  365. :return:
  366. """
  367. # Get all provider configurations of the current workspace
  368. provider_configurations = self.provider_manager.get_configurations(tenant_id)
  369. # Convert preferred_provider_type to ProviderType
  370. preferred_provider_type_enum = ProviderType.value_of(preferred_provider_type)
  371. # Get provider configuration
  372. provider_configuration = provider_configurations.get(provider)
  373. if not provider_configuration:
  374. raise ValueError(f"Provider {provider} does not exist.")
  375. # Switch preferred provider type
  376. provider_configuration.switch_preferred_provider_type(preferred_provider_type_enum)
  377. def free_quota_submit(self, tenant_id: str, provider: str):
  378. api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY")
  379. api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL")
  380. api_url = api_base_url + '/api/v1/providers/apply'
  381. headers = {
  382. 'Content-Type': 'application/json',
  383. 'Authorization': f"Bearer {api_key}"
  384. }
  385. response = requests.post(api_url, headers=headers, json={'workspace_id': tenant_id, 'provider_name': provider})
  386. if not response.ok:
  387. logger.error(f"Request FREE QUOTA APPLY SERVER Error: {response.status_code} ")
  388. raise ValueError(f"Error: {response.status_code} ")
  389. if response.json()["code"] != 'success':
  390. raise ValueError(
  391. f"error: {response.json()['message']}"
  392. )
  393. rst = response.json()
  394. if rst['type'] == 'redirect':
  395. return {
  396. 'type': rst['type'],
  397. 'redirect_url': rst['redirect_url']
  398. }
  399. else:
  400. return {
  401. 'type': rst['type'],
  402. 'result': 'success'
  403. }
  404. def free_quota_qualification_verify(self, tenant_id: str, provider: str, token: Optional[str]):
  405. api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY")
  406. api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL")
  407. api_url = api_base_url + '/api/v1/providers/qualification-verify'
  408. headers = {
  409. 'Content-Type': 'application/json',
  410. 'Authorization': f"Bearer {api_key}"
  411. }
  412. json_data = {'workspace_id': tenant_id, 'provider_name': provider}
  413. if token:
  414. json_data['token'] = token
  415. response = requests.post(api_url, headers=headers,
  416. json=json_data)
  417. if not response.ok:
  418. logger.error(f"Request FREE QUOTA APPLY SERVER Error: {response.status_code} ")
  419. raise ValueError(f"Error: {response.status_code} ")
  420. rst = response.json()
  421. if rst["code"] != 'success':
  422. raise ValueError(
  423. f"error: {rst['message']}"
  424. )
  425. data = rst['data']
  426. if data['qualified'] is True:
  427. return {
  428. 'result': 'success',
  429. 'provider_name': provider,
  430. 'flag': True
  431. }
  432. else:
  433. return {
  434. 'result': 'success',
  435. 'provider_name': provider,
  436. 'flag': False,
  437. 'reason': data['reason']
  438. }