diff --git a/homeassistant/components/accuweather/config_flow.py b/homeassistant/components/accuweather/config_flow.py index d16b9a1f77a..00c5f926456 100644 --- a/homeassistant/components/accuweather/config_flow.py +++ b/homeassistant/components/accuweather/config_flow.py @@ -3,6 +3,7 @@ from __future__ import annotations from asyncio import timeout +from collections.abc import Mapping from typing import Any from accuweather import AccuWeather, ApiError, InvalidApiKeyError, RequestsExceededError @@ -22,6 +23,8 @@ class AccuWeatherFlowHandler(ConfigFlow, domain=DOMAIN): """Config flow for AccuWeather.""" VERSION = 1 + _latitude: float | None = None + _longitude: float | None = None async def async_step_user( self, user_input: dict[str, Any] | None = None @@ -74,3 +77,46 @@ class AccuWeatherFlowHandler(ConfigFlow, domain=DOMAIN): ), errors=errors, ) + + async def async_step_reauth( + self, entry_data: Mapping[str, Any] + ) -> ConfigFlowResult: + """Handle configuration by re-auth.""" + self._latitude = entry_data[CONF_LATITUDE] + self._longitude = entry_data[CONF_LONGITUDE] + + return await self.async_step_reauth_confirm() + + async def async_step_reauth_confirm( + self, user_input: dict[str, Any] | None = None + ) -> ConfigFlowResult: + """Dialog that informs the user that reauth is required.""" + errors: dict[str, str] = {} + + if user_input is not None: + websession = async_get_clientsession(self.hass) + try: + async with timeout(10): + accuweather = AccuWeather( + user_input[CONF_API_KEY], + websession, + latitude=self._latitude, + longitude=self._longitude, + ) + await accuweather.async_get_location() + except (ApiError, ClientConnectorError, TimeoutError, ClientError): + errors["base"] = "cannot_connect" + except InvalidApiKeyError: + errors["base"] = "invalid_api_key" + except RequestsExceededError: + errors["base"] = "requests_exceeded" + else: + return self.async_update_reload_and_abort( + self._get_reauth_entry(), data_updates=user_input + ) + + return self.async_show_form( + step_id="reauth_confirm", + data_schema=vol.Schema({vol.Required(CONF_API_KEY): str}), + errors=errors, + ) diff --git a/homeassistant/components/accuweather/coordinator.py b/homeassistant/components/accuweather/coordinator.py index 7056c6e81fd..3c4991d2c59 100644 --- a/homeassistant/components/accuweather/coordinator.py +++ b/homeassistant/components/accuweather/coordinator.py @@ -15,6 +15,7 @@ from aiohttp.client_exceptions import ClientConnectorError from homeassistant.config_entries import ConfigEntry from homeassistant.const import CONF_NAME from homeassistant.core import HomeAssistant +from homeassistant.exceptions import ConfigEntryAuthFailed from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo from homeassistant.helpers.update_coordinator import ( DataUpdateCoordinator, @@ -30,7 +31,7 @@ from .const import ( UPDATE_INTERVAL_OBSERVATION, ) -EXCEPTIONS = (ApiError, ClientConnectorError, InvalidApiKeyError, RequestsExceededError) +EXCEPTIONS = (ApiError, ClientConnectorError, RequestsExceededError) _LOGGER = logging.getLogger(__name__) @@ -52,6 +53,8 @@ class AccuWeatherObservationDataUpdateCoordinator( ): """Class to manage fetching AccuWeather data API.""" + config_entry: AccuWeatherConfigEntry + def __init__( self, hass: HomeAssistant, @@ -87,6 +90,12 @@ class AccuWeatherObservationDataUpdateCoordinator( translation_key="current_conditions_update_error", translation_placeholders={"error": repr(error)}, ) from error + except InvalidApiKeyError as err: + raise ConfigEntryAuthFailed( + translation_domain=DOMAIN, + translation_key="auth_error", + translation_placeholders={"entry": self.config_entry.title}, + ) from err _LOGGER.debug("Requests remaining: %d", self.accuweather.requests_remaining) @@ -98,6 +107,8 @@ class AccuWeatherForecastDataUpdateCoordinator( ): """Base class for AccuWeather forecast.""" + config_entry: AccuWeatherConfigEntry + def __init__( self, hass: HomeAssistant, @@ -137,6 +148,12 @@ class AccuWeatherForecastDataUpdateCoordinator( translation_key="forecast_update_error", translation_placeholders={"error": repr(error)}, ) from error + except InvalidApiKeyError as err: + raise ConfigEntryAuthFailed( + translation_domain=DOMAIN, + translation_key="auth_error", + translation_placeholders={"entry": self.config_entry.title}, + ) from err _LOGGER.debug("Requests remaining: %d", self.accuweather.requests_remaining) return result diff --git a/homeassistant/components/accuweather/strings.json b/homeassistant/components/accuweather/strings.json index cbda5f8989f..b46393acf78 100644 --- a/homeassistant/components/accuweather/strings.json +++ b/homeassistant/components/accuweather/strings.json @@ -7,6 +7,17 @@ "api_key": "[%key:common::config_flow::data::api_key%]", "latitude": "[%key:common::config_flow::data::latitude%]", "longitude": "[%key:common::config_flow::data::longitude%]" + }, + "data_description": { + "api_key": "API key generated in the AccuWeather APIs portal." + } + }, + "reauth_confirm": { + "data": { + "api_key": "[%key:common::config_flow::data::api_key%]" + }, + "data_description": { + "api_key": "[%key:component::accuweather::config::step::user::data_description::api_key%]" } } }, @@ -19,7 +30,8 @@ "requests_exceeded": "The allowed number of requests to the AccuWeather API has been exceeded. You have to wait or change the API key." }, "abort": { - "already_configured": "[%key:common::config_flow::abort::already_configured_location%]" + "already_configured": "[%key:common::config_flow::abort::already_configured_location%]", + "reauth_successful": "[%key:common::config_flow::abort::reauth_successful%]" } }, "entity": { @@ -239,6 +251,9 @@ } }, "exceptions": { + "auth_error": { + "message": "Authentication failed for {entry}, please update your API key" + }, "current_conditions_update_error": { "message": "An error occurred while retrieving weather current conditions data from the AccuWeather API: {error}" }, diff --git a/tests/components/accuweather/test_config_flow.py b/tests/components/accuweather/test_config_flow.py index ff1f31f01bc..f17f4362aca 100644 --- a/tests/components/accuweather/test_config_flow.py +++ b/tests/components/accuweather/test_config_flow.py @@ -3,6 +3,7 @@ from unittest.mock import AsyncMock from accuweather import ApiError, InvalidApiKeyError, RequestsExceededError +import pytest from homeassistant.components.accuweather.const import DOMAIN from homeassistant.config_entries import SOURCE_USER @@ -10,6 +11,8 @@ from homeassistant.const import CONF_API_KEY, CONF_LATITUDE, CONF_LONGITUDE, CON from homeassistant.core import HomeAssistant from homeassistant.data_entry_flow import FlowResultType +from . import init_integration + from tests.common import MockConfigEntry VALID_CONFIG = { @@ -117,3 +120,64 @@ async def test_create_entry( assert result["data"][CONF_LATITUDE] == 55.55 assert result["data"][CONF_LONGITUDE] == 122.12 assert result["data"][CONF_API_KEY] == "32-character-string-1234567890qw" + + +async def test_reauth_successful( + hass: HomeAssistant, mock_accuweather_client: AsyncMock +) -> None: + """Test starting a reauthentication flow.""" + mock_config_entry = await init_integration(hass) + + result = await mock_config_entry.start_reauth_flow(hass) + assert result["type"] is FlowResultType.FORM + assert result["step_id"] == "reauth_confirm" + + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + user_input={CONF_API_KEY: "new_api_key"}, + ) + + assert result["type"] is FlowResultType.ABORT + assert result["reason"] == "reauth_successful" + assert mock_config_entry.data[CONF_API_KEY] == "new_api_key" + + +@pytest.mark.parametrize( + ("exc", "base_error"), + [ + (ApiError("API Error"), "cannot_connect"), + (InvalidApiKeyError("Invalid API Key"), "invalid_api_key"), + (TimeoutError, "cannot_connect"), + (RequestsExceededError("Requests Exceeded"), "requests_exceeded"), + ], +) +async def test_reauth_errors( + hass: HomeAssistant, + exc: Exception, + base_error: str, + mock_accuweather_client: AsyncMock, +) -> None: + """Test reauthentication flow with errors.""" + mock_config_entry = await init_integration(hass) + + result = await mock_config_entry.start_reauth_flow(hass) + assert result["type"] is FlowResultType.FORM + assert result["step_id"] == "reauth_confirm" + + mock_accuweather_client.async_get_location.side_effect = exc + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + user_input={CONF_API_KEY: "new_api_key"}, + ) + + assert result["errors"] == {"base": base_error} + + mock_accuweather_client.async_get_location.side_effect = None + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + user_input={CONF_API_KEY: "new_api_key"}, + ) + + assert result["type"] is FlowResultType.ABORT + assert result["reason"] == "reauth_successful" + assert mock_config_entry.data[CONF_API_KEY] == "new_api_key" diff --git a/tests/components/accuweather/test_init.py b/tests/components/accuweather/test_init.py index f88cde88e7e..f79ddaebb30 100644 --- a/tests/components/accuweather/test_init.py +++ b/tests/components/accuweather/test_init.py @@ -1,8 +1,9 @@ """Test init of AccuWeather integration.""" +from datetime import timedelta from unittest.mock import AsyncMock -from accuweather import ApiError +from accuweather import ApiError, InvalidApiKeyError from freezegun.api import FrozenDateTimeFactory from homeassistant.components.accuweather.const import ( @@ -11,7 +12,7 @@ from homeassistant.components.accuweather.const import ( UPDATE_INTERVAL_OBSERVATION, ) from homeassistant.components.sensor import DOMAIN as SENSOR_DOMAIN -from homeassistant.config_entries import ConfigEntryState +from homeassistant.config_entries import SOURCE_REAUTH, ConfigEntryState from homeassistant.const import STATE_UNAVAILABLE from homeassistant.core import HomeAssistant from homeassistant.helpers import entity_registry as er @@ -118,3 +119,60 @@ async def test_remove_ozone_sensors( entry = entity_registry.async_get("sensor.home_ozone_0d") assert entry is None + + +async def test_auth_error( + hass: HomeAssistant, + freezer: FrozenDateTimeFactory, + mock_accuweather_client: AsyncMock, +) -> None: + """Test authentication error when polling data.""" + mock_accuweather_client.async_get_current_conditions.side_effect = ( + InvalidApiKeyError("Invalid API Key") + ) + + mock_config_entry = await init_integration(hass) + + assert mock_config_entry.state is ConfigEntryState.SETUP_ERROR + + flows = hass.config_entries.flow.async_progress() + assert len(flows) == 1 + + flow = flows[0] + assert flow.get("step_id") == "reauth_confirm" + assert flow.get("handler") == DOMAIN + + assert "context" in flow + assert flow["context"].get("source") == SOURCE_REAUTH + assert flow["context"].get("entry_id") == mock_config_entry.entry_id + + +async def test_auth_error_whe_polling_data( + hass: HomeAssistant, + freezer: FrozenDateTimeFactory, + mock_accuweather_client: AsyncMock, +) -> None: + """Test authentication error when polling data.""" + mock_config_entry = await init_integration(hass) + + assert mock_config_entry.state is ConfigEntryState.LOADED + + mock_accuweather_client.async_get_current_conditions.side_effect = ( + InvalidApiKeyError("Invalid API Key") + ) + freezer.tick(timedelta(minutes=10)) + async_fire_time_changed(hass) + await hass.async_block_till_done() + + assert mock_config_entry.state is ConfigEntryState.LOADED + + flows = hass.config_entries.flow.async_progress() + assert len(flows) == 1 + + flow = flows[0] + assert flow.get("step_id") == "reauth_confirm" + assert flow.get("handler") == DOMAIN + + assert "context" in flow + assert flow["context"].get("source") == SOURCE_REAUTH + assert flow["context"].get("entry_id") == mock_config_entry.entry_id diff --git a/tests/components/accuweather/test_sensor.py b/tests/components/accuweather/test_sensor.py index 855c9f3e4d5..69035d63990 100644 --- a/tests/components/accuweather/test_sensor.py +++ b/tests/components/accuweather/test_sensor.py @@ -2,7 +2,7 @@ from unittest.mock import AsyncMock, patch -from accuweather import ApiError, InvalidApiKeyError, RequestsExceededError +from accuweather import ApiError, RequestsExceededError from aiohttp.client_exceptions import ClientConnectorError from freezegun.api import FrozenDateTimeFactory import pytest @@ -86,7 +86,6 @@ async def test_availability( ApiError("API Error"), ConnectionError, ClientConnectorError, - InvalidApiKeyError("Invalid API key"), RequestsExceededError("Requests exceeded"), ], )