Don't register Home Assistant Cloud LLM platforms if not logged in (#157630)

This commit is contained in:
victorigualada
2025-12-02 17:47:08 +01:00
committed by GitHub
parent ae8980ce5b
commit 84d2686517
8 changed files with 51 additions and 121 deletions

View File

@@ -4,12 +4,13 @@ from __future__ import annotations
import asyncio
from collections.abc import Awaitable, Callable
from contextlib import suppress
from datetime import datetime, timedelta
from enum import Enum
import logging
from typing import Any, cast
from hass_nabucasa import Cloud
from hass_nabucasa import Cloud, NabuCasaBaseError
import voluptuous as vol
from homeassistant.components import alexa, google_assistant
@@ -78,13 +79,16 @@ from .subscription import async_subscription_info
DEFAULT_MODE = MODE_PROD
PLATFORMS = [
Platform.AI_TASK,
Platform.BINARY_SENSOR,
Platform.CONVERSATION,
Platform.STT,
Platform.TTS,
]
LLM_PLATFORMS = [
Platform.AI_TASK,
Platform.CONVERSATION,
]
SERVICE_REMOTE_CONNECT = "remote_connect"
SERVICE_REMOTE_DISCONNECT = "remote_disconnect"
@@ -431,7 +435,14 @@ def _handle_prefs_updated(hass: HomeAssistant, cloud: Cloud[CloudClient]) -> Non
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Set up a config entry."""
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
platforms = PLATFORMS.copy()
if (cloud := hass.data[DATA_CLOUD]).is_logged_in:
with suppress(NabuCasaBaseError):
await cloud.llm.async_ensure_token()
platforms += LLM_PLATFORMS
await hass.config_entries.async_forward_entry_setups(entry, platforms)
entry.runtime_data = {"platforms": platforms}
stt_tts_entities_added = hass.data[DATA_PLATFORMS_SETUP]["stt_tts_entities_added"]
stt_tts_entities_added.set()
@@ -440,7 +451,9 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload a config entry."""
return await hass.config_entries.async_unload_platforms(entry, PLATFORMS)
return await hass.config_entries.async_unload_platforms(
entry, entry.runtime_data["platforms"]
)
@callback

View File

@@ -6,7 +6,6 @@ import io
from json import JSONDecodeError
import logging
from hass_nabucasa import NabuCasaBaseError
from hass_nabucasa.llm import (
LLMAuthenticationError,
LLMError,
@@ -20,7 +19,7 @@ from PIL import Image
from homeassistant.components import ai_task, conversation
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ConfigEntryAuthFailed, HomeAssistantError
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
from homeassistant.util.json import json_loads
@@ -94,17 +93,11 @@ async def async_setup_entry(
async_add_entities: AddConfigEntryEntitiesCallback,
) -> None:
"""Set up Home Assistant Cloud AI Task entity."""
if not (cloud := hass.data[DATA_CLOUD]).is_logged_in:
return
try:
await cloud.llm.async_ensure_token()
except (LLMError, NabuCasaBaseError):
return
async_add_entities([CloudLLMTaskEntity(cloud, config_entry)])
cloud = hass.data[DATA_CLOUD]
async_add_entities([CloudAITaskEntity(cloud, config_entry)])
class CloudLLMTaskEntity(ai_task.AITaskEntity, BaseCloudLLMEntity):
class CloudAITaskEntity(BaseCloudLLMEntity, ai_task.AITaskEntity):
"""Home Assistant Cloud AI Task entity."""
_attr_has_entity_name = True
@@ -181,7 +174,7 @@ class CloudLLMTaskEntity(ai_task.AITaskEntity, BaseCloudLLMEntity):
attachments=attachments,
)
except LLMAuthenticationError as err:
raise ConfigEntryAuthFailed("Cloud LLM authentication failed") from err
raise HomeAssistantError("Cloud LLM authentication failed") from err
except LLMRateLimitError as err:
raise HomeAssistantError("Cloud LLM is rate limited") from err
except LLMResponseError as err:

View File

@@ -4,9 +4,6 @@ from __future__ import annotations
from typing import Literal
from hass_nabucasa import NabuCasaBaseError
from hass_nabucasa.llm import LLMError
from homeassistant.components import conversation
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import MATCH_ALL
@@ -24,19 +21,13 @@ async def async_setup_entry(
async_add_entities: AddConfigEntryEntitiesCallback,
) -> None:
"""Set up the Home Assistant Cloud conversation entity."""
if not (cloud := hass.data[DATA_CLOUD]).is_logged_in:
return
try:
await cloud.llm.async_ensure_token()
except (LLMError, NabuCasaBaseError):
return
cloud = hass.data[DATA_CLOUD]
async_add_entities([CloudConversationEntity(cloud, config_entry)])
class CloudConversationEntity(
conversation.ConversationEntity,
BaseCloudLLMEntity,
conversation.ConversationEntity,
):
"""Home Assistant Cloud conversation agent."""

View File

@@ -8,10 +8,9 @@ import logging
import re
from typing import Any, Literal, cast
from hass_nabucasa import Cloud
from hass_nabucasa import Cloud, NabuCasaBaseError
from hass_nabucasa.llm import (
LLMAuthenticationError,
LLMError,
LLMRateLimitError,
LLMResponseError,
LLMServiceError,
@@ -37,7 +36,7 @@ from voluptuous_openapi import convert
from homeassistant.components import conversation
from homeassistant.config_entries import ConfigEntry
from homeassistant.exceptions import ConfigEntryAuthFailed, HomeAssistantError
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import llm
from homeassistant.helpers.entity import Entity
from homeassistant.util import slugify
@@ -601,14 +600,14 @@ class BaseCloudLLMEntity(Entity):
)
except LLMAuthenticationError as err:
raise ConfigEntryAuthFailed("Cloud LLM authentication failed") from err
raise HomeAssistantError("Cloud LLM authentication failed") from err
except LLMRateLimitError as err:
raise HomeAssistantError("Cloud LLM is rate limited") from err
except LLMResponseError as err:
raise HomeAssistantError(str(err)) from err
except LLMServiceError as err:
raise HomeAssistantError("Error talking to Cloud LLM") from err
except LLMError as err:
except NabuCasaBaseError as err:
raise HomeAssistantError(str(err)) from err
if not chat_log.unresponded_tool_results:

View File

@@ -85,6 +85,7 @@ async def cloud_fixture() -> AsyncGenerator[MagicMock]:
return_value=lambda: "mock-unregister"
),
)
mock_cloud.llm = MagicMock(async_ensure_token=AsyncMock())
def set_up_mock_cloud(
cloud_client: CloudClient, mode: str, **kwargs: Any

View File

@@ -21,7 +21,7 @@
## Active Integrations
Built-in integrations: 19
Built-in integrations: 21
Custom integrations: 1
<details><summary>Built-in integrations</summary>
@@ -32,7 +32,9 @@
auth | Auth
binary_sensor | Binary Sensor
cloud | Home Assistant Cloud
cloud.ai_task | Unknown
cloud.binary_sensor | Unknown
cloud.conversation | Unknown
cloud.stt | Unknown
cloud.tts | Unknown
conversation | Conversation
@@ -120,7 +122,7 @@
## Active Integrations
Built-in integrations: 19
Built-in integrations: 21
Custom integrations: 0
<details><summary>Built-in integrations</summary>
@@ -131,7 +133,9 @@
auth | Auth
binary_sensor | Binary Sensor
cloud | Home Assistant Cloud
cloud.ai_task | Unknown
cloud.binary_sensor | Unknown
cloud.conversation | Unknown
cloud.stt | Unknown
cloud.tts | Unknown
conversation | Conversation

View File

@@ -19,20 +19,18 @@ import voluptuous as vol
from homeassistant.components import ai_task, conversation
from homeassistant.components.cloud.ai_task import (
CloudLLMTaskEntity,
CloudAITaskEntity,
async_prepare_image_generation_attachments,
async_setup_entry,
)
from homeassistant.components.cloud.const import DATA_CLOUD
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ConfigEntryAuthFailed, HomeAssistantError
from homeassistant.exceptions import HomeAssistantError
from tests.common import MockConfigEntry
@pytest.fixture
def mock_cloud_ai_task_entity(hass: HomeAssistant) -> CloudLLMTaskEntity:
"""Return a CloudLLMTaskEntity with a mocked cloud LLM."""
def mock_cloud_ai_task_entity(hass: HomeAssistant) -> CloudAITaskEntity:
"""Return a CloudAITaskEntity with a mocked cloud LLM."""
cloud = MagicMock()
cloud.llm = MagicMock(
async_generate_image=AsyncMock(),
@@ -42,32 +40,17 @@ def mock_cloud_ai_task_entity(hass: HomeAssistant) -> CloudLLMTaskEntity:
cloud.valid_subscription = True
entry = MockConfigEntry(domain="cloud")
entry.add_to_hass(hass)
entity = CloudLLMTaskEntity(cloud, entry)
entity = CloudAITaskEntity(cloud, entry)
entity.entity_id = "ai_task.cloud_ai_task"
entity.hass = hass
return entity
async def test_setup_entry_skips_when_not_logged_in(
hass: HomeAssistant,
) -> None:
"""Test setup_entry exits early when not logged in."""
cloud = MagicMock()
cloud.is_logged_in = False
entry = MockConfigEntry(domain="cloud")
entry.add_to_hass(hass)
hass.data[DATA_CLOUD] = cloud
async_add_entities = AsyncMock()
await async_setup_entry(hass, entry, async_add_entities)
async_add_entities.assert_not_called()
@pytest.fixture(name="mock_handle_chat_log")
def mock_handle_chat_log_fixture() -> AsyncMock:
"""Patch the chat log handler."""
with patch(
"homeassistant.components.cloud.ai_task.CloudLLMTaskEntity._async_handle_chat_log",
"homeassistant.components.cloud.ai_task.CloudAITaskEntity._async_handle_chat_log",
AsyncMock(),
) as mock:
yield mock
@@ -171,7 +154,7 @@ async def test_prepare_image_generation_attachments_processing_error(
async def test_generate_data_returns_text(
hass: HomeAssistant,
mock_cloud_ai_task_entity: CloudLLMTaskEntity,
mock_cloud_ai_task_entity: CloudAITaskEntity,
mock_handle_chat_log: AsyncMock,
) -> None:
"""Test generating plain text data."""
@@ -200,7 +183,7 @@ async def test_generate_data_returns_text(
async def test_generate_data_returns_json(
hass: HomeAssistant,
mock_cloud_ai_task_entity: CloudLLMTaskEntity,
mock_cloud_ai_task_entity: CloudAITaskEntity,
mock_handle_chat_log: AsyncMock,
) -> None:
"""Test generating structured data."""
@@ -228,7 +211,7 @@ async def test_generate_data_returns_json(
async def test_generate_data_invalid_json(
hass: HomeAssistant,
mock_cloud_ai_task_entity: CloudLLMTaskEntity,
mock_cloud_ai_task_entity: CloudAITaskEntity,
mock_handle_chat_log: AsyncMock,
) -> None:
"""Test invalid JSON responses raise an error."""
@@ -256,7 +239,7 @@ async def test_generate_data_invalid_json(
async def test_generate_image_no_attachments(
hass: HomeAssistant, mock_cloud_ai_task_entity: CloudLLMTaskEntity
hass: HomeAssistant, mock_cloud_ai_task_entity: CloudAITaskEntity
) -> None:
"""Test generating an image without attachments."""
mock_cloud_ai_task_entity._cloud.llm.async_generate_image.return_value = {
@@ -281,7 +264,7 @@ async def test_generate_image_no_attachments(
async def test_generate_image_with_attachments(
hass: HomeAssistant,
mock_cloud_ai_task_entity: CloudLLMTaskEntity,
mock_cloud_ai_task_entity: CloudAITaskEntity,
mock_prepare_generation_attachments: AsyncMock,
) -> None:
"""Test generating an edited image when attachments are provided."""
@@ -319,7 +302,7 @@ async def test_generate_image_with_attachments(
[
(
LLMAuthenticationError("auth"),
ConfigEntryAuthFailed,
HomeAssistantError,
"Cloud LLM authentication failed",
),
(
@@ -346,7 +329,7 @@ async def test_generate_image_with_attachments(
)
async def test_generate_image_error_handling(
hass: HomeAssistant,
mock_cloud_ai_task_entity: CloudLLMTaskEntity,
mock_cloud_ai_task_entity: CloudAITaskEntity,
err: Exception,
expected_exception: type[Exception],
message: str,

View File

@@ -4,15 +4,11 @@ from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock, patch
from hass_nabucasa.llm import LLMError
import pytest
from homeassistant.components import conversation
from homeassistant.components.cloud.const import DATA_CLOUD, DOMAIN
from homeassistant.components.cloud.conversation import (
CloudConversationEntity,
async_setup_entry,
)
from homeassistant.components.cloud.const import DOMAIN
from homeassistant.components.cloud.conversation import CloudConversationEntity
from homeassistant.core import Context, HomeAssistant
from homeassistant.helpers import intent, llm
@@ -34,21 +30,6 @@ def cloud_conversation_entity(hass: HomeAssistant) -> CloudConversationEntity:
return entity
async def test_setup_entry_skips_when_not_logged_in(
hass: HomeAssistant,
) -> None:
"""Test setup_entry exits early when not logged in."""
cloud = MagicMock()
cloud.is_logged_in = False
entry = MockConfigEntry(domain="cloud")
entry.add_to_hass(hass)
hass.data[DATA_CLOUD] = cloud
async_add_entities = AsyncMock()
await async_setup_entry(hass, entry, async_add_entities)
async_add_entities.assert_not_called()
def test_entity_availability(
cloud_conversation_entity: CloudConversationEntity,
) -> None:
@@ -145,38 +126,3 @@ async def test_async_handle_message_converse_error(
handle_chat_log.assert_not_called()
assert result.response is error_response
assert result.conversation_id == user_input.conversation_id
async def test_async_setup_entry_adds_entity(hass: HomeAssistant) -> None:
"""Test the platform setup adds the conversation entity."""
cloud = MagicMock()
cloud.llm = MagicMock(async_ensure_token=AsyncMock())
cloud.is_logged_in = True
cloud.valid_subscription = True
hass.data[DATA_CLOUD] = cloud
entry = MockConfigEntry(domain=DOMAIN)
entry.add_to_hass(hass)
add_entities = MagicMock()
await async_setup_entry(hass, entry, add_entities)
cloud.llm.async_ensure_token.assert_awaited_once()
assert add_entities.call_count == 1
assert isinstance(add_entities.call_args[0][0][0], CloudConversationEntity)
async def test_async_setup_entry_llm_error(hass: HomeAssistant) -> None:
"""Test entity setup is aborted when ensuring the token fails."""
cloud = MagicMock()
cloud.llm = MagicMock(async_ensure_token=AsyncMock(side_effect=LLMError("fail")))
cloud.is_logged_in = True
cloud.valid_subscription = True
hass.data[DATA_CLOUD] = cloud
entry = MockConfigEntry(domain=DOMAIN)
entry.add_to_hass(hass)
add_entities = MagicMock()
await async_setup_entry(hass, entry, add_entities)
cloud.llm.async_ensure_token.assert_awaited_once()
add_entities.assert_not_called()