OAuth2.0 token request error handling (#153167)

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
Erwin Douna
2026-02-18 15:36:53 +01:00
committed by GitHub
parent 4dcfd5fb91
commit d1a1183b9a
7 changed files with 303 additions and 33 deletions

View File

@@ -1368,7 +1368,7 @@ async def test_dhcp_discovery_with_creds(
("status_code", "error_reason"),
[
(HTTPStatus.UNAUTHORIZED, "oauth_unauthorized"),
(HTTPStatus.NOT_FOUND, "oauth_failed"),
(HTTPStatus.NOT_FOUND, "oauth_unauthorized"),
(HTTPStatus.INTERNAL_SERVER_ERROR, "oauth_failed"),
],
)

View File

@@ -7,11 +7,15 @@ import time
from typing import Any
from unittest.mock import AsyncMock, patch
import aiohttp
import pytest
from homeassistant import config_entries, data_entry_flow, setup
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import (
OAuth2TokenRequestError,
OAuth2TokenRequestReauthError,
OAuth2TokenRequestTransientError,
)
from homeassistant.helpers import config_entry_oauth2_flow
from homeassistant.helpers.network import NoURLAvailableError
@@ -478,7 +482,7 @@ async def test_abort_discovered_multiple(
(
HTTPStatus.NOT_FOUND,
{},
"oauth_failed",
"oauth_unauthorized",
"Token request for oauth2_test failed (unknown): unknown",
),
(
@@ -494,7 +498,7 @@ async def test_abort_discovered_multiple(
"error_description": "Request was missing the 'redirect_uri' parameter.",
"error_uri": "See the full API docs at https://authorization-server.com/docs/access_token",
},
"oauth_failed",
"oauth_unauthorized",
"Token request for oauth2_test failed (invalid_request): Request was missing the",
),
],
@@ -979,16 +983,42 @@ async def test_implementation_provider(hass: HomeAssistant, local_impl) -> None:
}
async def test_oauth_session_refresh_failure(
@pytest.mark.parametrize(
("status_code", "expected_exception"),
[
(
HTTPStatus.BAD_REQUEST,
OAuth2TokenRequestReauthError,
),
(
HTTPStatus.TOO_MANY_REQUESTS, # 429, odd one, but treated as transient
OAuth2TokenRequestTransientError,
),
(
HTTPStatus.INTERNAL_SERVER_ERROR, # 500 range, so treated as transient
OAuth2TokenRequestTransientError,
),
(
600, # Nonsense code, just to hit the generic error branch
OAuth2TokenRequestError,
),
],
)
async def test_oauth_session_refresh_failure_exceptions(
hass: HomeAssistant,
flow_handler: type[config_entry_oauth2_flow.AbstractOAuth2FlowHandler],
local_impl: config_entry_oauth2_flow.LocalOAuth2Implementation,
aioclient_mock: AiohttpClientMocker,
status_code: int,
expected_exception: type[Exception],
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test the OAuth2 session helper when no refresh is needed."""
"""Test OAuth2 session refresh failures raise mapped exceptions."""
mock_integration(hass, MockModule(domain=TEST_DOMAIN))
flow_handler.async_register_implementation(hass, local_impl)
aioclient_mock.post(TOKEN_URL, status=400)
aioclient_mock.post(TOKEN_URL, status=status_code, json={})
config_entry = MockConfigEntry(
domain=TEST_DOMAIN,
@@ -1005,11 +1035,18 @@ async def test_oauth_session_refresh_failure(
},
},
)
config_entry.add_to_hass(hass)
session = config_entry_oauth2_flow.OAuth2Session(hass, config_entry, local_impl)
with pytest.raises(aiohttp.client_exceptions.ClientResponseError):
with (
caplog.at_level(logging.WARNING),
pytest.raises(expected_exception) as err,
):
await session.async_request("post", "https://example.com")
assert err.value.status == status_code
assert f"Token request for {TEST_DOMAIN} failed" in caplog.text
async def test_oauth2_without_secret_init(
local_impl: config_entry_oauth2_flow.LocalOAuth2Implementation,

View File

@@ -19,6 +19,8 @@ from homeassistant.exceptions import (
ConfigEntryAuthFailed,
ConfigEntryError,
ConfigEntryNotReady,
OAuth2TokenRequestError,
OAuth2TokenRequestReauthError,
)
from homeassistant.helpers import frame, update_coordinator
from homeassistant.util.dt import utcnow
@@ -322,6 +324,84 @@ async def test_refresh_fail_unknown(
assert "Unexpected error fetching test data" in caplog.text
@pytest.mark.parametrize(
("exception", "expected_exception"),
[(OAuth2TokenRequestReauthError, ConfigEntryAuthFailed)],
)
async def test_oauth_token_request_refresh_errors(
crd: update_coordinator.DataUpdateCoordinator[int],
exception: type[OAuth2TokenRequestError],
expected_exception: type[Exception],
) -> None:
"""Test OAuth2 token request errors are mapped during refresh."""
request_info = Mock()
request_info.real_url = "http://example.com/token"
request_info.method = "POST"
oauth_exception = exception(
request_info=request_info,
history=(),
status=400,
message="OAuth 2.0 token refresh failed",
domain="domain",
)
crd.update_method = AsyncMock(side_effect=oauth_exception)
with pytest.raises(expected_exception) as err:
# Raise on auth failed, needs to be set
await crd._async_refresh(raise_on_auth_failed=True)
# Check thoroughly the chain
assert isinstance(err.value, expected_exception)
assert isinstance(err.value.__cause__, exception)
assert isinstance(err.value.__cause__, OAuth2TokenRequestError)
@pytest.mark.parametrize(
("exception", "expected_exception"),
[
(OAuth2TokenRequestReauthError, ConfigEntryAuthFailed),
(OAuth2TokenRequestError, ConfigEntryNotReady),
],
)
async def test_token_request_setup_errors(
hass: HomeAssistant,
exception: type[OAuth2TokenRequestError],
expected_exception: type[Exception],
) -> None:
"""Test OAuth2 token request errors raised from setup."""
entry = MockConfigEntry()
entry._async_set_state(
hass, config_entries.ConfigEntryState.SETUP_IN_PROGRESS, "For testing, duh"
)
crd = get_crd(hass, DEFAULT_UPDATE_INTERVAL, entry)
# Patch the underlying request info to raise ClientResponseError
request_info = Mock()
request_info.real_url = "http://example.com/token"
request_info.method = "POST"
oauth_exception = exception(
request_info=request_info,
history=(),
status=400,
message="OAuth 2.0 token refresh failed",
domain="domain",
)
crd.setup_method = AsyncMock(side_effect=oauth_exception)
with pytest.raises(expected_exception) as err:
await crd.async_config_entry_first_refresh()
assert crd.last_update_success is False
# Check thoroughly the chain
assert isinstance(err.value, expected_exception)
assert isinstance(err.value.__cause__, exception)
assert isinstance(err.value.__cause__, OAuth2TokenRequestError)
async def test_refresh_no_update_method(
crd: update_coordinator.DataUpdateCoordinator[int],
) -> None: