From 84d2686517ddcd4a044976320e807e570fdbd6d6 Mon Sep 17 00:00:00 2001 From: victorigualada <21220224+victorigualada@users.noreply.github.com> Date: Tue, 2 Dec 2025 17:47:08 +0100 Subject: [PATCH] Don't register Home Assistant Cloud LLM platforms if not logged in (#157630) --- homeassistant/components/cloud/__init__.py | 23 ++++++-- homeassistant/components/cloud/ai_task.py | 17 ++---- .../components/cloud/conversation.py | 13 +---- homeassistant/components/cloud/entity.py | 9 ++- tests/components/cloud/conftest.py | 1 + .../cloud/snapshots/test_http_api.ambr | 8 ++- tests/components/cloud/test_ai_task.py | 43 +++++--------- tests/components/cloud/test_conversation.py | 58 +------------------ 8 files changed, 51 insertions(+), 121 deletions(-) diff --git a/homeassistant/components/cloud/__init__.py b/homeassistant/components/cloud/__init__.py index e9b0f6ab294..410029c2716 100644 --- a/homeassistant/components/cloud/__init__.py +++ b/homeassistant/components/cloud/__init__.py @@ -4,12 +4,13 @@ from __future__ import annotations import asyncio from collections.abc import Awaitable, Callable +from contextlib import suppress from datetime import datetime, timedelta from enum import Enum import logging from typing import Any, cast -from hass_nabucasa import Cloud +from hass_nabucasa import Cloud, NabuCasaBaseError import voluptuous as vol from homeassistant.components import alexa, google_assistant @@ -78,13 +79,16 @@ from .subscription import async_subscription_info DEFAULT_MODE = MODE_PROD PLATFORMS = [ - Platform.AI_TASK, Platform.BINARY_SENSOR, - Platform.CONVERSATION, Platform.STT, Platform.TTS, ] +LLM_PLATFORMS = [ + Platform.AI_TASK, + Platform.CONVERSATION, +] + SERVICE_REMOTE_CONNECT = "remote_connect" SERVICE_REMOTE_DISCONNECT = "remote_disconnect" @@ -431,7 +435,14 @@ def _handle_prefs_updated(hass: HomeAssistant, cloud: Cloud[CloudClient]) -> Non async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Set up a config entry.""" - await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS) + platforms = PLATFORMS.copy() + if (cloud := hass.data[DATA_CLOUD]).is_logged_in: + with suppress(NabuCasaBaseError): + await cloud.llm.async_ensure_token() + platforms += LLM_PLATFORMS + + await hass.config_entries.async_forward_entry_setups(entry, platforms) + entry.runtime_data = {"platforms": platforms} stt_tts_entities_added = hass.data[DATA_PLATFORMS_SETUP]["stt_tts_entities_added"] stt_tts_entities_added.set() @@ -440,7 +451,9 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Unload a config entry.""" - return await hass.config_entries.async_unload_platforms(entry, PLATFORMS) + return await hass.config_entries.async_unload_platforms( + entry, entry.runtime_data["platforms"] + ) @callback diff --git a/homeassistant/components/cloud/ai_task.py b/homeassistant/components/cloud/ai_task.py index ff57144805e..a92060db7b1 100644 --- a/homeassistant/components/cloud/ai_task.py +++ b/homeassistant/components/cloud/ai_task.py @@ -6,7 +6,6 @@ import io from json import JSONDecodeError import logging -from hass_nabucasa import NabuCasaBaseError from hass_nabucasa.llm import ( LLMAuthenticationError, LLMError, @@ -20,7 +19,7 @@ from PIL import Image from homeassistant.components import ai_task, conversation from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant -from homeassistant.exceptions import ConfigEntryAuthFailed, HomeAssistantError +from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback from homeassistant.util.json import json_loads @@ -94,17 +93,11 @@ async def async_setup_entry( async_add_entities: AddConfigEntryEntitiesCallback, ) -> None: """Set up Home Assistant Cloud AI Task entity.""" - if not (cloud := hass.data[DATA_CLOUD]).is_logged_in: - return - try: - await cloud.llm.async_ensure_token() - except (LLMError, NabuCasaBaseError): - return - - async_add_entities([CloudLLMTaskEntity(cloud, config_entry)]) + cloud = hass.data[DATA_CLOUD] + async_add_entities([CloudAITaskEntity(cloud, config_entry)]) -class CloudLLMTaskEntity(ai_task.AITaskEntity, BaseCloudLLMEntity): +class CloudAITaskEntity(BaseCloudLLMEntity, ai_task.AITaskEntity): """Home Assistant Cloud AI Task entity.""" _attr_has_entity_name = True @@ -181,7 +174,7 @@ class CloudLLMTaskEntity(ai_task.AITaskEntity, BaseCloudLLMEntity): attachments=attachments, ) except LLMAuthenticationError as err: - raise ConfigEntryAuthFailed("Cloud LLM authentication failed") from err + raise HomeAssistantError("Cloud LLM authentication failed") from err except LLMRateLimitError as err: raise HomeAssistantError("Cloud LLM is rate limited") from err except LLMResponseError as err: diff --git a/homeassistant/components/cloud/conversation.py b/homeassistant/components/cloud/conversation.py index c7f197a3923..06a11feef6e 100644 --- a/homeassistant/components/cloud/conversation.py +++ b/homeassistant/components/cloud/conversation.py @@ -4,9 +4,6 @@ from __future__ import annotations from typing import Literal -from hass_nabucasa import NabuCasaBaseError -from hass_nabucasa.llm import LLMError - from homeassistant.components import conversation from homeassistant.config_entries import ConfigEntry from homeassistant.const import MATCH_ALL @@ -24,19 +21,13 @@ async def async_setup_entry( async_add_entities: AddConfigEntryEntitiesCallback, ) -> None: """Set up the Home Assistant Cloud conversation entity.""" - if not (cloud := hass.data[DATA_CLOUD]).is_logged_in: - return - try: - await cloud.llm.async_ensure_token() - except (LLMError, NabuCasaBaseError): - return - + cloud = hass.data[DATA_CLOUD] async_add_entities([CloudConversationEntity(cloud, config_entry)]) class CloudConversationEntity( - conversation.ConversationEntity, BaseCloudLLMEntity, + conversation.ConversationEntity, ): """Home Assistant Cloud conversation agent.""" diff --git a/homeassistant/components/cloud/entity.py b/homeassistant/components/cloud/entity.py index 94eefbb1f14..f16e136804f 100644 --- a/homeassistant/components/cloud/entity.py +++ b/homeassistant/components/cloud/entity.py @@ -8,10 +8,9 @@ import logging import re from typing import Any, Literal, cast -from hass_nabucasa import Cloud +from hass_nabucasa import Cloud, NabuCasaBaseError from hass_nabucasa.llm import ( LLMAuthenticationError, - LLMError, LLMRateLimitError, LLMResponseError, LLMServiceError, @@ -37,7 +36,7 @@ from voluptuous_openapi import convert from homeassistant.components import conversation from homeassistant.config_entries import ConfigEntry -from homeassistant.exceptions import ConfigEntryAuthFailed, HomeAssistantError +from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import llm from homeassistant.helpers.entity import Entity from homeassistant.util import slugify @@ -601,14 +600,14 @@ class BaseCloudLLMEntity(Entity): ) except LLMAuthenticationError as err: - raise ConfigEntryAuthFailed("Cloud LLM authentication failed") from err + raise HomeAssistantError("Cloud LLM authentication failed") from err except LLMRateLimitError as err: raise HomeAssistantError("Cloud LLM is rate limited") from err except LLMResponseError as err: raise HomeAssistantError(str(err)) from err except LLMServiceError as err: raise HomeAssistantError("Error talking to Cloud LLM") from err - except LLMError as err: + except NabuCasaBaseError as err: raise HomeAssistantError(str(err)) from err if not chat_log.unresponded_tool_results: diff --git a/tests/components/cloud/conftest.py b/tests/components/cloud/conftest.py index 10d38c227f1..3b802d0af63 100644 --- a/tests/components/cloud/conftest.py +++ b/tests/components/cloud/conftest.py @@ -85,6 +85,7 @@ async def cloud_fixture() -> AsyncGenerator[MagicMock]: return_value=lambda: "mock-unregister" ), ) + mock_cloud.llm = MagicMock(async_ensure_token=AsyncMock()) def set_up_mock_cloud( cloud_client: CloudClient, mode: str, **kwargs: Any diff --git a/tests/components/cloud/snapshots/test_http_api.ambr b/tests/components/cloud/snapshots/test_http_api.ambr index ad8afbe695e..2876ba20eb8 100644 --- a/tests/components/cloud/snapshots/test_http_api.ambr +++ b/tests/components/cloud/snapshots/test_http_api.ambr @@ -21,7 +21,7 @@ ## Active Integrations - Built-in integrations: 19 + Built-in integrations: 21 Custom integrations: 1
Built-in integrations @@ -32,7 +32,9 @@ auth | Auth binary_sensor | Binary Sensor cloud | Home Assistant Cloud + cloud.ai_task | Unknown cloud.binary_sensor | Unknown + cloud.conversation | Unknown cloud.stt | Unknown cloud.tts | Unknown conversation | Conversation @@ -120,7 +122,7 @@ ## Active Integrations - Built-in integrations: 19 + Built-in integrations: 21 Custom integrations: 0
Built-in integrations @@ -131,7 +133,9 @@ auth | Auth binary_sensor | Binary Sensor cloud | Home Assistant Cloud + cloud.ai_task | Unknown cloud.binary_sensor | Unknown + cloud.conversation | Unknown cloud.stt | Unknown cloud.tts | Unknown conversation | Conversation diff --git a/tests/components/cloud/test_ai_task.py b/tests/components/cloud/test_ai_task.py index 308971be92b..461367cb2a9 100644 --- a/tests/components/cloud/test_ai_task.py +++ b/tests/components/cloud/test_ai_task.py @@ -19,20 +19,18 @@ import voluptuous as vol from homeassistant.components import ai_task, conversation from homeassistant.components.cloud.ai_task import ( - CloudLLMTaskEntity, + CloudAITaskEntity, async_prepare_image_generation_attachments, - async_setup_entry, ) -from homeassistant.components.cloud.const import DATA_CLOUD from homeassistant.core import HomeAssistant -from homeassistant.exceptions import ConfigEntryAuthFailed, HomeAssistantError +from homeassistant.exceptions import HomeAssistantError from tests.common import MockConfigEntry @pytest.fixture -def mock_cloud_ai_task_entity(hass: HomeAssistant) -> CloudLLMTaskEntity: - """Return a CloudLLMTaskEntity with a mocked cloud LLM.""" +def mock_cloud_ai_task_entity(hass: HomeAssistant) -> CloudAITaskEntity: + """Return a CloudAITaskEntity with a mocked cloud LLM.""" cloud = MagicMock() cloud.llm = MagicMock( async_generate_image=AsyncMock(), @@ -42,32 +40,17 @@ def mock_cloud_ai_task_entity(hass: HomeAssistant) -> CloudLLMTaskEntity: cloud.valid_subscription = True entry = MockConfigEntry(domain="cloud") entry.add_to_hass(hass) - entity = CloudLLMTaskEntity(cloud, entry) + entity = CloudAITaskEntity(cloud, entry) entity.entity_id = "ai_task.cloud_ai_task" entity.hass = hass return entity -async def test_setup_entry_skips_when_not_logged_in( - hass: HomeAssistant, -) -> None: - """Test setup_entry exits early when not logged in.""" - cloud = MagicMock() - cloud.is_logged_in = False - entry = MockConfigEntry(domain="cloud") - entry.add_to_hass(hass) - hass.data[DATA_CLOUD] = cloud - - async_add_entities = AsyncMock() - await async_setup_entry(hass, entry, async_add_entities) - async_add_entities.assert_not_called() - - @pytest.fixture(name="mock_handle_chat_log") def mock_handle_chat_log_fixture() -> AsyncMock: """Patch the chat log handler.""" with patch( - "homeassistant.components.cloud.ai_task.CloudLLMTaskEntity._async_handle_chat_log", + "homeassistant.components.cloud.ai_task.CloudAITaskEntity._async_handle_chat_log", AsyncMock(), ) as mock: yield mock @@ -171,7 +154,7 @@ async def test_prepare_image_generation_attachments_processing_error( async def test_generate_data_returns_text( hass: HomeAssistant, - mock_cloud_ai_task_entity: CloudLLMTaskEntity, + mock_cloud_ai_task_entity: CloudAITaskEntity, mock_handle_chat_log: AsyncMock, ) -> None: """Test generating plain text data.""" @@ -200,7 +183,7 @@ async def test_generate_data_returns_text( async def test_generate_data_returns_json( hass: HomeAssistant, - mock_cloud_ai_task_entity: CloudLLMTaskEntity, + mock_cloud_ai_task_entity: CloudAITaskEntity, mock_handle_chat_log: AsyncMock, ) -> None: """Test generating structured data.""" @@ -228,7 +211,7 @@ async def test_generate_data_returns_json( async def test_generate_data_invalid_json( hass: HomeAssistant, - mock_cloud_ai_task_entity: CloudLLMTaskEntity, + mock_cloud_ai_task_entity: CloudAITaskEntity, mock_handle_chat_log: AsyncMock, ) -> None: """Test invalid JSON responses raise an error.""" @@ -256,7 +239,7 @@ async def test_generate_data_invalid_json( async def test_generate_image_no_attachments( - hass: HomeAssistant, mock_cloud_ai_task_entity: CloudLLMTaskEntity + hass: HomeAssistant, mock_cloud_ai_task_entity: CloudAITaskEntity ) -> None: """Test generating an image without attachments.""" mock_cloud_ai_task_entity._cloud.llm.async_generate_image.return_value = { @@ -281,7 +264,7 @@ async def test_generate_image_no_attachments( async def test_generate_image_with_attachments( hass: HomeAssistant, - mock_cloud_ai_task_entity: CloudLLMTaskEntity, + mock_cloud_ai_task_entity: CloudAITaskEntity, mock_prepare_generation_attachments: AsyncMock, ) -> None: """Test generating an edited image when attachments are provided.""" @@ -319,7 +302,7 @@ async def test_generate_image_with_attachments( [ ( LLMAuthenticationError("auth"), - ConfigEntryAuthFailed, + HomeAssistantError, "Cloud LLM authentication failed", ), ( @@ -346,7 +329,7 @@ async def test_generate_image_with_attachments( ) async def test_generate_image_error_handling( hass: HomeAssistant, - mock_cloud_ai_task_entity: CloudLLMTaskEntity, + mock_cloud_ai_task_entity: CloudAITaskEntity, err: Exception, expected_exception: type[Exception], message: str, diff --git a/tests/components/cloud/test_conversation.py b/tests/components/cloud/test_conversation.py index df1b7e8deb7..5481294d1f0 100644 --- a/tests/components/cloud/test_conversation.py +++ b/tests/components/cloud/test_conversation.py @@ -4,15 +4,11 @@ from __future__ import annotations from unittest.mock import AsyncMock, MagicMock, patch -from hass_nabucasa.llm import LLMError import pytest from homeassistant.components import conversation -from homeassistant.components.cloud.const import DATA_CLOUD, DOMAIN -from homeassistant.components.cloud.conversation import ( - CloudConversationEntity, - async_setup_entry, -) +from homeassistant.components.cloud.const import DOMAIN +from homeassistant.components.cloud.conversation import CloudConversationEntity from homeassistant.core import Context, HomeAssistant from homeassistant.helpers import intent, llm @@ -34,21 +30,6 @@ def cloud_conversation_entity(hass: HomeAssistant) -> CloudConversationEntity: return entity -async def test_setup_entry_skips_when_not_logged_in( - hass: HomeAssistant, -) -> None: - """Test setup_entry exits early when not logged in.""" - cloud = MagicMock() - cloud.is_logged_in = False - entry = MockConfigEntry(domain="cloud") - entry.add_to_hass(hass) - hass.data[DATA_CLOUD] = cloud - - async_add_entities = AsyncMock() - await async_setup_entry(hass, entry, async_add_entities) - async_add_entities.assert_not_called() - - def test_entity_availability( cloud_conversation_entity: CloudConversationEntity, ) -> None: @@ -145,38 +126,3 @@ async def test_async_handle_message_converse_error( handle_chat_log.assert_not_called() assert result.response is error_response assert result.conversation_id == user_input.conversation_id - - -async def test_async_setup_entry_adds_entity(hass: HomeAssistant) -> None: - """Test the platform setup adds the conversation entity.""" - cloud = MagicMock() - cloud.llm = MagicMock(async_ensure_token=AsyncMock()) - cloud.is_logged_in = True - cloud.valid_subscription = True - hass.data[DATA_CLOUD] = cloud - entry = MockConfigEntry(domain=DOMAIN) - entry.add_to_hass(hass) - add_entities = MagicMock() - - await async_setup_entry(hass, entry, add_entities) - - cloud.llm.async_ensure_token.assert_awaited_once() - assert add_entities.call_count == 1 - assert isinstance(add_entities.call_args[0][0][0], CloudConversationEntity) - - -async def test_async_setup_entry_llm_error(hass: HomeAssistant) -> None: - """Test entity setup is aborted when ensuring the token fails.""" - cloud = MagicMock() - cloud.llm = MagicMock(async_ensure_token=AsyncMock(side_effect=LLMError("fail"))) - cloud.is_logged_in = True - cloud.valid_subscription = True - hass.data[DATA_CLOUD] = cloud - entry = MockConfigEntry(domain=DOMAIN) - entry.add_to_hass(hass) - add_entities = MagicMock() - - await async_setup_entry(hass, entry, add_entities) - - cloud.llm.async_ensure_token.assert_awaited_once() - add_entities.assert_not_called()