mirror of
https://github.com/Electric-Special/ha-core.git
synced 2026-03-21 00:03:16 +01:00
OAuth2.0 token request error handling (#153167)
Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
@@ -27,6 +27,15 @@
|
||||
"multiple_integration_config_errors": {
|
||||
"message": "Failed to process config for integration {domain} due to multiple ({errors}) errors. Check the logs for more information."
|
||||
},
|
||||
"oauth2_helper_reauth_required": {
|
||||
"message": "Credentials are invalid, re-authentication required"
|
||||
},
|
||||
"oauth2_helper_refresh_failed": {
|
||||
"message": "OAuth2 token refresh failed for {domain}"
|
||||
},
|
||||
"oauth2_helper_refresh_transient": {
|
||||
"message": "Temporary error refreshing credentials for {domain}, try again later"
|
||||
},
|
||||
"platform_component_load_err": {
|
||||
"message": "Platform error: {domain} - {error}."
|
||||
},
|
||||
|
||||
@@ -6,6 +6,9 @@ from collections.abc import Callable, Generator, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from aiohttp import ClientResponse, ClientResponseError, RequestInfo
|
||||
from multidict import MultiMapping
|
||||
|
||||
from .util.event_type import EventType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -218,6 +221,63 @@ class ConfigEntryAuthFailed(IntegrationError):
|
||||
"""Error to indicate that config entry could not authenticate."""
|
||||
|
||||
|
||||
class OAuth2TokenRequestError(ClientResponseError, HomeAssistantError):
|
||||
"""Error to indicate that the OAuth 2.0 flow could not refresh token."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
request_info: RequestInfo,
|
||||
history: tuple[ClientResponse, ...] = (),
|
||||
status: int = 0,
|
||||
message: str = "OAuth 2.0 token refresh failed",
|
||||
headers: MultiMapping[str] | None = None,
|
||||
domain: str,
|
||||
) -> None:
|
||||
"""Initialize OAuth2RefreshTokenFailed."""
|
||||
ClientResponseError.__init__(
|
||||
self,
|
||||
request_info=request_info,
|
||||
history=history,
|
||||
status=status,
|
||||
message=message,
|
||||
headers=headers,
|
||||
)
|
||||
HomeAssistantError.__init__(self)
|
||||
self.domain = domain
|
||||
self.translation_domain = "homeassistant"
|
||||
self.translation_key = "oauth2_helper_refresh_failed"
|
||||
self.translation_placeholders = {"domain": domain}
|
||||
self.generate_message = True
|
||||
|
||||
|
||||
class OAuth2TokenRequestTransientError(OAuth2TokenRequestError):
|
||||
"""Recoverable error to indicate flow could not refresh token."""
|
||||
|
||||
def __init__(self, *, domain: str, **kwargs: Any) -> None:
|
||||
"""Initialize OAuth2RefreshTokenTransientError."""
|
||||
super().__init__(domain=domain, **kwargs)
|
||||
self.translation_domain = "homeassistant"
|
||||
self.translation_key = "oauth2_helper_refresh_transient"
|
||||
self.translation_placeholders = {"domain": domain}
|
||||
self.generate_message = True
|
||||
|
||||
|
||||
class OAuth2TokenRequestReauthError(OAuth2TokenRequestError):
|
||||
"""Non recoverable error to indicate the flow could not refresh token.
|
||||
|
||||
Re-authentication is required.
|
||||
"""
|
||||
|
||||
def __init__(self, *, domain: str, **kwargs: Any) -> None:
|
||||
"""Initialize OAuth2RefreshTokenReauthError."""
|
||||
super().__init__(domain=domain, **kwargs)
|
||||
self.translation_domain = "homeassistant"
|
||||
self.translation_key = "oauth2_helper_reauth_required"
|
||||
self.translation_placeholders = {"domain": domain}
|
||||
self.generate_message = True
|
||||
|
||||
|
||||
class InvalidStateError(HomeAssistantError):
|
||||
"""When an invalid state is encountered."""
|
||||
|
||||
|
||||
@@ -29,7 +29,12 @@ from yarl import URL
|
||||
|
||||
from homeassistant import config_entries
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.exceptions import (
|
||||
HomeAssistantError,
|
||||
OAuth2TokenRequestError,
|
||||
OAuth2TokenRequestReauthError,
|
||||
OAuth2TokenRequestTransientError,
|
||||
)
|
||||
from homeassistant.loader import async_get_application_credentials
|
||||
from homeassistant.util.hass_dict import HassKey
|
||||
|
||||
@@ -56,6 +61,7 @@ AUTH_CALLBACK_PATH = "/auth/external/callback"
|
||||
HEADER_FRONTEND_BASE = "HA-Frontend-Base"
|
||||
MY_AUTH_CALLBACK_PATH = "https://my.home-assistant.io/redirect/oauth"
|
||||
|
||||
|
||||
CLOCK_OUT_OF_SYNC_MAX_SEC = 20
|
||||
|
||||
OAUTH_AUTHORIZE_URL_TIMEOUT_SEC = 30
|
||||
@@ -134,7 +140,10 @@ class AbstractOAuth2Implementation(ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def _async_refresh_token(self, token: dict) -> dict:
|
||||
"""Refresh a token."""
|
||||
"""Refresh a token.
|
||||
|
||||
Should raise OAuth2TokenRequestError on token refresh failure.
|
||||
"""
|
||||
|
||||
|
||||
class LocalOAuth2Implementation(AbstractOAuth2Implementation):
|
||||
@@ -211,7 +220,8 @@ class LocalOAuth2Implementation(AbstractOAuth2Implementation):
|
||||
return await self._token_request(request_data)
|
||||
|
||||
async def _async_refresh_token(self, token: dict) -> dict:
|
||||
"""Refresh tokens."""
|
||||
"""Refresh a token."""
|
||||
|
||||
new_token = await self._token_request(
|
||||
{
|
||||
"grant_type": "refresh_token",
|
||||
@@ -219,33 +229,71 @@ class LocalOAuth2Implementation(AbstractOAuth2Implementation):
|
||||
"refresh_token": token["refresh_token"],
|
||||
}
|
||||
)
|
||||
|
||||
return {**token, **new_token}
|
||||
|
||||
async def _token_request(self, data: dict) -> dict:
|
||||
"""Make a token request."""
|
||||
"""Make a token request.
|
||||
|
||||
Raises OAuth2TokenRequestError on token request failure.
|
||||
"""
|
||||
session = async_get_clientsession(self.hass)
|
||||
|
||||
data["client_id"] = self.client_id
|
||||
|
||||
if self.client_secret:
|
||||
data["client_secret"] = self.client_secret
|
||||
|
||||
_LOGGER.debug("Sending token request to %s", self.token_url)
|
||||
resp = await session.post(self.token_url, data=data)
|
||||
if resp.status >= 400:
|
||||
try:
|
||||
error_response = await resp.json()
|
||||
except ClientError, JSONDecodeError:
|
||||
error_response = {}
|
||||
error_code = error_response.get("error", "unknown")
|
||||
error_description = error_response.get("error_description", "unknown error")
|
||||
_LOGGER.error(
|
||||
"Token request for %s failed (%s): %s",
|
||||
self.domain,
|
||||
error_code,
|
||||
error_description,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
|
||||
try:
|
||||
resp = await session.post(self.token_url, data=data)
|
||||
if resp.status >= 400:
|
||||
try:
|
||||
error_response = await resp.json()
|
||||
except ClientError, JSONDecodeError:
|
||||
error_response = {}
|
||||
error_code = error_response.get("error", "unknown")
|
||||
error_description = error_response.get(
|
||||
"error_description", "unknown error"
|
||||
)
|
||||
_LOGGER.error(
|
||||
"Token request for %s failed (%s): %s",
|
||||
self.domain,
|
||||
error_code,
|
||||
error_description,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
except ClientResponseError as err:
|
||||
if err.status == HTTPStatus.TOO_MANY_REQUESTS or 500 <= err.status <= 599:
|
||||
# Recoverable error
|
||||
raise OAuth2TokenRequestTransientError(
|
||||
request_info=err.request_info,
|
||||
history=err.history,
|
||||
status=err.status,
|
||||
message=err.message,
|
||||
headers=err.headers,
|
||||
domain=self._domain,
|
||||
) from err
|
||||
if 400 <= err.status <= 499:
|
||||
# Non-recoverable error
|
||||
raise OAuth2TokenRequestReauthError(
|
||||
request_info=err.request_info,
|
||||
history=err.history,
|
||||
status=err.status,
|
||||
message=err.message,
|
||||
headers=err.headers,
|
||||
domain=self._domain,
|
||||
) from err
|
||||
|
||||
raise OAuth2TokenRequestError(
|
||||
request_info=err.request_info,
|
||||
history=err.history,
|
||||
status=err.status,
|
||||
message=err.message,
|
||||
headers=err.headers,
|
||||
domain=self._domain,
|
||||
) from err
|
||||
|
||||
return cast(dict, await resp.json())
|
||||
|
||||
|
||||
@@ -458,12 +506,12 @@ class AbstractOAuth2FlowHandler(config_entries.ConfigFlow, metaclass=ABCMeta):
|
||||
except TimeoutError as err:
|
||||
_LOGGER.error("Timeout resolving OAuth token: %s", err)
|
||||
return self.async_abort(reason="oauth_timeout")
|
||||
except (ClientResponseError, ClientError) as err:
|
||||
except (
|
||||
OAuth2TokenRequestError,
|
||||
ClientError,
|
||||
) as err:
|
||||
_LOGGER.error("Error resolving OAuth token: %s", err)
|
||||
if (
|
||||
isinstance(err, ClientResponseError)
|
||||
and err.status == HTTPStatus.UNAUTHORIZED
|
||||
):
|
||||
if isinstance(err, OAuth2TokenRequestReauthError):
|
||||
return self.async_abort(reason="oauth_unauthorized")
|
||||
return self.async_abort(reason="oauth_failed")
|
||||
|
||||
|
||||
@@ -25,6 +25,8 @@ from homeassistant.exceptions import (
|
||||
ConfigEntryError,
|
||||
ConfigEntryNotReady,
|
||||
HomeAssistantError,
|
||||
OAuth2TokenRequestError,
|
||||
OAuth2TokenRequestReauthError,
|
||||
)
|
||||
from homeassistant.util.dt import utcnow
|
||||
|
||||
@@ -352,6 +354,14 @@ class DataUpdateCoordinator(BaseDataUpdateCoordinatorProtocol, Generic[_DataT]):
|
||||
"""Error handling for _async_setup."""
|
||||
try:
|
||||
await self._async_setup()
|
||||
|
||||
except OAuth2TokenRequestError as err:
|
||||
self.last_exception = err
|
||||
if isinstance(err, OAuth2TokenRequestReauthError):
|
||||
self.last_update_success = False
|
||||
# Non-recoverable error
|
||||
raise ConfigEntryAuthFailed from err
|
||||
|
||||
except (
|
||||
TimeoutError,
|
||||
requests.exceptions.Timeout,
|
||||
@@ -423,6 +433,32 @@ class DataUpdateCoordinator(BaseDataUpdateCoordinatorProtocol, Generic[_DataT]):
|
||||
self.logger.debug("Full error:", exc_info=True)
|
||||
self.last_update_success = False
|
||||
|
||||
except (OAuth2TokenRequestError,) as err:
|
||||
self.last_exception = err
|
||||
if isinstance(err, OAuth2TokenRequestReauthError):
|
||||
# Non-recoverable error
|
||||
auth_failed = True
|
||||
if self.last_update_success:
|
||||
if log_failures:
|
||||
self.logger.error(
|
||||
"Authentication failed while fetching %s data: %s",
|
||||
self.name,
|
||||
err,
|
||||
)
|
||||
self.last_update_success = False
|
||||
if raise_on_auth_failed:
|
||||
raise ConfigEntryAuthFailed from err
|
||||
|
||||
if self.config_entry:
|
||||
self.config_entry.async_start_reauth(self.hass)
|
||||
return
|
||||
|
||||
# Recoverable error
|
||||
if self.last_update_success:
|
||||
if log_failures:
|
||||
self.logger.error("Error fetching %s data: %s", self.name, err)
|
||||
self.last_update_success = False
|
||||
|
||||
except (aiohttp.ClientError, requests.exceptions.RequestException) as err:
|
||||
self.last_exception = err
|
||||
if self.last_update_success:
|
||||
|
||||
@@ -1368,7 +1368,7 @@ async def test_dhcp_discovery_with_creds(
|
||||
("status_code", "error_reason"),
|
||||
[
|
||||
(HTTPStatus.UNAUTHORIZED, "oauth_unauthorized"),
|
||||
(HTTPStatus.NOT_FOUND, "oauth_failed"),
|
||||
(HTTPStatus.NOT_FOUND, "oauth_unauthorized"),
|
||||
(HTTPStatus.INTERNAL_SERVER_ERROR, "oauth_failed"),
|
||||
],
|
||||
)
|
||||
|
||||
@@ -7,11 +7,15 @@ import time
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import aiohttp
|
||||
import pytest
|
||||
|
||||
from homeassistant import config_entries, data_entry_flow, setup
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import (
|
||||
OAuth2TokenRequestError,
|
||||
OAuth2TokenRequestReauthError,
|
||||
OAuth2TokenRequestTransientError,
|
||||
)
|
||||
from homeassistant.helpers import config_entry_oauth2_flow
|
||||
from homeassistant.helpers.network import NoURLAvailableError
|
||||
|
||||
@@ -478,7 +482,7 @@ async def test_abort_discovered_multiple(
|
||||
(
|
||||
HTTPStatus.NOT_FOUND,
|
||||
{},
|
||||
"oauth_failed",
|
||||
"oauth_unauthorized",
|
||||
"Token request for oauth2_test failed (unknown): unknown",
|
||||
),
|
||||
(
|
||||
@@ -494,7 +498,7 @@ async def test_abort_discovered_multiple(
|
||||
"error_description": "Request was missing the 'redirect_uri' parameter.",
|
||||
"error_uri": "See the full API docs at https://authorization-server.com/docs/access_token",
|
||||
},
|
||||
"oauth_failed",
|
||||
"oauth_unauthorized",
|
||||
"Token request for oauth2_test failed (invalid_request): Request was missing the",
|
||||
),
|
||||
],
|
||||
@@ -979,16 +983,42 @@ async def test_implementation_provider(hass: HomeAssistant, local_impl) -> None:
|
||||
}
|
||||
|
||||
|
||||
async def test_oauth_session_refresh_failure(
|
||||
@pytest.mark.parametrize(
|
||||
("status_code", "expected_exception"),
|
||||
[
|
||||
(
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
OAuth2TokenRequestReauthError,
|
||||
),
|
||||
(
|
||||
HTTPStatus.TOO_MANY_REQUESTS, # 429, odd one, but treated as transient
|
||||
OAuth2TokenRequestTransientError,
|
||||
),
|
||||
(
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR, # 500 range, so treated as transient
|
||||
OAuth2TokenRequestTransientError,
|
||||
),
|
||||
(
|
||||
600, # Nonsense code, just to hit the generic error branch
|
||||
OAuth2TokenRequestError,
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_oauth_session_refresh_failure_exceptions(
|
||||
hass: HomeAssistant,
|
||||
flow_handler: type[config_entry_oauth2_flow.AbstractOAuth2FlowHandler],
|
||||
local_impl: config_entry_oauth2_flow.LocalOAuth2Implementation,
|
||||
aioclient_mock: AiohttpClientMocker,
|
||||
status_code: int,
|
||||
expected_exception: type[Exception],
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
"""Test the OAuth2 session helper when no refresh is needed."""
|
||||
"""Test OAuth2 session refresh failures raise mapped exceptions."""
|
||||
mock_integration(hass, MockModule(domain=TEST_DOMAIN))
|
||||
|
||||
flow_handler.async_register_implementation(hass, local_impl)
|
||||
|
||||
aioclient_mock.post(TOKEN_URL, status=400)
|
||||
aioclient_mock.post(TOKEN_URL, status=status_code, json={})
|
||||
|
||||
config_entry = MockConfigEntry(
|
||||
domain=TEST_DOMAIN,
|
||||
@@ -1005,11 +1035,18 @@ async def test_oauth_session_refresh_failure(
|
||||
},
|
||||
},
|
||||
)
|
||||
config_entry.add_to_hass(hass)
|
||||
|
||||
session = config_entry_oauth2_flow.OAuth2Session(hass, config_entry, local_impl)
|
||||
with pytest.raises(aiohttp.client_exceptions.ClientResponseError):
|
||||
with (
|
||||
caplog.at_level(logging.WARNING),
|
||||
pytest.raises(expected_exception) as err,
|
||||
):
|
||||
await session.async_request("post", "https://example.com")
|
||||
|
||||
assert err.value.status == status_code
|
||||
assert f"Token request for {TEST_DOMAIN} failed" in caplog.text
|
||||
|
||||
|
||||
async def test_oauth2_without_secret_init(
|
||||
local_impl: config_entry_oauth2_flow.LocalOAuth2Implementation,
|
||||
|
||||
@@ -19,6 +19,8 @@ from homeassistant.exceptions import (
|
||||
ConfigEntryAuthFailed,
|
||||
ConfigEntryError,
|
||||
ConfigEntryNotReady,
|
||||
OAuth2TokenRequestError,
|
||||
OAuth2TokenRequestReauthError,
|
||||
)
|
||||
from homeassistant.helpers import frame, update_coordinator
|
||||
from homeassistant.util.dt import utcnow
|
||||
@@ -322,6 +324,84 @@ async def test_refresh_fail_unknown(
|
||||
assert "Unexpected error fetching test data" in caplog.text
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("exception", "expected_exception"),
|
||||
[(OAuth2TokenRequestReauthError, ConfigEntryAuthFailed)],
|
||||
)
|
||||
async def test_oauth_token_request_refresh_errors(
|
||||
crd: update_coordinator.DataUpdateCoordinator[int],
|
||||
exception: type[OAuth2TokenRequestError],
|
||||
expected_exception: type[Exception],
|
||||
) -> None:
|
||||
"""Test OAuth2 token request errors are mapped during refresh."""
|
||||
request_info = Mock()
|
||||
request_info.real_url = "http://example.com/token"
|
||||
request_info.method = "POST"
|
||||
|
||||
oauth_exception = exception(
|
||||
request_info=request_info,
|
||||
history=(),
|
||||
status=400,
|
||||
message="OAuth 2.0 token refresh failed",
|
||||
domain="domain",
|
||||
)
|
||||
|
||||
crd.update_method = AsyncMock(side_effect=oauth_exception)
|
||||
|
||||
with pytest.raises(expected_exception) as err:
|
||||
# Raise on auth failed, needs to be set
|
||||
await crd._async_refresh(raise_on_auth_failed=True)
|
||||
|
||||
# Check thoroughly the chain
|
||||
assert isinstance(err.value, expected_exception)
|
||||
assert isinstance(err.value.__cause__, exception)
|
||||
assert isinstance(err.value.__cause__, OAuth2TokenRequestError)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("exception", "expected_exception"),
|
||||
[
|
||||
(OAuth2TokenRequestReauthError, ConfigEntryAuthFailed),
|
||||
(OAuth2TokenRequestError, ConfigEntryNotReady),
|
||||
],
|
||||
)
|
||||
async def test_token_request_setup_errors(
|
||||
hass: HomeAssistant,
|
||||
exception: type[OAuth2TokenRequestError],
|
||||
expected_exception: type[Exception],
|
||||
) -> None:
|
||||
"""Test OAuth2 token request errors raised from setup."""
|
||||
entry = MockConfigEntry()
|
||||
entry._async_set_state(
|
||||
hass, config_entries.ConfigEntryState.SETUP_IN_PROGRESS, "For testing, duh"
|
||||
)
|
||||
crd = get_crd(hass, DEFAULT_UPDATE_INTERVAL, entry)
|
||||
|
||||
# Patch the underlying request info to raise ClientResponseError
|
||||
request_info = Mock()
|
||||
request_info.real_url = "http://example.com/token"
|
||||
request_info.method = "POST"
|
||||
oauth_exception = exception(
|
||||
request_info=request_info,
|
||||
history=(),
|
||||
status=400,
|
||||
message="OAuth 2.0 token refresh failed",
|
||||
domain="domain",
|
||||
)
|
||||
|
||||
crd.setup_method = AsyncMock(side_effect=oauth_exception)
|
||||
|
||||
with pytest.raises(expected_exception) as err:
|
||||
await crd.async_config_entry_first_refresh()
|
||||
|
||||
assert crd.last_update_success is False
|
||||
|
||||
# Check thoroughly the chain
|
||||
assert isinstance(err.value, expected_exception)
|
||||
assert isinstance(err.value.__cause__, exception)
|
||||
assert isinstance(err.value.__cause__, OAuth2TokenRequestError)
|
||||
|
||||
|
||||
async def test_refresh_no_update_method(
|
||||
crd: update_coordinator.DataUpdateCoordinator[int],
|
||||
) -> None:
|
||||
|
||||
Reference in New Issue
Block a user