mirror of
https://github.com/Electric-Special/ha-core.git
synced 2026-03-21 05:06:13 +01:00
Don't register Home Assistant Cloud LLM platforms if not logged in (#157630)
This commit is contained in:
@@ -4,12 +4,13 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
from contextlib import suppress
|
||||
from datetime import datetime, timedelta
|
||||
from enum import Enum
|
||||
import logging
|
||||
from typing import Any, cast
|
||||
|
||||
from hass_nabucasa import Cloud
|
||||
from hass_nabucasa import Cloud, NabuCasaBaseError
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import alexa, google_assistant
|
||||
@@ -78,13 +79,16 @@ from .subscription import async_subscription_info
|
||||
DEFAULT_MODE = MODE_PROD
|
||||
|
||||
PLATFORMS = [
|
||||
Platform.AI_TASK,
|
||||
Platform.BINARY_SENSOR,
|
||||
Platform.CONVERSATION,
|
||||
Platform.STT,
|
||||
Platform.TTS,
|
||||
]
|
||||
|
||||
LLM_PLATFORMS = [
|
||||
Platform.AI_TASK,
|
||||
Platform.CONVERSATION,
|
||||
]
|
||||
|
||||
SERVICE_REMOTE_CONNECT = "remote_connect"
|
||||
SERVICE_REMOTE_DISCONNECT = "remote_disconnect"
|
||||
|
||||
@@ -431,7 +435,14 @@ def _handle_prefs_updated(hass: HomeAssistant, cloud: Cloud[CloudClient]) -> Non
|
||||
|
||||
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Set up a config entry."""
|
||||
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
|
||||
platforms = PLATFORMS.copy()
|
||||
if (cloud := hass.data[DATA_CLOUD]).is_logged_in:
|
||||
with suppress(NabuCasaBaseError):
|
||||
await cloud.llm.async_ensure_token()
|
||||
platforms += LLM_PLATFORMS
|
||||
|
||||
await hass.config_entries.async_forward_entry_setups(entry, platforms)
|
||||
entry.runtime_data = {"platforms": platforms}
|
||||
stt_tts_entities_added = hass.data[DATA_PLATFORMS_SETUP]["stt_tts_entities_added"]
|
||||
stt_tts_entities_added.set()
|
||||
|
||||
@@ -440,7 +451,9 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
|
||||
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Unload a config entry."""
|
||||
return await hass.config_entries.async_unload_platforms(entry, PLATFORMS)
|
||||
return await hass.config_entries.async_unload_platforms(
|
||||
entry, entry.runtime_data["platforms"]
|
||||
)
|
||||
|
||||
|
||||
@callback
|
||||
|
||||
@@ -6,7 +6,6 @@ import io
|
||||
from json import JSONDecodeError
|
||||
import logging
|
||||
|
||||
from hass_nabucasa import NabuCasaBaseError
|
||||
from hass_nabucasa.llm import (
|
||||
LLMAuthenticationError,
|
||||
LLMError,
|
||||
@@ -20,7 +19,7 @@ from PIL import Image
|
||||
from homeassistant.components import ai_task, conversation
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import ConfigEntryAuthFailed, HomeAssistantError
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
||||
from homeassistant.util.json import json_loads
|
||||
|
||||
@@ -94,17 +93,11 @@ async def async_setup_entry(
|
||||
async_add_entities: AddConfigEntryEntitiesCallback,
|
||||
) -> None:
|
||||
"""Set up Home Assistant Cloud AI Task entity."""
|
||||
if not (cloud := hass.data[DATA_CLOUD]).is_logged_in:
|
||||
return
|
||||
try:
|
||||
await cloud.llm.async_ensure_token()
|
||||
except (LLMError, NabuCasaBaseError):
|
||||
return
|
||||
|
||||
async_add_entities([CloudLLMTaskEntity(cloud, config_entry)])
|
||||
cloud = hass.data[DATA_CLOUD]
|
||||
async_add_entities([CloudAITaskEntity(cloud, config_entry)])
|
||||
|
||||
|
||||
class CloudLLMTaskEntity(ai_task.AITaskEntity, BaseCloudLLMEntity):
|
||||
class CloudAITaskEntity(BaseCloudLLMEntity, ai_task.AITaskEntity):
|
||||
"""Home Assistant Cloud AI Task entity."""
|
||||
|
||||
_attr_has_entity_name = True
|
||||
@@ -181,7 +174,7 @@ class CloudLLMTaskEntity(ai_task.AITaskEntity, BaseCloudLLMEntity):
|
||||
attachments=attachments,
|
||||
)
|
||||
except LLMAuthenticationError as err:
|
||||
raise ConfigEntryAuthFailed("Cloud LLM authentication failed") from err
|
||||
raise HomeAssistantError("Cloud LLM authentication failed") from err
|
||||
except LLMRateLimitError as err:
|
||||
raise HomeAssistantError("Cloud LLM is rate limited") from err
|
||||
except LLMResponseError as err:
|
||||
|
||||
@@ -4,9 +4,6 @@ from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from hass_nabucasa import NabuCasaBaseError
|
||||
from hass_nabucasa.llm import LLMError
|
||||
|
||||
from homeassistant.components import conversation
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import MATCH_ALL
|
||||
@@ -24,19 +21,13 @@ async def async_setup_entry(
|
||||
async_add_entities: AddConfigEntryEntitiesCallback,
|
||||
) -> None:
|
||||
"""Set up the Home Assistant Cloud conversation entity."""
|
||||
if not (cloud := hass.data[DATA_CLOUD]).is_logged_in:
|
||||
return
|
||||
try:
|
||||
await cloud.llm.async_ensure_token()
|
||||
except (LLMError, NabuCasaBaseError):
|
||||
return
|
||||
|
||||
cloud = hass.data[DATA_CLOUD]
|
||||
async_add_entities([CloudConversationEntity(cloud, config_entry)])
|
||||
|
||||
|
||||
class CloudConversationEntity(
|
||||
conversation.ConversationEntity,
|
||||
BaseCloudLLMEntity,
|
||||
conversation.ConversationEntity,
|
||||
):
|
||||
"""Home Assistant Cloud conversation agent."""
|
||||
|
||||
|
||||
@@ -8,10 +8,9 @@ import logging
|
||||
import re
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from hass_nabucasa import Cloud
|
||||
from hass_nabucasa import Cloud, NabuCasaBaseError
|
||||
from hass_nabucasa.llm import (
|
||||
LLMAuthenticationError,
|
||||
LLMError,
|
||||
LLMRateLimitError,
|
||||
LLMResponseError,
|
||||
LLMServiceError,
|
||||
@@ -37,7 +36,7 @@ from voluptuous_openapi import convert
|
||||
|
||||
from homeassistant.components import conversation
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.exceptions import ConfigEntryAuthFailed, HomeAssistantError
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import llm
|
||||
from homeassistant.helpers.entity import Entity
|
||||
from homeassistant.util import slugify
|
||||
@@ -601,14 +600,14 @@ class BaseCloudLLMEntity(Entity):
|
||||
)
|
||||
|
||||
except LLMAuthenticationError as err:
|
||||
raise ConfigEntryAuthFailed("Cloud LLM authentication failed") from err
|
||||
raise HomeAssistantError("Cloud LLM authentication failed") from err
|
||||
except LLMRateLimitError as err:
|
||||
raise HomeAssistantError("Cloud LLM is rate limited") from err
|
||||
except LLMResponseError as err:
|
||||
raise HomeAssistantError(str(err)) from err
|
||||
except LLMServiceError as err:
|
||||
raise HomeAssistantError("Error talking to Cloud LLM") from err
|
||||
except LLMError as err:
|
||||
except NabuCasaBaseError as err:
|
||||
raise HomeAssistantError(str(err)) from err
|
||||
|
||||
if not chat_log.unresponded_tool_results:
|
||||
|
||||
@@ -85,6 +85,7 @@ async def cloud_fixture() -> AsyncGenerator[MagicMock]:
|
||||
return_value=lambda: "mock-unregister"
|
||||
),
|
||||
)
|
||||
mock_cloud.llm = MagicMock(async_ensure_token=AsyncMock())
|
||||
|
||||
def set_up_mock_cloud(
|
||||
cloud_client: CloudClient, mode: str, **kwargs: Any
|
||||
|
||||
@@ -21,7 +21,7 @@
|
||||
|
||||
## Active Integrations
|
||||
|
||||
Built-in integrations: 19
|
||||
Built-in integrations: 21
|
||||
Custom integrations: 1
|
||||
|
||||
<details><summary>Built-in integrations</summary>
|
||||
@@ -32,7 +32,9 @@
|
||||
auth | Auth
|
||||
binary_sensor | Binary Sensor
|
||||
cloud | Home Assistant Cloud
|
||||
cloud.ai_task | Unknown
|
||||
cloud.binary_sensor | Unknown
|
||||
cloud.conversation | Unknown
|
||||
cloud.stt | Unknown
|
||||
cloud.tts | Unknown
|
||||
conversation | Conversation
|
||||
@@ -120,7 +122,7 @@
|
||||
|
||||
## Active Integrations
|
||||
|
||||
Built-in integrations: 19
|
||||
Built-in integrations: 21
|
||||
Custom integrations: 0
|
||||
|
||||
<details><summary>Built-in integrations</summary>
|
||||
@@ -131,7 +133,9 @@
|
||||
auth | Auth
|
||||
binary_sensor | Binary Sensor
|
||||
cloud | Home Assistant Cloud
|
||||
cloud.ai_task | Unknown
|
||||
cloud.binary_sensor | Unknown
|
||||
cloud.conversation | Unknown
|
||||
cloud.stt | Unknown
|
||||
cloud.tts | Unknown
|
||||
conversation | Conversation
|
||||
|
||||
@@ -19,20 +19,18 @@ import voluptuous as vol
|
||||
|
||||
from homeassistant.components import ai_task, conversation
|
||||
from homeassistant.components.cloud.ai_task import (
|
||||
CloudLLMTaskEntity,
|
||||
CloudAITaskEntity,
|
||||
async_prepare_image_generation_attachments,
|
||||
async_setup_entry,
|
||||
)
|
||||
from homeassistant.components.cloud.const import DATA_CLOUD
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import ConfigEntryAuthFailed, HomeAssistantError
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_cloud_ai_task_entity(hass: HomeAssistant) -> CloudLLMTaskEntity:
|
||||
"""Return a CloudLLMTaskEntity with a mocked cloud LLM."""
|
||||
def mock_cloud_ai_task_entity(hass: HomeAssistant) -> CloudAITaskEntity:
|
||||
"""Return a CloudAITaskEntity with a mocked cloud LLM."""
|
||||
cloud = MagicMock()
|
||||
cloud.llm = MagicMock(
|
||||
async_generate_image=AsyncMock(),
|
||||
@@ -42,32 +40,17 @@ def mock_cloud_ai_task_entity(hass: HomeAssistant) -> CloudLLMTaskEntity:
|
||||
cloud.valid_subscription = True
|
||||
entry = MockConfigEntry(domain="cloud")
|
||||
entry.add_to_hass(hass)
|
||||
entity = CloudLLMTaskEntity(cloud, entry)
|
||||
entity = CloudAITaskEntity(cloud, entry)
|
||||
entity.entity_id = "ai_task.cloud_ai_task"
|
||||
entity.hass = hass
|
||||
return entity
|
||||
|
||||
|
||||
async def test_setup_entry_skips_when_not_logged_in(
|
||||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test setup_entry exits early when not logged in."""
|
||||
cloud = MagicMock()
|
||||
cloud.is_logged_in = False
|
||||
entry = MockConfigEntry(domain="cloud")
|
||||
entry.add_to_hass(hass)
|
||||
hass.data[DATA_CLOUD] = cloud
|
||||
|
||||
async_add_entities = AsyncMock()
|
||||
await async_setup_entry(hass, entry, async_add_entities)
|
||||
async_add_entities.assert_not_called()
|
||||
|
||||
|
||||
@pytest.fixture(name="mock_handle_chat_log")
|
||||
def mock_handle_chat_log_fixture() -> AsyncMock:
|
||||
"""Patch the chat log handler."""
|
||||
with patch(
|
||||
"homeassistant.components.cloud.ai_task.CloudLLMTaskEntity._async_handle_chat_log",
|
||||
"homeassistant.components.cloud.ai_task.CloudAITaskEntity._async_handle_chat_log",
|
||||
AsyncMock(),
|
||||
) as mock:
|
||||
yield mock
|
||||
@@ -171,7 +154,7 @@ async def test_prepare_image_generation_attachments_processing_error(
|
||||
|
||||
async def test_generate_data_returns_text(
|
||||
hass: HomeAssistant,
|
||||
mock_cloud_ai_task_entity: CloudLLMTaskEntity,
|
||||
mock_cloud_ai_task_entity: CloudAITaskEntity,
|
||||
mock_handle_chat_log: AsyncMock,
|
||||
) -> None:
|
||||
"""Test generating plain text data."""
|
||||
@@ -200,7 +183,7 @@ async def test_generate_data_returns_text(
|
||||
|
||||
async def test_generate_data_returns_json(
|
||||
hass: HomeAssistant,
|
||||
mock_cloud_ai_task_entity: CloudLLMTaskEntity,
|
||||
mock_cloud_ai_task_entity: CloudAITaskEntity,
|
||||
mock_handle_chat_log: AsyncMock,
|
||||
) -> None:
|
||||
"""Test generating structured data."""
|
||||
@@ -228,7 +211,7 @@ async def test_generate_data_returns_json(
|
||||
|
||||
async def test_generate_data_invalid_json(
|
||||
hass: HomeAssistant,
|
||||
mock_cloud_ai_task_entity: CloudLLMTaskEntity,
|
||||
mock_cloud_ai_task_entity: CloudAITaskEntity,
|
||||
mock_handle_chat_log: AsyncMock,
|
||||
) -> None:
|
||||
"""Test invalid JSON responses raise an error."""
|
||||
@@ -256,7 +239,7 @@ async def test_generate_data_invalid_json(
|
||||
|
||||
|
||||
async def test_generate_image_no_attachments(
|
||||
hass: HomeAssistant, mock_cloud_ai_task_entity: CloudLLMTaskEntity
|
||||
hass: HomeAssistant, mock_cloud_ai_task_entity: CloudAITaskEntity
|
||||
) -> None:
|
||||
"""Test generating an image without attachments."""
|
||||
mock_cloud_ai_task_entity._cloud.llm.async_generate_image.return_value = {
|
||||
@@ -281,7 +264,7 @@ async def test_generate_image_no_attachments(
|
||||
|
||||
async def test_generate_image_with_attachments(
|
||||
hass: HomeAssistant,
|
||||
mock_cloud_ai_task_entity: CloudLLMTaskEntity,
|
||||
mock_cloud_ai_task_entity: CloudAITaskEntity,
|
||||
mock_prepare_generation_attachments: AsyncMock,
|
||||
) -> None:
|
||||
"""Test generating an edited image when attachments are provided."""
|
||||
@@ -319,7 +302,7 @@ async def test_generate_image_with_attachments(
|
||||
[
|
||||
(
|
||||
LLMAuthenticationError("auth"),
|
||||
ConfigEntryAuthFailed,
|
||||
HomeAssistantError,
|
||||
"Cloud LLM authentication failed",
|
||||
),
|
||||
(
|
||||
@@ -346,7 +329,7 @@ async def test_generate_image_with_attachments(
|
||||
)
|
||||
async def test_generate_image_error_handling(
|
||||
hass: HomeAssistant,
|
||||
mock_cloud_ai_task_entity: CloudLLMTaskEntity,
|
||||
mock_cloud_ai_task_entity: CloudAITaskEntity,
|
||||
err: Exception,
|
||||
expected_exception: type[Exception],
|
||||
message: str,
|
||||
|
||||
@@ -4,15 +4,11 @@ from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from hass_nabucasa.llm import LLMError
|
||||
import pytest
|
||||
|
||||
from homeassistant.components import conversation
|
||||
from homeassistant.components.cloud.const import DATA_CLOUD, DOMAIN
|
||||
from homeassistant.components.cloud.conversation import (
|
||||
CloudConversationEntity,
|
||||
async_setup_entry,
|
||||
)
|
||||
from homeassistant.components.cloud.const import DOMAIN
|
||||
from homeassistant.components.cloud.conversation import CloudConversationEntity
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.helpers import intent, llm
|
||||
|
||||
@@ -34,21 +30,6 @@ def cloud_conversation_entity(hass: HomeAssistant) -> CloudConversationEntity:
|
||||
return entity
|
||||
|
||||
|
||||
async def test_setup_entry_skips_when_not_logged_in(
|
||||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test setup_entry exits early when not logged in."""
|
||||
cloud = MagicMock()
|
||||
cloud.is_logged_in = False
|
||||
entry = MockConfigEntry(domain="cloud")
|
||||
entry.add_to_hass(hass)
|
||||
hass.data[DATA_CLOUD] = cloud
|
||||
|
||||
async_add_entities = AsyncMock()
|
||||
await async_setup_entry(hass, entry, async_add_entities)
|
||||
async_add_entities.assert_not_called()
|
||||
|
||||
|
||||
def test_entity_availability(
|
||||
cloud_conversation_entity: CloudConversationEntity,
|
||||
) -> None:
|
||||
@@ -145,38 +126,3 @@ async def test_async_handle_message_converse_error(
|
||||
handle_chat_log.assert_not_called()
|
||||
assert result.response is error_response
|
||||
assert result.conversation_id == user_input.conversation_id
|
||||
|
||||
|
||||
async def test_async_setup_entry_adds_entity(hass: HomeAssistant) -> None:
|
||||
"""Test the platform setup adds the conversation entity."""
|
||||
cloud = MagicMock()
|
||||
cloud.llm = MagicMock(async_ensure_token=AsyncMock())
|
||||
cloud.is_logged_in = True
|
||||
cloud.valid_subscription = True
|
||||
hass.data[DATA_CLOUD] = cloud
|
||||
entry = MockConfigEntry(domain=DOMAIN)
|
||||
entry.add_to_hass(hass)
|
||||
add_entities = MagicMock()
|
||||
|
||||
await async_setup_entry(hass, entry, add_entities)
|
||||
|
||||
cloud.llm.async_ensure_token.assert_awaited_once()
|
||||
assert add_entities.call_count == 1
|
||||
assert isinstance(add_entities.call_args[0][0][0], CloudConversationEntity)
|
||||
|
||||
|
||||
async def test_async_setup_entry_llm_error(hass: HomeAssistant) -> None:
|
||||
"""Test entity setup is aborted when ensuring the token fails."""
|
||||
cloud = MagicMock()
|
||||
cloud.llm = MagicMock(async_ensure_token=AsyncMock(side_effect=LLMError("fail")))
|
||||
cloud.is_logged_in = True
|
||||
cloud.valid_subscription = True
|
||||
hass.data[DATA_CLOUD] = cloud
|
||||
entry = MockConfigEntry(domain=DOMAIN)
|
||||
entry.add_to_hass(hass)
|
||||
add_entities = MagicMock()
|
||||
|
||||
await async_setup_entry(hass, entry, add_entities)
|
||||
|
||||
cloud.llm.async_ensure_token.assert_awaited_once()
|
||||
add_entities.assert_not_called()
|
||||
|
||||
Reference in New Issue
Block a user