OAuth2.0 token request error handling (#153167)

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
Erwin Douna
2026-02-18 15:36:53 +01:00
committed by GitHub
parent 4dcfd5fb91
commit d1a1183b9a
7 changed files with 303 additions and 33 deletions

View File

@@ -27,6 +27,15 @@
"multiple_integration_config_errors": {
"message": "Failed to process config for integration {domain} due to multiple ({errors}) errors. Check the logs for more information."
},
"oauth2_helper_reauth_required": {
"message": "Credentials are invalid, re-authentication required"
},
"oauth2_helper_refresh_failed": {
"message": "OAuth2 token refresh failed for {domain}"
},
"oauth2_helper_refresh_transient": {
"message": "Temporary error refreshing credentials for {domain}, try again later"
},
"platform_component_load_err": {
"message": "Platform error: {domain} - {error}."
},

View File

@@ -6,6 +6,9 @@ from collections.abc import Callable, Generator, Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
from aiohttp import ClientResponse, ClientResponseError, RequestInfo
from multidict import MultiMapping
from .util.event_type import EventType
if TYPE_CHECKING:
@@ -218,6 +221,63 @@ class ConfigEntryAuthFailed(IntegrationError):
"""Error to indicate that config entry could not authenticate."""
class OAuth2TokenRequestError(ClientResponseError, HomeAssistantError):
"""Error to indicate that the OAuth 2.0 flow could not refresh token."""
def __init__(
self,
*,
request_info: RequestInfo,
history: tuple[ClientResponse, ...] = (),
status: int = 0,
message: str = "OAuth 2.0 token refresh failed",
headers: MultiMapping[str] | None = None,
domain: str,
) -> None:
"""Initialize OAuth2RefreshTokenFailed."""
ClientResponseError.__init__(
self,
request_info=request_info,
history=history,
status=status,
message=message,
headers=headers,
)
HomeAssistantError.__init__(self)
self.domain = domain
self.translation_domain = "homeassistant"
self.translation_key = "oauth2_helper_refresh_failed"
self.translation_placeholders = {"domain": domain}
self.generate_message = True
class OAuth2TokenRequestTransientError(OAuth2TokenRequestError):
"""Recoverable error to indicate flow could not refresh token."""
def __init__(self, *, domain: str, **kwargs: Any) -> None:
"""Initialize OAuth2RefreshTokenTransientError."""
super().__init__(domain=domain, **kwargs)
self.translation_domain = "homeassistant"
self.translation_key = "oauth2_helper_refresh_transient"
self.translation_placeholders = {"domain": domain}
self.generate_message = True
class OAuth2TokenRequestReauthError(OAuth2TokenRequestError):
"""Non recoverable error to indicate the flow could not refresh token.
Re-authentication is required.
"""
def __init__(self, *, domain: str, **kwargs: Any) -> None:
"""Initialize OAuth2RefreshTokenReauthError."""
super().__init__(domain=domain, **kwargs)
self.translation_domain = "homeassistant"
self.translation_key = "oauth2_helper_reauth_required"
self.translation_placeholders = {"domain": domain}
self.generate_message = True
class InvalidStateError(HomeAssistantError):
"""When an invalid state is encountered."""

View File

@@ -29,7 +29,12 @@ from yarl import URL
from homeassistant import config_entries
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.exceptions import (
HomeAssistantError,
OAuth2TokenRequestError,
OAuth2TokenRequestReauthError,
OAuth2TokenRequestTransientError,
)
from homeassistant.loader import async_get_application_credentials
from homeassistant.util.hass_dict import HassKey
@@ -56,6 +61,7 @@ AUTH_CALLBACK_PATH = "/auth/external/callback"
HEADER_FRONTEND_BASE = "HA-Frontend-Base"
MY_AUTH_CALLBACK_PATH = "https://my.home-assistant.io/redirect/oauth"
CLOCK_OUT_OF_SYNC_MAX_SEC = 20
OAUTH_AUTHORIZE_URL_TIMEOUT_SEC = 30
@@ -134,7 +140,10 @@ class AbstractOAuth2Implementation(ABC):
@abstractmethod
async def _async_refresh_token(self, token: dict) -> dict:
"""Refresh a token."""
"""Refresh a token.
Should raise OAuth2TokenRequestError on token refresh failure.
"""
class LocalOAuth2Implementation(AbstractOAuth2Implementation):
@@ -211,7 +220,8 @@ class LocalOAuth2Implementation(AbstractOAuth2Implementation):
return await self._token_request(request_data)
async def _async_refresh_token(self, token: dict) -> dict:
"""Refresh tokens."""
"""Refresh a token."""
new_token = await self._token_request(
{
"grant_type": "refresh_token",
@@ -219,33 +229,71 @@ class LocalOAuth2Implementation(AbstractOAuth2Implementation):
"refresh_token": token["refresh_token"],
}
)
return {**token, **new_token}
async def _token_request(self, data: dict) -> dict:
"""Make a token request."""
"""Make a token request.
Raises OAuth2TokenRequestError on token request failure.
"""
session = async_get_clientsession(self.hass)
data["client_id"] = self.client_id
if self.client_secret:
data["client_secret"] = self.client_secret
_LOGGER.debug("Sending token request to %s", self.token_url)
resp = await session.post(self.token_url, data=data)
if resp.status >= 400:
try:
error_response = await resp.json()
except ClientError, JSONDecodeError:
error_response = {}
error_code = error_response.get("error", "unknown")
error_description = error_response.get("error_description", "unknown error")
_LOGGER.error(
"Token request for %s failed (%s): %s",
self.domain,
error_code,
error_description,
)
resp.raise_for_status()
try:
resp = await session.post(self.token_url, data=data)
if resp.status >= 400:
try:
error_response = await resp.json()
except ClientError, JSONDecodeError:
error_response = {}
error_code = error_response.get("error", "unknown")
error_description = error_response.get(
"error_description", "unknown error"
)
_LOGGER.error(
"Token request for %s failed (%s): %s",
self.domain,
error_code,
error_description,
)
resp.raise_for_status()
except ClientResponseError as err:
if err.status == HTTPStatus.TOO_MANY_REQUESTS or 500 <= err.status <= 599:
# Recoverable error
raise OAuth2TokenRequestTransientError(
request_info=err.request_info,
history=err.history,
status=err.status,
message=err.message,
headers=err.headers,
domain=self._domain,
) from err
if 400 <= err.status <= 499:
# Non-recoverable error
raise OAuth2TokenRequestReauthError(
request_info=err.request_info,
history=err.history,
status=err.status,
message=err.message,
headers=err.headers,
domain=self._domain,
) from err
raise OAuth2TokenRequestError(
request_info=err.request_info,
history=err.history,
status=err.status,
message=err.message,
headers=err.headers,
domain=self._domain,
) from err
return cast(dict, await resp.json())
@@ -458,12 +506,12 @@ class AbstractOAuth2FlowHandler(config_entries.ConfigFlow, metaclass=ABCMeta):
except TimeoutError as err:
_LOGGER.error("Timeout resolving OAuth token: %s", err)
return self.async_abort(reason="oauth_timeout")
except (ClientResponseError, ClientError) as err:
except (
OAuth2TokenRequestError,
ClientError,
) as err:
_LOGGER.error("Error resolving OAuth token: %s", err)
if (
isinstance(err, ClientResponseError)
and err.status == HTTPStatus.UNAUTHORIZED
):
if isinstance(err, OAuth2TokenRequestReauthError):
return self.async_abort(reason="oauth_unauthorized")
return self.async_abort(reason="oauth_failed")

View File

@@ -25,6 +25,8 @@ from homeassistant.exceptions import (
ConfigEntryError,
ConfigEntryNotReady,
HomeAssistantError,
OAuth2TokenRequestError,
OAuth2TokenRequestReauthError,
)
from homeassistant.util.dt import utcnow
@@ -352,6 +354,14 @@ class DataUpdateCoordinator(BaseDataUpdateCoordinatorProtocol, Generic[_DataT]):
"""Error handling for _async_setup."""
try:
await self._async_setup()
except OAuth2TokenRequestError as err:
self.last_exception = err
if isinstance(err, OAuth2TokenRequestReauthError):
self.last_update_success = False
# Non-recoverable error
raise ConfigEntryAuthFailed from err
except (
TimeoutError,
requests.exceptions.Timeout,
@@ -423,6 +433,32 @@ class DataUpdateCoordinator(BaseDataUpdateCoordinatorProtocol, Generic[_DataT]):
self.logger.debug("Full error:", exc_info=True)
self.last_update_success = False
except (OAuth2TokenRequestError,) as err:
self.last_exception = err
if isinstance(err, OAuth2TokenRequestReauthError):
# Non-recoverable error
auth_failed = True
if self.last_update_success:
if log_failures:
self.logger.error(
"Authentication failed while fetching %s data: %s",
self.name,
err,
)
self.last_update_success = False
if raise_on_auth_failed:
raise ConfigEntryAuthFailed from err
if self.config_entry:
self.config_entry.async_start_reauth(self.hass)
return
# Recoverable error
if self.last_update_success:
if log_failures:
self.logger.error("Error fetching %s data: %s", self.name, err)
self.last_update_success = False
except (aiohttp.ClientError, requests.exceptions.RequestException) as err:
self.last_exception = err
if self.last_update_success:

View File

@@ -1368,7 +1368,7 @@ async def test_dhcp_discovery_with_creds(
("status_code", "error_reason"),
[
(HTTPStatus.UNAUTHORIZED, "oauth_unauthorized"),
(HTTPStatus.NOT_FOUND, "oauth_failed"),
(HTTPStatus.NOT_FOUND, "oauth_unauthorized"),
(HTTPStatus.INTERNAL_SERVER_ERROR, "oauth_failed"),
],
)

View File

@@ -7,11 +7,15 @@ import time
from typing import Any
from unittest.mock import AsyncMock, patch
import aiohttp
import pytest
from homeassistant import config_entries, data_entry_flow, setup
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import (
OAuth2TokenRequestError,
OAuth2TokenRequestReauthError,
OAuth2TokenRequestTransientError,
)
from homeassistant.helpers import config_entry_oauth2_flow
from homeassistant.helpers.network import NoURLAvailableError
@@ -478,7 +482,7 @@ async def test_abort_discovered_multiple(
(
HTTPStatus.NOT_FOUND,
{},
"oauth_failed",
"oauth_unauthorized",
"Token request for oauth2_test failed (unknown): unknown",
),
(
@@ -494,7 +498,7 @@ async def test_abort_discovered_multiple(
"error_description": "Request was missing the 'redirect_uri' parameter.",
"error_uri": "See the full API docs at https://authorization-server.com/docs/access_token",
},
"oauth_failed",
"oauth_unauthorized",
"Token request for oauth2_test failed (invalid_request): Request was missing the",
),
],
@@ -979,16 +983,42 @@ async def test_implementation_provider(hass: HomeAssistant, local_impl) -> None:
}
async def test_oauth_session_refresh_failure(
@pytest.mark.parametrize(
("status_code", "expected_exception"),
[
(
HTTPStatus.BAD_REQUEST,
OAuth2TokenRequestReauthError,
),
(
HTTPStatus.TOO_MANY_REQUESTS, # 429, odd one, but treated as transient
OAuth2TokenRequestTransientError,
),
(
HTTPStatus.INTERNAL_SERVER_ERROR, # 500 range, so treated as transient
OAuth2TokenRequestTransientError,
),
(
600, # Nonsense code, just to hit the generic error branch
OAuth2TokenRequestError,
),
],
)
async def test_oauth_session_refresh_failure_exceptions(
hass: HomeAssistant,
flow_handler: type[config_entry_oauth2_flow.AbstractOAuth2FlowHandler],
local_impl: config_entry_oauth2_flow.LocalOAuth2Implementation,
aioclient_mock: AiohttpClientMocker,
status_code: int,
expected_exception: type[Exception],
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test the OAuth2 session helper when no refresh is needed."""
"""Test OAuth2 session refresh failures raise mapped exceptions."""
mock_integration(hass, MockModule(domain=TEST_DOMAIN))
flow_handler.async_register_implementation(hass, local_impl)
aioclient_mock.post(TOKEN_URL, status=400)
aioclient_mock.post(TOKEN_URL, status=status_code, json={})
config_entry = MockConfigEntry(
domain=TEST_DOMAIN,
@@ -1005,11 +1035,18 @@ async def test_oauth_session_refresh_failure(
},
},
)
config_entry.add_to_hass(hass)
session = config_entry_oauth2_flow.OAuth2Session(hass, config_entry, local_impl)
with pytest.raises(aiohttp.client_exceptions.ClientResponseError):
with (
caplog.at_level(logging.WARNING),
pytest.raises(expected_exception) as err,
):
await session.async_request("post", "https://example.com")
assert err.value.status == status_code
assert f"Token request for {TEST_DOMAIN} failed" in caplog.text
async def test_oauth2_without_secret_init(
local_impl: config_entry_oauth2_flow.LocalOAuth2Implementation,

View File

@@ -19,6 +19,8 @@ from homeassistant.exceptions import (
ConfigEntryAuthFailed,
ConfigEntryError,
ConfigEntryNotReady,
OAuth2TokenRequestError,
OAuth2TokenRequestReauthError,
)
from homeassistant.helpers import frame, update_coordinator
from homeassistant.util.dt import utcnow
@@ -322,6 +324,84 @@ async def test_refresh_fail_unknown(
assert "Unexpected error fetching test data" in caplog.text
@pytest.mark.parametrize(
("exception", "expected_exception"),
[(OAuth2TokenRequestReauthError, ConfigEntryAuthFailed)],
)
async def test_oauth_token_request_refresh_errors(
crd: update_coordinator.DataUpdateCoordinator[int],
exception: type[OAuth2TokenRequestError],
expected_exception: type[Exception],
) -> None:
"""Test OAuth2 token request errors are mapped during refresh."""
request_info = Mock()
request_info.real_url = "http://example.com/token"
request_info.method = "POST"
oauth_exception = exception(
request_info=request_info,
history=(),
status=400,
message="OAuth 2.0 token refresh failed",
domain="domain",
)
crd.update_method = AsyncMock(side_effect=oauth_exception)
with pytest.raises(expected_exception) as err:
# Raise on auth failed, needs to be set
await crd._async_refresh(raise_on_auth_failed=True)
# Check thoroughly the chain
assert isinstance(err.value, expected_exception)
assert isinstance(err.value.__cause__, exception)
assert isinstance(err.value.__cause__, OAuth2TokenRequestError)
@pytest.mark.parametrize(
("exception", "expected_exception"),
[
(OAuth2TokenRequestReauthError, ConfigEntryAuthFailed),
(OAuth2TokenRequestError, ConfigEntryNotReady),
],
)
async def test_token_request_setup_errors(
hass: HomeAssistant,
exception: type[OAuth2TokenRequestError],
expected_exception: type[Exception],
) -> None:
"""Test OAuth2 token request errors raised from setup."""
entry = MockConfigEntry()
entry._async_set_state(
hass, config_entries.ConfigEntryState.SETUP_IN_PROGRESS, "For testing, duh"
)
crd = get_crd(hass, DEFAULT_UPDATE_INTERVAL, entry)
# Patch the underlying request info to raise ClientResponseError
request_info = Mock()
request_info.real_url = "http://example.com/token"
request_info.method = "POST"
oauth_exception = exception(
request_info=request_info,
history=(),
status=400,
message="OAuth 2.0 token refresh failed",
domain="domain",
)
crd.setup_method = AsyncMock(side_effect=oauth_exception)
with pytest.raises(expected_exception) as err:
await crd.async_config_entry_first_refresh()
assert crd.last_update_success is False
# Check thoroughly the chain
assert isinstance(err.value, expected_exception)
assert isinstance(err.value.__cause__, exception)
assert isinstance(err.value.__cause__, OAuth2TokenRequestError)
async def test_refresh_no_update_method(
crd: update_coordinator.DataUpdateCoordinator[int],
) -> None: