From d1a1183b9a1fe0bc01766e08832c6a74b50b1eb4 Mon Sep 17 00:00:00 2001 From: Erwin Douna Date: Wed, 18 Feb 2026 15:36:53 +0100 Subject: [PATCH] OAuth2.0 token request error handling (#153167) Co-authored-by: Martin Hjelmare --- .../components/homeassistant/strings.json | 9 ++ homeassistant/exceptions.py | 60 ++++++++++++ .../helpers/config_entry_oauth2_flow.py | 98 ++++++++++++++----- homeassistant/helpers/update_coordinator.py | 36 +++++++ tests/components/nest/test_config_flow.py | 2 +- .../helpers/test_config_entry_oauth2_flow.py | 51 ++++++++-- tests/helpers/test_update_coordinator.py | 80 +++++++++++++++ 7 files changed, 303 insertions(+), 33 deletions(-) diff --git a/homeassistant/components/homeassistant/strings.json b/homeassistant/components/homeassistant/strings.json index 95fc7c5aa5b..16cad4835ab 100644 --- a/homeassistant/components/homeassistant/strings.json +++ b/homeassistant/components/homeassistant/strings.json @@ -27,6 +27,15 @@ "multiple_integration_config_errors": { "message": "Failed to process config for integration {domain} due to multiple ({errors}) errors. Check the logs for more information." }, + "oauth2_helper_reauth_required": { + "message": "Credentials are invalid, re-authentication required" + }, + "oauth2_helper_refresh_failed": { + "message": "OAuth2 token refresh failed for {domain}" + }, + "oauth2_helper_refresh_transient": { + "message": "Temporary error refreshing credentials for {domain}, try again later" + }, "platform_component_load_err": { "message": "Platform error: {domain} - {error}." }, diff --git a/homeassistant/exceptions.py b/homeassistant/exceptions.py index 23416480dd7..58d8c22092c 100644 --- a/homeassistant/exceptions.py +++ b/homeassistant/exceptions.py @@ -6,6 +6,9 @@ from collections.abc import Callable, Generator, Sequence from dataclasses import dataclass from typing import TYPE_CHECKING, Any +from aiohttp import ClientResponse, ClientResponseError, RequestInfo +from multidict import MultiMapping + from .util.event_type import EventType if TYPE_CHECKING: @@ -218,6 +221,63 @@ class ConfigEntryAuthFailed(IntegrationError): """Error to indicate that config entry could not authenticate.""" +class OAuth2TokenRequestError(ClientResponseError, HomeAssistantError): + """Error to indicate that the OAuth 2.0 flow could not refresh token.""" + + def __init__( + self, + *, + request_info: RequestInfo, + history: tuple[ClientResponse, ...] = (), + status: int = 0, + message: str = "OAuth 2.0 token refresh failed", + headers: MultiMapping[str] | None = None, + domain: str, + ) -> None: + """Initialize OAuth2RefreshTokenFailed.""" + ClientResponseError.__init__( + self, + request_info=request_info, + history=history, + status=status, + message=message, + headers=headers, + ) + HomeAssistantError.__init__(self) + self.domain = domain + self.translation_domain = "homeassistant" + self.translation_key = "oauth2_helper_refresh_failed" + self.translation_placeholders = {"domain": domain} + self.generate_message = True + + +class OAuth2TokenRequestTransientError(OAuth2TokenRequestError): + """Recoverable error to indicate flow could not refresh token.""" + + def __init__(self, *, domain: str, **kwargs: Any) -> None: + """Initialize OAuth2RefreshTokenTransientError.""" + super().__init__(domain=domain, **kwargs) + self.translation_domain = "homeassistant" + self.translation_key = "oauth2_helper_refresh_transient" + self.translation_placeholders = {"domain": domain} + self.generate_message = True + + +class OAuth2TokenRequestReauthError(OAuth2TokenRequestError): + """Non recoverable error to indicate the flow could not refresh token. + + Re-authentication is required. + """ + + def __init__(self, *, domain: str, **kwargs: Any) -> None: + """Initialize OAuth2RefreshTokenReauthError.""" + super().__init__(domain=domain, **kwargs) + self.translation_domain = "homeassistant" + self.translation_key = "oauth2_helper_reauth_required" + self.translation_placeholders = {"domain": domain} + self.generate_message = True + + class InvalidStateError(HomeAssistantError): """When an invalid state is encountered.""" diff --git a/homeassistant/helpers/config_entry_oauth2_flow.py b/homeassistant/helpers/config_entry_oauth2_flow.py index d7fc606b591..c25c609dd06 100644 --- a/homeassistant/helpers/config_entry_oauth2_flow.py +++ b/homeassistant/helpers/config_entry_oauth2_flow.py @@ -29,7 +29,12 @@ from yarl import URL from homeassistant import config_entries from homeassistant.core import HomeAssistant, callback -from homeassistant.exceptions import HomeAssistantError +from homeassistant.exceptions import ( + HomeAssistantError, + OAuth2TokenRequestError, + OAuth2TokenRequestReauthError, + OAuth2TokenRequestTransientError, +) from homeassistant.loader import async_get_application_credentials from homeassistant.util.hass_dict import HassKey @@ -56,6 +61,7 @@ AUTH_CALLBACK_PATH = "/auth/external/callback" HEADER_FRONTEND_BASE = "HA-Frontend-Base" MY_AUTH_CALLBACK_PATH = "https://my.home-assistant.io/redirect/oauth" + CLOCK_OUT_OF_SYNC_MAX_SEC = 20 OAUTH_AUTHORIZE_URL_TIMEOUT_SEC = 30 @@ -134,7 +140,10 @@ class AbstractOAuth2Implementation(ABC): @abstractmethod async def _async_refresh_token(self, token: dict) -> dict: - """Refresh a token.""" + """Refresh a token. + + Should raise OAuth2TokenRequestError on token refresh failure. + """ class LocalOAuth2Implementation(AbstractOAuth2Implementation): @@ -211,7 +220,8 @@ class LocalOAuth2Implementation(AbstractOAuth2Implementation): return await self._token_request(request_data) async def _async_refresh_token(self, token: dict) -> dict: - """Refresh tokens.""" + """Refresh a token.""" + new_token = await self._token_request( { "grant_type": "refresh_token", @@ -219,33 +229,71 @@ class LocalOAuth2Implementation(AbstractOAuth2Implementation): "refresh_token": token["refresh_token"], } ) + return {**token, **new_token} async def _token_request(self, data: dict) -> dict: - """Make a token request.""" + """Make a token request. + + Raises OAuth2TokenRequestError on token request failure. + """ session = async_get_clientsession(self.hass) data["client_id"] = self.client_id - if self.client_secret: data["client_secret"] = self.client_secret _LOGGER.debug("Sending token request to %s", self.token_url) - resp = await session.post(self.token_url, data=data) - if resp.status >= 400: - try: - error_response = await resp.json() - except ClientError, JSONDecodeError: - error_response = {} - error_code = error_response.get("error", "unknown") - error_description = error_response.get("error_description", "unknown error") - _LOGGER.error( - "Token request for %s failed (%s): %s", - self.domain, - error_code, - error_description, - ) - resp.raise_for_status() + + try: + resp = await session.post(self.token_url, data=data) + if resp.status >= 400: + try: + error_response = await resp.json() + except ClientError, JSONDecodeError: + error_response = {} + error_code = error_response.get("error", "unknown") + error_description = error_response.get( + "error_description", "unknown error" + ) + _LOGGER.error( + "Token request for %s failed (%s): %s", + self.domain, + error_code, + error_description, + ) + resp.raise_for_status() + except ClientResponseError as err: + if err.status == HTTPStatus.TOO_MANY_REQUESTS or 500 <= err.status <= 599: + # Recoverable error + raise OAuth2TokenRequestTransientError( + request_info=err.request_info, + history=err.history, + status=err.status, + message=err.message, + headers=err.headers, + domain=self._domain, + ) from err + if 400 <= err.status <= 499: + # Non-recoverable error + raise OAuth2TokenRequestReauthError( + request_info=err.request_info, + history=err.history, + status=err.status, + message=err.message, + headers=err.headers, + domain=self._domain, + ) from err + + raise OAuth2TokenRequestError( + request_info=err.request_info, + history=err.history, + status=err.status, + message=err.message, + headers=err.headers, + domain=self._domain, + ) from err + return cast(dict, await resp.json()) @@ -458,12 +506,12 @@ class AbstractOAuth2FlowHandler(config_entries.ConfigFlow, metaclass=ABCMeta): except TimeoutError as err: _LOGGER.error("Timeout resolving OAuth token: %s", err) return self.async_abort(reason="oauth_timeout") - except (ClientResponseError, ClientError) as err: + except ( + OAuth2TokenRequestError, + ClientError, + ) as err: _LOGGER.error("Error resolving OAuth token: %s", err) - if ( - isinstance(err, ClientResponseError) - and err.status == HTTPStatus.UNAUTHORIZED - ): + if isinstance(err, OAuth2TokenRequestReauthError): return self.async_abort(reason="oauth_unauthorized") return self.async_abort(reason="oauth_failed") diff --git a/homeassistant/helpers/update_coordinator.py b/homeassistant/helpers/update_coordinator.py index 0bbea1ac6f4..7bed9ca1f28 100644 --- a/homeassistant/helpers/update_coordinator.py +++ b/homeassistant/helpers/update_coordinator.py @@ -25,6 +25,8 @@ from homeassistant.exceptions import ( ConfigEntryError, ConfigEntryNotReady, HomeAssistantError, + OAuth2TokenRequestError, + OAuth2TokenRequestReauthError, ) from homeassistant.util.dt import utcnow @@ -352,6 +354,14 @@ class DataUpdateCoordinator(BaseDataUpdateCoordinatorProtocol, Generic[_DataT]): """Error handling for _async_setup.""" try: await self._async_setup() + + except OAuth2TokenRequestError as err: + self.last_exception = err + if isinstance(err, OAuth2TokenRequestReauthError): + self.last_update_success = False + # Non-recoverable error + raise ConfigEntryAuthFailed from err + except ( TimeoutError, requests.exceptions.Timeout, @@ -423,6 +433,32 @@ class DataUpdateCoordinator(BaseDataUpdateCoordinatorProtocol, Generic[_DataT]): self.logger.debug("Full error:", exc_info=True) self.last_update_success = False + except (OAuth2TokenRequestError,) as err: + self.last_exception = err + if isinstance(err, OAuth2TokenRequestReauthError): + # Non-recoverable error + auth_failed = True + if self.last_update_success: + if log_failures: + self.logger.error( + "Authentication failed while fetching %s data: %s", + self.name, + err, + ) + self.last_update_success = False + if raise_on_auth_failed: + raise ConfigEntryAuthFailed from err + + if self.config_entry: + self.config_entry.async_start_reauth(self.hass) + return + + # Recoverable error + if self.last_update_success: + if log_failures: + self.logger.error("Error fetching %s data: %s", self.name, err) + self.last_update_success = False + except (aiohttp.ClientError, requests.exceptions.RequestException) as err: self.last_exception = err if self.last_update_success: diff --git a/tests/components/nest/test_config_flow.py b/tests/components/nest/test_config_flow.py index 24b12b047bf..9ff7713e9ed 100644 --- a/tests/components/nest/test_config_flow.py +++ b/tests/components/nest/test_config_flow.py @@ -1368,7 +1368,7 @@ async def test_dhcp_discovery_with_creds( ("status_code", "error_reason"), [ (HTTPStatus.UNAUTHORIZED, "oauth_unauthorized"), - (HTTPStatus.NOT_FOUND, "oauth_failed"), + (HTTPStatus.NOT_FOUND, "oauth_unauthorized"), (HTTPStatus.INTERNAL_SERVER_ERROR, "oauth_failed"), ], ) diff --git a/tests/helpers/test_config_entry_oauth2_flow.py b/tests/helpers/test_config_entry_oauth2_flow.py index dc56910785c..0ba5e9543ae 100644 --- a/tests/helpers/test_config_entry_oauth2_flow.py +++ b/tests/helpers/test_config_entry_oauth2_flow.py @@ -7,11 +7,15 @@ import time from typing import Any from unittest.mock import AsyncMock, patch -import aiohttp import pytest from homeassistant import config_entries, data_entry_flow, setup from homeassistant.core import HomeAssistant +from homeassistant.exceptions import ( + OAuth2TokenRequestError, + OAuth2TokenRequestReauthError, + OAuth2TokenRequestTransientError, +) from homeassistant.helpers import config_entry_oauth2_flow from homeassistant.helpers.network import NoURLAvailableError @@ -478,7 +482,7 @@ async def test_abort_discovered_multiple( ( HTTPStatus.NOT_FOUND, {}, - "oauth_failed", + "oauth_unauthorized", "Token request for oauth2_test failed (unknown): unknown", ), ( @@ -494,7 +498,7 @@ async def test_abort_discovered_multiple( "error_description": "Request was missing the 'redirect_uri' parameter.", "error_uri": "See the full API docs at https://authorization-server.com/docs/access_token", }, - "oauth_failed", + "oauth_unauthorized", "Token request for oauth2_test failed (invalid_request): Request was missing the", ), ], @@ -979,16 +983,42 @@ async def test_implementation_provider(hass: HomeAssistant, local_impl) -> None: } -async def test_oauth_session_refresh_failure( +@pytest.mark.parametrize( + ("status_code", "expected_exception"), + [ + ( + HTTPStatus.BAD_REQUEST, + OAuth2TokenRequestReauthError, + ), + ( + HTTPStatus.TOO_MANY_REQUESTS, # 429, odd one, but treated as transient + OAuth2TokenRequestTransientError, + ), + ( + HTTPStatus.INTERNAL_SERVER_ERROR, # 500 range, so treated as transient + OAuth2TokenRequestTransientError, + ), + ( + 600, # Nonsense code, just to hit the generic error branch + OAuth2TokenRequestError, + ), + ], +) +async def test_oauth_session_refresh_failure_exceptions( hass: HomeAssistant, flow_handler: type[config_entry_oauth2_flow.AbstractOAuth2FlowHandler], local_impl: config_entry_oauth2_flow.LocalOAuth2Implementation, aioclient_mock: AiohttpClientMocker, + status_code: int, + expected_exception: type[Exception], + caplog: pytest.LogCaptureFixture, ) -> None: - """Test the OAuth2 session helper when no refresh is needed.""" + """Test OAuth2 session refresh failures raise mapped exceptions.""" + mock_integration(hass, MockModule(domain=TEST_DOMAIN)) + flow_handler.async_register_implementation(hass, local_impl) - aioclient_mock.post(TOKEN_URL, status=400) + aioclient_mock.post(TOKEN_URL, status=status_code, json={}) config_entry = MockConfigEntry( domain=TEST_DOMAIN, @@ -1005,11 +1035,18 @@ async def test_oauth_session_refresh_failure( }, }, ) + config_entry.add_to_hass(hass) session = config_entry_oauth2_flow.OAuth2Session(hass, config_entry, local_impl) - with pytest.raises(aiohttp.client_exceptions.ClientResponseError): + with ( + caplog.at_level(logging.WARNING), + pytest.raises(expected_exception) as err, + ): await session.async_request("post", "https://example.com") + assert err.value.status == status_code + assert f"Token request for {TEST_DOMAIN} failed" in caplog.text + async def test_oauth2_without_secret_init( local_impl: config_entry_oauth2_flow.LocalOAuth2Implementation, diff --git a/tests/helpers/test_update_coordinator.py b/tests/helpers/test_update_coordinator.py index 612b39293a2..77a3c90ee0e 100644 --- a/tests/helpers/test_update_coordinator.py +++ b/tests/helpers/test_update_coordinator.py @@ -19,6 +19,8 @@ from homeassistant.exceptions import ( ConfigEntryAuthFailed, ConfigEntryError, ConfigEntryNotReady, + OAuth2TokenRequestError, + OAuth2TokenRequestReauthError, ) from homeassistant.helpers import frame, update_coordinator from homeassistant.util.dt import utcnow @@ -322,6 +324,84 @@ async def test_refresh_fail_unknown( assert "Unexpected error fetching test data" in caplog.text +@pytest.mark.parametrize( + ("exception", "expected_exception"), + [(OAuth2TokenRequestReauthError, ConfigEntryAuthFailed)], +) +async def test_oauth_token_request_refresh_errors( + crd: update_coordinator.DataUpdateCoordinator[int], + exception: type[OAuth2TokenRequestError], + expected_exception: type[Exception], +) -> None: + """Test OAuth2 token request errors are mapped during refresh.""" + request_info = Mock() + request_info.real_url = "http://example.com/token" + request_info.method = "POST" + + oauth_exception = exception( + request_info=request_info, + history=(), + status=400, + message="OAuth 2.0 token refresh failed", + domain="domain", + ) + + crd.update_method = AsyncMock(side_effect=oauth_exception) + + with pytest.raises(expected_exception) as err: + # Raise on auth failed, needs to be set + await crd._async_refresh(raise_on_auth_failed=True) + + # Check thoroughly the chain + assert isinstance(err.value, expected_exception) + assert isinstance(err.value.__cause__, exception) + assert isinstance(err.value.__cause__, OAuth2TokenRequestError) + + +@pytest.mark.parametrize( + ("exception", "expected_exception"), + [ + (OAuth2TokenRequestReauthError, ConfigEntryAuthFailed), + (OAuth2TokenRequestError, ConfigEntryNotReady), + ], +) +async def test_token_request_setup_errors( + hass: HomeAssistant, + exception: type[OAuth2TokenRequestError], + expected_exception: type[Exception], +) -> None: + """Test OAuth2 token request errors raised from setup.""" + entry = MockConfigEntry() + entry._async_set_state( + hass, config_entries.ConfigEntryState.SETUP_IN_PROGRESS, "For testing, duh" + ) + crd = get_crd(hass, DEFAULT_UPDATE_INTERVAL, entry) + + # Patch the underlying request info to raise ClientResponseError + request_info = Mock() + request_info.real_url = "http://example.com/token" + request_info.method = "POST" + oauth_exception = exception( + request_info=request_info, + history=(), + status=400, + message="OAuth 2.0 token refresh failed", + domain="domain", + ) + + crd.setup_method = AsyncMock(side_effect=oauth_exception) + + with pytest.raises(expected_exception) as err: + await crd.async_config_entry_first_refresh() + + assert crd.last_update_success is False + + # Check thoroughly the chain + assert isinstance(err.value, expected_exception) + assert isinstance(err.value.__cause__, exception) + assert isinstance(err.value.__cause__, OAuth2TokenRequestError) + + async def test_refresh_no_update_method( crd: update_coordinator.DataUpdateCoordinator[int], ) -> None: