diff --git a/homeassistant/components/tibber/__init__.py b/homeassistant/components/tibber/__init__.py index e14a717fcf4..40a882a5b04 100644 --- a/homeassistant/components/tibber/__init__.py +++ b/homeassistant/components/tibber/__init__.py @@ -8,7 +8,6 @@ import logging import aiohttp from aiohttp.client_exceptions import ClientError, ClientResponseError import tibber -from tibber import data_api as tibber_data_api from homeassistant.const import CONF_ACCESS_TOKEN, EVENT_HOMEASSISTANT_STOP, Platform from homeassistant.core import Event, HomeAssistant @@ -23,13 +22,7 @@ from homeassistant.helpers.config_entry_oauth2_flow import ( from homeassistant.helpers.typing import ConfigType from homeassistant.util import dt as dt_util, ssl as ssl_util -from .const import ( - AUTH_IMPLEMENTATION, - CONF_LEGACY_ACCESS_TOKEN, - DATA_HASS_CONFIG, - DOMAIN, - TibberConfigEntry, -) +from .const import AUTH_IMPLEMENTATION, DATA_HASS_CONFIG, DOMAIN, TibberConfigEntry from .coordinator import TibberDataAPICoordinator from .services import async_setup_services @@ -44,24 +37,23 @@ _LOGGER = logging.getLogger(__name__) class TibberRuntimeData: """Runtime data for Tibber API entries.""" - tibber_connection: tibber.Tibber session: OAuth2Session data_api_coordinator: TibberDataAPICoordinator | None = field(default=None) - _client: tibber_data_api.TibberDataAPI | None = None + _client: tibber.Tibber | None = None - async def async_get_client( - self, hass: HomeAssistant - ) -> tibber_data_api.TibberDataAPI: - """Return an authenticated Tibber Data API client.""" + async def async_get_client(self, hass: HomeAssistant) -> tibber.Tibber: + """Return an authenticated Tibber client.""" await self.session.async_ensure_token_valid() token = self.session.token access_token = token.get(CONF_ACCESS_TOKEN) if not access_token: raise ConfigEntryAuthFailed("Access token missing from OAuth session") if self._client is None: - self._client = tibber_data_api.TibberDataAPI( - access_token, + self._client = tibber.Tibber( + access_token=access_token, websession=async_get_clientsession(hass), + time_zone=dt_util.get_default_time_zone(), + ssl=ssl_util.get_default_context(), ) self._client.set_access_token(access_token) return self._client @@ -88,32 +80,6 @@ async def async_setup_entry(hass: HomeAssistant, entry: TibberConfigEntry) -> bo translation_key="data_api_reauth_required", ) - tibber_connection = tibber.Tibber( - access_token=entry.data[CONF_LEGACY_ACCESS_TOKEN], - websession=async_get_clientsession(hass), - time_zone=dt_util.get_default_time_zone(), - ssl=ssl_util.get_default_context(), - ) - - async def _close(event: Event) -> None: - await tibber_connection.rt_disconnect() - - entry.async_on_unload(hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, _close)) - - try: - await tibber_connection.update_info() - except ( - TimeoutError, - aiohttp.ClientError, - tibber.RetryableHttpExceptionError, - ) as err: - raise ConfigEntryNotReady("Unable to connect") from err - except tibber.InvalidLoginError as exp: - _LOGGER.error("Failed to login. %s", exp) - return False - except tibber.FatalHttpExceptionError: - return False - try: implementation = await async_get_config_entry_implementation(hass, entry) except ImplementationUnavailableError as err: @@ -135,10 +101,29 @@ async def async_setup_entry(hass: HomeAssistant, entry: TibberConfigEntry) -> bo raise ConfigEntryNotReady from err entry.runtime_data = TibberRuntimeData( - tibber_connection=tibber_connection, session=session, ) + tibber_connection = await entry.runtime_data.async_get_client(hass) + + async def _close(event: Event) -> None: + await tibber_connection.rt_disconnect() + + entry.async_on_unload(hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, _close)) + + try: + await tibber_connection.update_info() + except ( + TimeoutError, + aiohttp.ClientError, + tibber.RetryableHttpExceptionError, + ) as err: + raise ConfigEntryNotReady("Unable to connect") from err + except tibber.InvalidLoginError as err: + raise ConfigEntryAuthFailed("Invalid login credentials") from err + except tibber.FatalHttpExceptionError as err: + raise ConfigEntryNotReady("Fatal HTTP error from Tibber API") from err + coordinator = TibberDataAPICoordinator(hass, entry) await coordinator.async_config_entry_first_refresh() entry.runtime_data.data_api_coordinator = coordinator @@ -154,5 +139,6 @@ async def async_unload_entry( if unload_ok := await hass.config_entries.async_unload_platforms( config_entry, PLATFORMS ): - await config_entry.runtime_data.tibber_connection.rt_disconnect() + tibber_connection = await config_entry.runtime_data.async_get_client(hass) + await tibber_connection.rt_disconnect() return unload_ok diff --git a/homeassistant/components/tibber/config_flow.py b/homeassistant/components/tibber/config_flow.py index bc8173312c6..c4a2109b8f9 100644 --- a/homeassistant/components/tibber/config_flow.py +++ b/homeassistant/components/tibber/config_flow.py @@ -8,21 +8,16 @@ from typing import Any import aiohttp import tibber -from tibber import data_api as tibber_data_api -import voluptuous as vol -from homeassistant.config_entries import SOURCE_REAUTH, SOURCE_USER, ConfigFlowResult +from homeassistant.config_entries import SOURCE_REAUTH, ConfigFlowResult from homeassistant.const import CONF_ACCESS_TOKEN, CONF_TOKEN from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers.config_entry_oauth2_flow import AbstractOAuth2FlowHandler -from .const import CONF_LEGACY_ACCESS_TOKEN, DATA_API_DEFAULT_SCOPES, DOMAIN +from .const import DATA_API_DEFAULT_SCOPES, DOMAIN -DATA_SCHEMA = vol.Schema({vol.Required(CONF_LEGACY_ACCESS_TOKEN): str}) -ERR_TIMEOUT = "timeout" ERR_CLIENT = "cannot_connect" ERR_TOKEN = "invalid_access_token" -TOKEN_URL = "https://developer.tibber.com/settings/access-token" _LOGGER = logging.getLogger(__name__) @@ -36,8 +31,7 @@ class TibberConfigFlow(AbstractOAuth2FlowHandler, domain=DOMAIN): def __init__(self) -> None: """Initialize the config flow.""" super().__init__() - self._access_token: str | None = None - self._title = "" + self._oauth_data: dict[str, Any] | None = None @property def logger(self) -> logging.Logger: @@ -52,114 +46,70 @@ class TibberConfigFlow(AbstractOAuth2FlowHandler, domain=DOMAIN): "scope": " ".join(DATA_API_DEFAULT_SCOPES), } - async def async_step_user( - self, user_input: dict[str, Any] | None = None - ) -> ConfigFlowResult: - """Handle the initial step.""" - if user_input is None: - data_schema = self.add_suggested_values_to_schema( - DATA_SCHEMA, {CONF_LEGACY_ACCESS_TOKEN: self._access_token or ""} - ) - - return self.async_show_form( - step_id=SOURCE_USER, - data_schema=data_schema, - description_placeholders={"url": TOKEN_URL}, - errors={}, - ) - - self._access_token = user_input[CONF_LEGACY_ACCESS_TOKEN].replace(" ", "") - tibber_connection = tibber.Tibber( - access_token=self._access_token, - websession=async_get_clientsession(self.hass), - ) - self._title = tibber_connection.name or "Tibber" - - errors: dict[str, str] = {} - try: - await tibber_connection.update_info() - except TimeoutError: - errors[CONF_LEGACY_ACCESS_TOKEN] = ERR_TIMEOUT - except tibber.InvalidLoginError: - errors[CONF_LEGACY_ACCESS_TOKEN] = ERR_TOKEN - except ( - aiohttp.ClientError, - tibber.RetryableHttpExceptionError, - tibber.FatalHttpExceptionError, - ): - errors[CONF_LEGACY_ACCESS_TOKEN] = ERR_CLIENT - - if errors: - data_schema = self.add_suggested_values_to_schema( - DATA_SCHEMA, {CONF_LEGACY_ACCESS_TOKEN: self._access_token or ""} - ) - - return self.async_show_form( - step_id=SOURCE_USER, - data_schema=data_schema, - description_placeholders={"url": TOKEN_URL}, - errors=errors, - ) - - await self.async_set_unique_id(tibber_connection.user_id) - - if self.source == SOURCE_REAUTH: - reauth_entry = self._get_reauth_entry() - self._abort_if_unique_id_mismatch( - reason="wrong_account", - description_placeholders={"title": reauth_entry.title}, - ) - else: - self._abort_if_unique_id_configured() - - return await self.async_step_pick_implementation() - async def async_step_reauth( self, entry_data: Mapping[str, Any] ) -> ConfigFlowResult: """Handle a reauth flow.""" - reauth_entry = self._get_reauth_entry() - self._access_token = reauth_entry.data.get(CONF_LEGACY_ACCESS_TOKEN) - self._title = reauth_entry.title return await self.async_step_reauth_confirm() async def async_step_reauth_confirm( self, user_input: dict[str, Any] | None = None ) -> ConfigFlowResult: """Confirm reauthentication by reusing the user step.""" - reauth_entry = self._get_reauth_entry() - self._access_token = reauth_entry.data.get(CONF_LEGACY_ACCESS_TOKEN) - self._title = reauth_entry.title if user_input is None: - return self.async_show_form( - step_id="reauth_confirm", - ) + return self.async_show_form(step_id="reauth_confirm") return await self.async_step_user() async def async_oauth_create_entry(self, data: dict) -> ConfigFlowResult: """Finalize the OAuth flow and create the config entry.""" - if self._access_token is None: - return self.async_abort(reason="missing_configuration") + self._oauth_data = data + return await self._async_validate_and_create() - data[CONF_LEGACY_ACCESS_TOKEN] = self._access_token + async def async_step_connection_error( + self, user_input: dict[str, Any] | None = None + ) -> ConfigFlowResult: + """Handle connection error retry.""" + if user_input is not None: + return await self._async_validate_and_create() + return self.async_show_form(step_id="connection_error") - access_token = data[CONF_TOKEN][CONF_ACCESS_TOKEN] - data_api_client = tibber_data_api.TibberDataAPI( - access_token, + async def _async_validate_and_create(self) -> ConfigFlowResult: + """Validate the OAuth token and create the config entry.""" + assert self._oauth_data is not None + access_token = self._oauth_data[CONF_TOKEN][CONF_ACCESS_TOKEN] + tibber_connection = tibber.Tibber( + access_token=access_token, websession=async_get_clientsession(self.hass), ) try: - await data_api_client.get_userinfo() - except (aiohttp.ClientError, TimeoutError): - return self.async_abort(reason="cannot_connect") + await tibber_connection.update_info() + except TimeoutError: + return await self.async_step_connection_error() + except tibber.InvalidLoginError: + return self.async_abort(reason=ERR_TOKEN) + except ( + aiohttp.ClientError, + tibber.RetryableHttpExceptionError, + ): + return await self.async_step_connection_error() + except tibber.FatalHttpExceptionError: + return self.async_abort(reason=ERR_CLIENT) + await self.async_set_unique_id(tibber_connection.user_id) + + title = tibber_connection.name or "Tibber" if self.source == SOURCE_REAUTH: reauth_entry = self._get_reauth_entry() + self._abort_if_unique_id_mismatch( + reason="wrong_account", + description_placeholders={"title": reauth_entry.title}, + ) return self.async_update_reload_and_abort( reauth_entry, - data=data, - title=self._title, + data=self._oauth_data, + title=title, ) - return self.async_create_entry(title=self._title, data=data) + self._abort_if_unique_id_configured() + return self.async_create_entry(title=title, data=self._oauth_data) diff --git a/homeassistant/components/tibber/const.py b/homeassistant/components/tibber/const.py index 8a856bb95c4..4151f21e444 100644 --- a/homeassistant/components/tibber/const.py +++ b/homeassistant/components/tibber/const.py @@ -5,7 +5,6 @@ from __future__ import annotations from typing import TYPE_CHECKING from homeassistant.config_entries import ConfigEntry -from homeassistant.const import CONF_ACCESS_TOKEN if TYPE_CHECKING: from . import TibberRuntimeData @@ -13,8 +12,6 @@ if TYPE_CHECKING: type TibberConfigEntry = ConfigEntry[TibberRuntimeData] -CONF_LEGACY_ACCESS_TOKEN = CONF_ACCESS_TOKEN - AUTH_IMPLEMENTATION = "auth_implementation" DATA_HASS_CONFIG = "tibber_hass_config" DOMAIN = "tibber" diff --git a/homeassistant/components/tibber/coordinator.py b/homeassistant/components/tibber/coordinator.py index 39fca55238c..43e51bc8c45 100644 --- a/homeassistant/components/tibber/coordinator.py +++ b/homeassistant/components/tibber/coordinator.py @@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, cast from aiohttp.client_exceptions import ClientError import tibber -from tibber.data_api import TibberDataAPI, TibberDevice +from tibber.data_api import TibberDevice from homeassistant.components.recorder import get_instance from homeassistant.components.recorder.models import ( @@ -230,28 +230,26 @@ class TibberDataAPICoordinator(DataUpdateCoordinator[dict[str, TibberDevice]]): return device_sensors.get(sensor_id) return None - async def _async_get_client(self) -> TibberDataAPI: - """Get the Tibber Data API client with error handling.""" + async def _async_get_client(self) -> tibber.Tibber: + """Get the Tibber client with error handling.""" try: return await self._runtime_data.async_get_client(self.hass) except ConfigEntryAuthFailed: raise except (ClientError, TimeoutError, tibber.UserAgentMissingError) as err: - raise UpdateFailed( - f"Unable to create Tibber Data API client: {err}" - ) from err + raise UpdateFailed(f"Unable to create Tibber client: {err}") from err async def _async_setup(self) -> None: """Initial load of Tibber Data API devices.""" client = await self._async_get_client() - devices = await client.get_all_devices() + devices = await client.data_api.get_all_devices() self._build_sensor_lookup(devices) async def _async_update_data(self) -> dict[str, TibberDevice]: """Fetch the latest device capabilities from the Tibber Data API.""" client = await self._async_get_client() try: - devices: dict[str, TibberDevice] = await client.update_devices() + devices: dict[str, TibberDevice] = await client.data_api.update_devices() except tibber.exceptions.RateLimitExceededError as err: raise UpdateFailed( f"Rate limit exceeded, retry after {err.retry_after} seconds", diff --git a/homeassistant/components/tibber/diagnostics.py b/homeassistant/components/tibber/diagnostics.py index 9c8f9ff5ae8..bde48b75972 100644 --- a/homeassistant/components/tibber/diagnostics.py +++ b/homeassistant/components/tibber/diagnostics.py @@ -15,6 +15,7 @@ async def async_get_config_entry_diagnostics( """Return diagnostics for a config entry.""" runtime = config_entry.runtime_data + tibber_connection = await runtime.async_get_client(hass) result: dict[str, Any] = { "homes": [ { @@ -24,7 +25,7 @@ async def async_get_config_entry_diagnostics( "last_cons_data_timestamp": home.last_cons_data_timestamp, "country": home.country, } - for home in runtime.tibber_connection.get_homes(only_active=False) + for home in tibber_connection.get_homes(only_active=False) ] } diff --git a/homeassistant/components/tibber/notify.py b/homeassistant/components/tibber/notify.py index b5e54a23b76..7dc5c2c259b 100644 --- a/homeassistant/components/tibber/notify.py +++ b/homeassistant/components/tibber/notify.py @@ -2,6 +2,8 @@ from __future__ import annotations +import tibber + from homeassistant.components.notify import ( ATTR_TITLE_DEFAULT, NotifyEntity, @@ -37,7 +39,9 @@ class TibberNotificationEntity(NotifyEntity): async def async_send_message(self, message: str, title: str | None = None) -> None: """Send a message to Tibber devices.""" - tibber_connection = self._entry.runtime_data.tibber_connection + tibber_connection: tibber.Tibber = ( + await self._entry.runtime_data.async_get_client(self.hass) + ) try: await tibber_connection.send_notification( title or ATTR_TITLE_DEFAULT, message diff --git a/homeassistant/components/tibber/sensor.py b/homeassistant/components/tibber/sensor.py index 5c3632c32d1..7038c39f516 100644 --- a/homeassistant/components/tibber/sensor.py +++ b/homeassistant/components/tibber/sensor.py @@ -605,7 +605,7 @@ async def _async_setup_graphql_sensors( ) -> None: """Set up the Tibber sensor.""" - tibber_connection = entry.runtime_data.tibber_connection + tibber_connection = await entry.runtime_data.async_get_client(hass) entity_registry = er.async_get(hass) diff --git a/homeassistant/components/tibber/services.py b/homeassistant/components/tibber/services.py index cbe90ddda64..099739e4478 100644 --- a/homeassistant/components/tibber/services.py +++ b/homeassistant/components/tibber/services.py @@ -42,7 +42,7 @@ async def __get_prices(call: ServiceCall) -> ServiceResponse: translation_domain=DOMAIN, translation_key="no_config_entry", ) - tibber_connection = entries[0].runtime_data.tibber_connection + tibber_connection = await entries[0].runtime_data.async_get_client(call.hass) start = __get_date(call.data.get(ATTR_START), "start") end = __get_date(call.data.get(ATTR_END), "end") diff --git a/homeassistant/components/tibber/strings.json b/homeassistant/components/tibber/strings.json index 1e6011381e3..d07f295785e 100644 --- a/homeassistant/components/tibber/strings.json +++ b/homeassistant/components/tibber/strings.json @@ -2,26 +2,21 @@ "config": { "abort": { "already_configured": "[%key:common::config_flow::abort::already_configured_service%]", + "cannot_connect": "[%key:common::config_flow::error::cannot_connect%]", + "invalid_access_token": "[%key:common::config_flow::error::invalid_access_token%]", "missing_configuration": "[%key:common::config_flow::abort::oauth2_missing_configuration%]", "missing_credentials": "[%key:common::config_flow::abort::oauth2_missing_credentials%]", "reauth_successful": "[%key:common::config_flow::abort::reauth_successful%]", "wrong_account": "The connected account does not match {title}. Sign in with the same Tibber account and try again." }, - "error": { - "cannot_connect": "[%key:common::config_flow::error::cannot_connect%]", - "invalid_access_token": "[%key:common::config_flow::error::invalid_access_token%]", - "timeout": "[%key:common::config_flow::error::timeout_connect%]" - }, "step": { + "connection_error": { + "description": "Could not connect to Tibber. Check your internet connection and try again.", + "title": "Connection failed" + }, "reauth_confirm": { "description": "Reconnect your Tibber account to refresh access.", "title": "[%key:common::config_flow::title::reauth%]" - }, - "user": { - "data": { - "access_token": "[%key:common::config_flow::data::access_token%]" - }, - "description": "Enter your access token from {url}" } } }, diff --git a/tests/components/tibber/conftest.py b/tests/components/tibber/conftest.py index bc6ecd7d8a9..cbd5ece648a 100644 --- a/tests/components/tibber/conftest.py +++ b/tests/components/tibber/conftest.py @@ -19,6 +19,14 @@ from homeassistant.core import HomeAssistant from homeassistant.setup import async_setup_component from tests.common import MockConfigEntry +from tests.typing import RecorderInstanceContextManager + + +@pytest.fixture +async def mock_recorder_before_hass( + async_test_recorder: RecorderInstanceContextManager, +) -> None: + """Set up recorder before hass fixture runs.""" def create_tibber_device( @@ -158,21 +166,15 @@ def config_entry(hass: HomeAssistant) -> MockConfigEntry: @pytest.fixture -def _tibber_patches() -> AsyncGenerator[tuple[MagicMock, MagicMock]]: +def tibber_mock() -> AsyncGenerator[MagicMock]: """Patch the Tibber libraries used by the integration.""" unique_user_id = "unique_user_id" title = "title" - with ( - patch( - "tibber.Tibber", - autospec=True, - ) as mock_tibber, - patch( - "tibber.data_api.TibberDataAPI", - autospec=True, - ) as mock_data_api_client, - ): + with patch( + "tibber.Tibber", + autospec=True, + ) as mock_tibber: tibber_mock = mock_tibber.return_value tibber_mock.update_info = AsyncMock(return_value=True) tibber_mock.user_id = unique_user_id @@ -180,24 +182,21 @@ def _tibber_patches() -> AsyncGenerator[tuple[MagicMock, MagicMock]]: tibber_mock.send_notification = AsyncMock() tibber_mock.rt_disconnect = AsyncMock() tibber_mock.get_homes = MagicMock(return_value=[]) + tibber_mock.set_access_token = MagicMock() - data_api_client_mock = mock_data_api_client.return_value - data_api_client_mock.get_all_devices = AsyncMock(return_value={}) - data_api_client_mock.update_devices = AsyncMock(return_value={}) + data_api_mock = MagicMock() + data_api_mock.get_all_devices = AsyncMock(return_value={}) + data_api_mock.update_devices = AsyncMock(return_value={}) + data_api_mock.get_userinfo = AsyncMock() + tibber_mock.data_api = data_api_mock - yield tibber_mock, data_api_client_mock + yield tibber_mock @pytest.fixture -def tibber_mock(_tibber_patches: tuple[MagicMock, MagicMock]) -> MagicMock: - """Return the patched Tibber connection mock.""" - return _tibber_patches[0] - - -@pytest.fixture -def data_api_client_mock(_tibber_patches: tuple[MagicMock, MagicMock]) -> MagicMock: +def data_api_client_mock(tibber_mock: MagicMock) -> MagicMock: """Return the patched Tibber Data API client mock.""" - return _tibber_patches[1] + return tibber_mock.data_api @pytest.fixture diff --git a/tests/components/tibber/test_config_flow.py b/tests/components/tibber/test_config_flow.py index bcd77b29eb2..1d6f20a66dc 100644 --- a/tests/components/tibber/test_config_flow.py +++ b/tests/components/tibber/test_config_flow.py @@ -19,7 +19,6 @@ from homeassistant.components.tibber.application_credentials import TOKEN_URL from homeassistant.components.tibber.config_flow import ( DATA_API_DEFAULT_SCOPES, ERR_CLIENT, - ERR_TIMEOUT, ERR_TOKEN, ) from homeassistant.components.tibber.const import AUTH_IMPLEMENTATION, DOMAIN @@ -55,226 +54,164 @@ def _mock_tibber( return tibber_mock -async def test_show_config_form(recorder_mock: Recorder, hass: HomeAssistant) -> None: - """Test show configuration form.""" - result = await hass.config_entries.flow.async_init( - DOMAIN, context={"source": config_entries.SOURCE_USER} - ) - - assert result["type"] is FlowResultType.FORM - assert result["step_id"] == "user" - - +@pytest.mark.usefixtures("setup_credentials", "current_request_with_host") @pytest.mark.parametrize( ("exception", "expected_error"), [ - (builtins.TimeoutError(), ERR_TIMEOUT), - (ClientError(), ERR_CLIENT), (InvalidLoginError(401), ERR_TOKEN), - (RetryableHttpExceptionError(503), ERR_CLIENT), (FatalHttpExceptionError(404), ERR_CLIENT), ], ) -async def test_graphql_step_exceptions( +async def test_oauth_create_entry_abort_exceptions( recorder_mock: Recorder, hass: HomeAssistant, tibber_mock: MagicMock, exception: Exception, expected_error: str, ) -> None: - """Validate GraphQL errors are surfaced.""" + """Validate fatal errors during OAuth finalization abort the flow.""" + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + handler = hass.config_entries.flow._progress[result["flow_id"]] + + _mock_tibber(tibber_mock, update_side_effect=exception) + flow_result = await handler.async_oauth_create_entry( + {CONF_TOKEN: {CONF_ACCESS_TOKEN: "rest-token"}} + ) + + assert flow_result["type"] is FlowResultType.ABORT + assert flow_result["reason"] == expected_error + + +@pytest.mark.usefixtures("setup_credentials", "current_request_with_host") +@pytest.mark.parametrize( + "exception", + [ + builtins.TimeoutError(), + ClientError(), + RetryableHttpExceptionError(503), + ], +) +async def test_oauth_create_entry_connection_error_retry( + recorder_mock: Recorder, + hass: HomeAssistant, + hass_client_no_auth: ClientSessionGenerator, + aioclient_mock: AiohttpClientMocker, + tibber_mock: MagicMock, + exception: Exception, +) -> None: + """Validate transient connection errors show retry form.""" result = await hass.config_entries.flow.async_init( DOMAIN, context={"source": config_entries.SOURCE_USER} ) _mock_tibber(tibber_mock, update_side_effect=exception) - result = await hass.config_entries.flow.async_configure( - result["flow_id"], {CONF_ACCESS_TOKEN: "invalid"} + assert result["type"] is FlowResultType.EXTERNAL_STEP + authorize_url = result["url"] + state = parse_qs(urlparse(authorize_url).query)["state"][0] + + client = await hass_client_no_auth() + resp = await client.get(f"/auth/external/callback?code=abcd&state={state}") + assert resp.status == HTTPStatus.OK + + aioclient_mock.post( + TOKEN_URL, + json={ + "access_token": "mock-access-token", + "refresh_token": "mock-refresh-token", + "token_type": "bearer", + "expires_in": 3600, + }, ) + result = await hass.config_entries.flow.async_configure(result["flow_id"]) assert result["type"] is FlowResultType.FORM - assert result["step_id"] == "user" - assert result["errors"][CONF_ACCESS_TOKEN] == expected_error + assert result["step_id"] == "connection_error" - -async def test_flow_entry_already_exists( - recorder_mock: Recorder, - hass: HomeAssistant, - config_entry, - tibber_mock: MagicMock, -) -> None: - """Test user input for config_entry that already exists.""" - result = await hass.config_entries.flow.async_init( - DOMAIN, context={"source": config_entries.SOURCE_USER} - ) - - _mock_tibber(tibber_mock, user_id="tibber") + tibber_mock.update_info.side_effect = None result = await hass.config_entries.flow.async_configure( - result["flow_id"], {CONF_ACCESS_TOKEN: "valid"} + result["flow_id"], user_input={} ) - assert result["type"] is FlowResultType.ABORT - assert result["reason"] == "already_configured" - - -async def test_reauth_flow_steps( - recorder_mock: Recorder, - hass: HomeAssistant, - config_entry: MockConfigEntry, -) -> None: - """Test the reauth flow goes through reauth_confirm to user step.""" - reauth_flow = await config_entry.start_reauth_flow(hass) - - assert reauth_flow["type"] is FlowResultType.FORM - assert reauth_flow["step_id"] == "reauth_confirm" - - result = await hass.config_entries.flow.async_configure(reauth_flow["flow_id"]) - assert result["type"] is FlowResultType.FORM - assert result["step_id"] == "reauth_confirm" - - result = await hass.config_entries.flow.async_configure( - reauth_flow["flow_id"], - user_input={}, - ) - assert result["type"] is FlowResultType.FORM - assert result["step_id"] == "user" - - -async def test_oauth_create_entry_missing_configuration( - recorder_mock: Recorder, - hass: HomeAssistant, -) -> None: - """Abort OAuth finalize if GraphQL step did not run.""" - result = await hass.config_entries.flow.async_init( - DOMAIN, - context={"source": config_entries.SOURCE_USER}, - ) - handler = hass.config_entries.flow._progress[result["flow_id"]] - - flow_result = await handler.async_oauth_create_entry( - {CONF_TOKEN: {CONF_ACCESS_TOKEN: "rest-token"}} - ) - - assert flow_result["type"] is FlowResultType.ABORT - assert flow_result["reason"] == "missing_configuration" - - -async def test_oauth_create_entry_cannot_connect_userinfo( - recorder_mock: Recorder, - hass: HomeAssistant, - data_api_client_mock: MagicMock, -) -> None: - """Abort OAuth finalize when Data API userinfo cannot be retrieved.""" - result = await hass.config_entries.flow.async_init( - DOMAIN, - context={"source": config_entries.SOURCE_USER}, - ) - handler = hass.config_entries.flow._progress[result["flow_id"]] - handler._access_token = "graphql-token" - - data_api_client_mock.get_userinfo = AsyncMock(side_effect=ClientError()) - flow_result = await handler.async_oauth_create_entry( - {CONF_TOKEN: {CONF_ACCESS_TOKEN: "rest-token"}} - ) - - assert flow_result["type"] is FlowResultType.ABORT - assert flow_result["reason"] == "cannot_connect" + assert result["type"] is FlowResultType.CREATE_ENTRY + assert result["title"] == "Mock Name" async def test_data_api_requires_credentials( recorder_mock: Recorder, hass: HomeAssistant, - tibber_mock: MagicMock, ) -> None: """Abort when OAuth credentials are missing.""" result = await hass.config_entries.flow.async_init( DOMAIN, context={"source": config_entries.SOURCE_USER} ) - _mock_tibber(tibber_mock) - result = await hass.config_entries.flow.async_configure( - result["flow_id"], {CONF_ACCESS_TOKEN: "valid"} - ) - assert result["type"] is FlowResultType.ABORT assert result["reason"] == "missing_credentials" @pytest.mark.usefixtures("setup_credentials", "current_request_with_host") async def test_data_api_extra_authorize_scope( + recorder_mock: Recorder, hass: HomeAssistant, - tibber_mock: MagicMock, ) -> None: """Ensure the OAuth implementation requests Tibber scopes.""" - with patch("homeassistant.components.recorder.async_setup", return_value=True): - result = await hass.config_entries.flow.async_init( - DOMAIN, context={"source": config_entries.SOURCE_USER} - ) + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": config_entries.SOURCE_USER} + ) - _mock_tibber(tibber_mock) - result = await hass.config_entries.flow.async_configure( - result["flow_id"], {CONF_ACCESS_TOKEN: "valid"} - ) - - handler = hass.config_entries.flow._progress[result["flow_id"]] - assert handler.extra_authorize_data["scope"] == " ".join( - DATA_API_DEFAULT_SCOPES - ) + handler = hass.config_entries.flow._progress[result["flow_id"]] + assert handler.extra_authorize_data["scope"] == " ".join(DATA_API_DEFAULT_SCOPES) @pytest.mark.usefixtures("setup_credentials", "current_request_with_host") async def test_full_flow_success( + recorder_mock: Recorder, hass: HomeAssistant, hass_client_no_auth: ClientSessionGenerator, aioclient_mock: AiohttpClientMocker, tibber_mock: MagicMock, - data_api_client_mock: MagicMock, ) -> None: - """Test configuring Tibber via GraphQL + OAuth.""" - with patch("homeassistant.components.recorder.async_setup", return_value=True): - result = await hass.config_entries.flow.async_init( - DOMAIN, context={"source": config_entries.SOURCE_USER} - ) + """Test configuring Tibber via OAuth.""" + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": config_entries.SOURCE_USER} + ) - _mock_tibber(tibber_mock) - result = await hass.config_entries.flow.async_configure( - result["flow_id"], {CONF_ACCESS_TOKEN: "graphql-token"} - ) + _mock_tibber(tibber_mock) + assert result["type"] is FlowResultType.EXTERNAL_STEP + authorize_url = result["url"] + state = parse_qs(urlparse(authorize_url).query)["state"][0] - assert result["type"] is FlowResultType.EXTERNAL_STEP - authorize_url = result["url"] - state = parse_qs(urlparse(authorize_url).query)["state"][0] + client = await hass_client_no_auth() + resp = await client.get(f"/auth/external/callback?code=abcd&state={state}") + assert resp.status == HTTPStatus.OK - client = await hass_client_no_auth() - resp = await client.get(f"/auth/external/callback?code=abcd&state={state}") - assert resp.status == HTTPStatus.OK + aioclient_mock.post( + TOKEN_URL, + json={ + "access_token": "mock-access-token", + "refresh_token": "mock-refresh-token", + "token_type": "bearer", + "expires_in": 3600, + }, + ) - aioclient_mock.post( - TOKEN_URL, - json={ - "access_token": "mock-access-token", - "refresh_token": "mock-refresh-token", - "token_type": "bearer", - "expires_in": 3600, - }, - ) + result = await hass.config_entries.flow.async_configure(result["flow_id"]) - data_api_client_mock.get_userinfo = AsyncMock( - return_value={"name": "Mock Name"} - ) - result = await hass.config_entries.flow.async_configure(result["flow_id"]) - - assert result["type"] is FlowResultType.CREATE_ENTRY - data = result["data"] - assert data[CONF_TOKEN]["access_token"] == "mock-access-token" - assert data[CONF_ACCESS_TOKEN] == "graphql-token" - assert data[AUTH_IMPLEMENTATION] == DOMAIN - assert result["title"] == "Mock Name" + assert result["type"] is FlowResultType.CREATE_ENTRY + data = result["data"] + assert data[CONF_TOKEN]["access_token"] == "mock-access-token" + assert data[AUTH_IMPLEMENTATION] == DOMAIN + assert result["title"] == "Mock Name" +@pytest.mark.usefixtures("setup_credentials", "current_request_with_host") async def test_data_api_abort_when_already_configured( recorder_mock: Recorder, hass: HomeAssistant, + hass_client_no_auth: ClientSessionGenerator, + aioclient_mock: AiohttpClientMocker, tibber_mock: MagicMock, ) -> None: """Ensure only a single Data API entry can be configured.""" @@ -283,7 +220,6 @@ async def test_data_api_abort_when_already_configured( data={ AUTH_IMPLEMENTATION: DOMAIN, CONF_TOKEN: {"access_token": "existing"}, - CONF_ACCESS_TOKEN: "stored-graphql", }, unique_id="unique_user_id", title="Existing", @@ -295,9 +231,133 @@ async def test_data_api_abort_when_already_configured( ) _mock_tibber(tibber_mock) - result = await hass.config_entries.flow.async_configure( - result["flow_id"], {CONF_ACCESS_TOKEN: "new-token"} + assert result["type"] is FlowResultType.EXTERNAL_STEP + authorize_url = result["url"] + state = parse_qs(urlparse(authorize_url).query)["state"][0] + + client = await hass_client_no_auth() + resp = await client.get(f"/auth/external/callback?code=abcd&state={state}") + assert resp.status == HTTPStatus.OK + + aioclient_mock.post( + TOKEN_URL, + json={ + "access_token": "mock-access-token", + "refresh_token": "mock-refresh-token", + "token_type": "bearer", + "expires_in": 3600, + }, ) + result = await hass.config_entries.flow.async_configure(result["flow_id"]) + assert result["type"] is FlowResultType.ABORT assert result["reason"] == "already_configured" + + +@pytest.mark.usefixtures("setup_credentials", "current_request_with_host") +async def test_reauth_flow_success( + recorder_mock: Recorder, + hass: HomeAssistant, + hass_client_no_auth: ClientSessionGenerator, + aioclient_mock: AiohttpClientMocker, + tibber_mock: MagicMock, +) -> None: + """Test successful reauthentication flow.""" + existing_entry = MockConfigEntry( + domain=DOMAIN, + data={ + AUTH_IMPLEMENTATION: DOMAIN, + CONF_TOKEN: {"access_token": "old-token"}, + }, + unique_id="unique_user_id", + title="Existing", + ) + existing_entry.add_to_hass(hass) + + result = await existing_entry.start_reauth_flow(hass) + + assert result["type"] is FlowResultType.FORM + assert result["step_id"] == "reauth_confirm" + + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + user_input={}, + ) + + _mock_tibber(tibber_mock) + assert result["type"] is FlowResultType.EXTERNAL_STEP + authorize_url = result["url"] + state = parse_qs(urlparse(authorize_url).query)["state"][0] + + client = await hass_client_no_auth() + resp = await client.get(f"/auth/external/callback?code=abcd&state={state}") + assert resp.status == HTTPStatus.OK + + aioclient_mock.post( + TOKEN_URL, + json={ + "access_token": "new-access-token", + "refresh_token": "new-refresh-token", + "token_type": "bearer", + "expires_in": 3600, + }, + ) + + result = await hass.config_entries.flow.async_configure(result["flow_id"]) + + assert result["type"] is FlowResultType.ABORT + assert result["reason"] == "reauth_successful" + assert existing_entry.data[CONF_TOKEN]["access_token"] == "new-access-token" + + +@pytest.mark.usefixtures("setup_credentials", "current_request_with_host") +async def test_reauth_flow_wrong_account( + recorder_mock: Recorder, + hass: HomeAssistant, + hass_client_no_auth: ClientSessionGenerator, + aioclient_mock: AiohttpClientMocker, + tibber_mock: MagicMock, +) -> None: + """Test reauthentication with wrong account aborts.""" + existing_entry = MockConfigEntry( + domain=DOMAIN, + data={ + AUTH_IMPLEMENTATION: DOMAIN, + CONF_TOKEN: {"access_token": "old-token"}, + }, + unique_id="original_user_id", + title="Existing", + ) + existing_entry.add_to_hass(hass) + + result = await existing_entry.start_reauth_flow(hass) + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + user_input={}, + ) + + # Mock a different user_id than the existing entry + _mock_tibber(tibber_mock, user_id="different_user_id") + assert result["type"] is FlowResultType.EXTERNAL_STEP + authorize_url = result["url"] + state = parse_qs(urlparse(authorize_url).query)["state"][0] + + client = await hass_client_no_auth() + resp = await client.get(f"/auth/external/callback?code=abcd&state={state}") + assert resp.status == HTTPStatus.OK + + aioclient_mock.post( + TOKEN_URL, + json={ + "access_token": "new-access-token", + "refresh_token": "new-refresh-token", + "token_type": "bearer", + "expires_in": 3600, + }, + ) + + result = await hass.config_entries.flow.async_configure(result["flow_id"]) + + assert result["type"] is FlowResultType.ABORT + assert result["reason"] == "wrong_account" diff --git a/tests/components/tibber/test_init.py b/tests/components/tibber/test_init.py index 3007ef34e13..111ff50c0c3 100644 --- a/tests/components/tibber/test_init.py +++ b/tests/components/tibber/test_init.py @@ -36,19 +36,18 @@ async def test_data_api_runtime_creates_client(hass: HomeAssistant) -> None: runtime = TibberRuntimeData( session=session, - tibber_connection=MagicMock(), ) - with patch( - "homeassistant.components.tibber.tibber_data_api.TibberDataAPI" - ) as mock_client_cls: + with patch("homeassistant.components.tibber.tibber.Tibber") as mock_client_cls: mock_client = MagicMock() mock_client.set_access_token = MagicMock() mock_client_cls.return_value = mock_client client = await runtime.async_get_client(hass) - mock_client_cls.assert_called_once_with("access-token", websession=ANY) + mock_client_cls.assert_called_once_with( + access_token="access-token", websession=ANY, time_zone=ANY, ssl=ANY + ) session.async_ensure_token_valid.assert_awaited_once() mock_client.set_access_token.assert_called_once_with("access-token") assert client is mock_client @@ -73,7 +72,6 @@ async def test_data_api_runtime_missing_token_raises(hass: HomeAssistant) -> Non runtime = TibberRuntimeData( session=session, - tibber_connection=MagicMock(), ) with pytest.raises(ConfigEntryAuthFailed):