xinference.py 1.0 KB

123456789101112131415161718192021222324
  1. import requests
  2. from core.tools.errors import ToolProviderCredentialValidationError
  3. from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
  4. class XinferenceProvider(BuiltinToolProviderController):
  5. def _validate_credentials(self, credentials: dict) -> None:
  6. base_url = credentials.get("base_url", "").removesuffix("/")
  7. api_key = credentials.get("api_key", "")
  8. if not api_key:
  9. api_key = "abc"
  10. credentials["api_key"] = api_key
  11. model = credentials.get("model", "")
  12. if not base_url or not model:
  13. raise ToolProviderCredentialValidationError("Xinference base_url and model is required")
  14. headers = {"Authorization": f"Bearer {api_key}"}
  15. res = requests.post(
  16. f"{base_url}/sdapi/v1/options",
  17. headers=headers,
  18. json={"sd_model_checkpoint": model},
  19. )
  20. if res.status_code != 200:
  21. raise ToolProviderCredentialValidationError("Xinference API key is invalid")