Add HomeAssistant Cloud ai_task (#157015)

This commit is contained in:
victorigualada
2025-11-25 17:01:32 +01:00
committed by GitHub
parent 242c02890f
commit 8f1abb6dbb
9 changed files with 1331 additions and 3 deletions

View File

@@ -77,7 +77,12 @@ from .subscription import async_subscription_info
DEFAULT_MODE = MODE_PROD
PLATFORMS = [Platform.BINARY_SENSOR, Platform.STT, Platform.TTS]
PLATFORMS = [
Platform.AI_TASK,
Platform.BINARY_SENSOR,
Platform.STT,
Platform.TTS,
]
SERVICE_REMOTE_CONNECT = "remote_connect"
SERVICE_REMOTE_DISCONNECT = "remote_disconnect"

View File

@@ -0,0 +1,200 @@
"""AI Task integration for Home Assistant Cloud."""
from __future__ import annotations
import io
from json import JSONDecodeError
import logging
from hass_nabucasa.llm import (
LLMAuthenticationError,
LLMError,
LLMImageAttachment,
LLMRateLimitError,
LLMResponseError,
LLMServiceError,
)
from PIL import Image
from homeassistant.components import ai_task, conversation
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ConfigEntryAuthFailed, HomeAssistantError
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
from homeassistant.util.json import json_loads
from .const import AI_TASK_ENTITY_UNIQUE_ID, DATA_CLOUD
from .entity import BaseCloudLLMEntity
_LOGGER = logging.getLogger(__name__)
def _convert_image_for_editing(data: bytes) -> tuple[bytes, str]:
"""Ensure the image data is in a format accepted by OpenAI image edits."""
stream = io.BytesIO(data)
with Image.open(stream) as img:
mode = img.mode
if mode not in ("RGBA", "LA", "L"):
img = img.convert("RGBA")
output = io.BytesIO()
if img.mode in ("RGBA", "LA", "L"):
img.save(output, format="PNG")
return output.getvalue(), "image/png"
img.save(output, format=img.format or "PNG")
return output.getvalue(), f"image/{(img.format or 'png').lower()}"
async def async_prepare_image_generation_attachments(
hass: HomeAssistant, attachments: list[conversation.Attachment]
) -> list[LLMImageAttachment]:
"""Load attachment data for image generation."""
def prepare() -> list[LLMImageAttachment]:
items: list[LLMImageAttachment] = []
for attachment in attachments:
if not attachment.mime_type or not attachment.mime_type.startswith(
"image/"
):
raise HomeAssistantError(
"Only image attachments are supported for image generation"
)
path = attachment.path
if not path.exists():
raise HomeAssistantError(f"`{path}` does not exist")
data = path.read_bytes()
mime_type = attachment.mime_type
try:
data, mime_type = _convert_image_for_editing(data)
except HomeAssistantError:
raise
except Exception as err:
raise HomeAssistantError("Failed to process image attachment") from err
items.append(
LLMImageAttachment(
filename=path.name,
mime_type=mime_type,
data=data,
)
)
return items
return await hass.async_add_executor_job(prepare)
async def async_setup_entry(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddConfigEntryEntitiesCallback,
) -> None:
"""Set up Home Assistant Cloud AI Task entity."""
cloud = hass.data[DATA_CLOUD]
try:
await cloud.llm.async_ensure_token()
except LLMError:
return
async_add_entities([CloudLLMTaskEntity(cloud, config_entry)])
class CloudLLMTaskEntity(ai_task.AITaskEntity, BaseCloudLLMEntity):
"""Home Assistant Cloud AI Task entity."""
_attr_has_entity_name = True
_attr_supported_features = (
ai_task.AITaskEntityFeature.GENERATE_DATA
| ai_task.AITaskEntityFeature.GENERATE_IMAGE
| ai_task.AITaskEntityFeature.SUPPORT_ATTACHMENTS
)
_attr_translation_key = "cloud_ai"
_attr_unique_id = AI_TASK_ENTITY_UNIQUE_ID
@property
def available(self) -> bool:
"""Return if the entity is available."""
return self._cloud.is_logged_in and self._cloud.valid_subscription
async def _async_generate_data(
self,
task: ai_task.GenDataTask,
chat_log: conversation.ChatLog,
) -> ai_task.GenDataTaskResult:
"""Handle a generate data task."""
await self._async_handle_chat_log(
"ai_task", chat_log, task.name, task.structure
)
if not isinstance(chat_log.content[-1], conversation.AssistantContent):
raise HomeAssistantError(
"Last content in chat log is not an AssistantContent"
)
text = chat_log.content[-1].content or ""
if not task.structure:
return ai_task.GenDataTaskResult(
conversation_id=chat_log.conversation_id,
data=text,
)
try:
data = json_loads(text)
except JSONDecodeError as err:
_LOGGER.error(
"Failed to parse JSON response: %s. Response: %s",
err,
text,
)
raise HomeAssistantError("Error with OpenAI structured response") from err
return ai_task.GenDataTaskResult(
conversation_id=chat_log.conversation_id,
data=data,
)
async def _async_generate_image(
self,
task: ai_task.GenImageTask,
chat_log: conversation.ChatLog,
) -> ai_task.GenImageTaskResult:
"""Handle a generate image task."""
attachments: list[LLMImageAttachment] | None = None
if task.attachments:
attachments = await async_prepare_image_generation_attachments(
self.hass, task.attachments
)
try:
if attachments is None:
image = await self._cloud.llm.async_generate_image(
prompt=task.instructions,
)
else:
image = await self._cloud.llm.async_edit_image(
prompt=task.instructions,
attachments=attachments,
)
except LLMAuthenticationError as err:
raise ConfigEntryAuthFailed("Cloud LLM authentication failed") from err
except LLMRateLimitError as err:
raise HomeAssistantError("Cloud LLM is rate limited") from err
except LLMResponseError as err:
raise HomeAssistantError(str(err)) from err
except LLMServiceError as err:
raise HomeAssistantError("Error talking to Cloud LLM") from err
except LLMError as err:
raise HomeAssistantError(str(err)) from err
return ai_task.GenImageTaskResult(
conversation_id=chat_log.conversation_id,
mime_type=image["mime_type"],
image_data=image["image_data"],
model=image.get("model"),
width=image.get("width"),
height=image.get("height"),
revised_prompt=image.get("revised_prompt"),
)

View File

@@ -91,6 +91,7 @@ DISPATCHER_REMOTE_UPDATE: SignalType[Any] = SignalType("cloud_remote_update")
STT_ENTITY_UNIQUE_ID = "cloud-speech-to-text"
TTS_ENTITY_UNIQUE_ID = "cloud-text-to-speech"
AI_TASK_ENTITY_UNIQUE_ID = "cloud-ai-task"
LOGIN_MFA_TIMEOUT = 60

View File

@@ -0,0 +1,543 @@
"""Helpers for cloud LLM chat handling."""
import base64
from collections.abc import AsyncGenerator, Callable
from enum import Enum
import json
import logging
import re
from typing import Any, Literal, cast
from hass_nabucasa import Cloud
from hass_nabucasa.llm import (
LLMAuthenticationError,
LLMError,
LLMRateLimitError,
LLMResponseError,
LLMServiceError,
)
from litellm import ResponseFunctionToolCall, ResponsesAPIStreamEvents
from openai.types.responses import (
FunctionToolParam,
ResponseReasoningItem,
ToolParam,
WebSearchToolParam,
)
import voluptuous as vol
from voluptuous_openapi import convert
from homeassistant.components import conversation
from homeassistant.config_entries import ConfigEntry
from homeassistant.exceptions import ConfigEntryAuthFailed, HomeAssistantError
from homeassistant.helpers import llm
from homeassistant.helpers.entity import Entity
from homeassistant.util import slugify
from .client import CloudClient
_LOGGER = logging.getLogger(__name__)
_MAX_TOOL_ITERATIONS = 10
class ResponseItemType(str, Enum):
"""Response item types."""
FUNCTION_CALL = "function_call"
MESSAGE = "message"
REASONING = "reasoning"
WEB_SEARCH_CALL = "web_search_call"
IMAGE = "image"
def _convert_content_to_chat_message(
content: conversation.Content,
) -> dict[str, Any] | None:
"""Convert ChatLog content to a responses message."""
if content.role not in ("user", "system", "tool_result", "assistant"):
return None
text_content = cast(
conversation.SystemContent
| conversation.UserContent
| conversation.AssistantContent,
content,
)
if not text_content.content:
return None
content_type = "output_text" if text_content.role == "assistant" else "input_text"
return {
"role": text_content.role,
"content": [
{
"type": content_type,
"text": text_content.content,
}
],
}
def _format_tool(
tool: llm.Tool,
custom_serializer: Callable[[Any], Any] | None,
) -> ToolParam:
"""Format a Home Assistant tool for the OpenAI Responses API."""
parameters = convert(tool.parameters, custom_serializer=custom_serializer)
spec: FunctionToolParam = {
"type": "function",
"name": tool.name,
"strict": False,
"description": tool.description,
"parameters": parameters,
}
return spec
def _adjust_schema(schema: dict[str, Any]) -> None:
"""Adjust the schema to be compatible with OpenAI API."""
if schema["type"] == "object":
schema.setdefault("strict", True)
schema.setdefault("additionalProperties", False)
if "properties" not in schema:
return
if "required" not in schema:
schema["required"] = []
# Ensure all properties are required
for prop, prop_info in schema["properties"].items():
_adjust_schema(prop_info)
if prop not in schema["required"]:
prop_info["type"] = [prop_info["type"], "null"]
schema["required"].append(prop)
elif schema["type"] == "array":
if "items" not in schema:
return
_adjust_schema(schema["items"])
def _format_structured_output(
schema: vol.Schema, llm_api: llm.APIInstance | None
) -> dict[str, Any]:
"""Format the schema to be compatible with OpenAI API."""
result: dict[str, Any] = convert(
schema,
custom_serializer=(
llm_api.custom_serializer if llm_api else llm.selector_serializer
),
)
_ensure_schema_constraints(result)
return result
def _ensure_schema_constraints(schema: dict[str, Any]) -> None:
"""Ensure generated schemas match the Responses API expectations."""
schema_type = schema.get("type")
if schema_type == "object":
schema.setdefault("additionalProperties", False)
properties = schema.get("properties")
if isinstance(properties, dict):
for property_schema in properties.values():
if isinstance(property_schema, dict):
_ensure_schema_constraints(property_schema)
elif schema_type == "array":
items = schema.get("items")
if isinstance(items, dict):
_ensure_schema_constraints(items)
# Borrowed and adapted from openai_conversation component
async def _transform_stream( # noqa: C901 - This is complex, but better to have it in one place
chat_log: conversation.ChatLog,
stream: Any,
remove_citations: bool = False,
) -> AsyncGenerator[
conversation.AssistantContentDeltaDict | conversation.ToolResultContentDeltaDict
]:
"""Transform stream result into HA format."""
last_summary_index = None
last_role: Literal["assistant", "tool_result"] | None = None
current_tool_call: ResponseFunctionToolCall | None = None
# Non-reasoning models don't follow our request to remove citations, so we remove
# them manually here. They always follow the same pattern: the citation is always
# in parentheses in Markdown format, the citation is always in a single delta event,
# and sometimes the closing parenthesis is split into a separate delta event.
remove_parentheses: bool = False
citation_regexp = re.compile(r"\(\[([^\]]+)\]\((https?:\/\/[^\)]+)\)")
async for event in stream:
event_type = getattr(event, "type", None)
event_item = getattr(event, "item", None)
event_item_type = getattr(event_item, "type", None) if event_item else None
_LOGGER.debug(
"Event[%s] | item: %s",
event_type,
event_item_type,
)
if event_type == ResponsesAPIStreamEvents.OUTPUT_ITEM_ADDED:
# Detect function_call even when it's a BaseLiteLLMOpenAIResponseObject
if event_item_type == ResponseItemType.FUNCTION_CALL:
# OpenAI has tool calls as individual events
# while HA puts tool calls inside the assistant message.
# We turn them into individual assistant content for HA
# to ensure that tools are called as soon as possible.
yield {"role": "assistant"}
last_role = "assistant"
last_summary_index = None
current_tool_call = cast(ResponseFunctionToolCall, event.item)
elif (
event_item_type == ResponseItemType.MESSAGE
or (
event_item_type == ResponseItemType.REASONING
and last_summary_index is not None
) # Subsequent ResponseReasoningItem
or last_role != "assistant"
):
yield {"role": "assistant"}
last_role = "assistant"
last_summary_index = None
elif event_type == ResponsesAPIStreamEvents.OUTPUT_ITEM_DONE:
if event_item_type == ResponseItemType.REASONING:
encrypted_content = getattr(event.item, "encrypted_content", None)
summary = getattr(event.item, "summary", []) or []
yield {
"native": ResponseReasoningItem(
type="reasoning",
id=event.item.id,
summary=[],
encrypted_content=encrypted_content,
)
}
last_summary_index = len(summary) - 1 if summary else None
elif event_item_type == ResponseItemType.WEB_SEARCH_CALL:
action = getattr(event.item, "action", None)
if isinstance(action, dict):
action_dict = action
elif action is not None:
action_dict = action.to_dict()
else:
action_dict = {}
yield {
"tool_calls": [
llm.ToolInput(
id=event.item.id,
tool_name="web_search_call",
tool_args={"action": action_dict},
external=True,
)
]
}
yield {
"role": "tool_result",
"tool_call_id": event.item.id,
"tool_name": "web_search_call",
"tool_result": {"status": event.item.status},
}
last_role = "tool_result"
elif event_item_type == ResponseItemType.IMAGE:
yield {"native": event.item}
last_summary_index = -1 # Trigger new assistant message on next turn
elif event_type == ResponsesAPIStreamEvents.OUTPUT_TEXT_DELTA:
data = event.delta
if remove_parentheses:
data = data.removeprefix(")")
remove_parentheses = False
elif remove_citations and (match := citation_regexp.search(data)):
match_start, match_end = match.span()
# remove leading space if any
if data[match_start - 1 : match_start] == " ":
match_start -= 1
# remove closing parenthesis:
if data[match_end : match_end + 1] == ")":
match_end += 1
else:
remove_parentheses = True
data = data[:match_start] + data[match_end:]
if data:
yield {"content": data}
elif event_type == ResponsesAPIStreamEvents.REASONING_SUMMARY_TEXT_DELTA:
# OpenAI can output several reasoning summaries
# in a single ResponseReasoningItem. We split them as separate
# AssistantContent messages. Only last of them will have
# the reasoning `native` field set.
if (
last_summary_index is not None
and event.summary_index != last_summary_index
):
yield {"role": "assistant"}
last_role = "assistant"
last_summary_index = event.summary_index
yield {"thinking_content": event.delta}
elif event_type == ResponsesAPIStreamEvents.FUNCTION_CALL_ARGUMENTS_DELTA:
if current_tool_call is not None:
current_tool_call.arguments += event.delta
elif event_type == ResponsesAPIStreamEvents.WEB_SEARCH_CALL_SEARCHING:
yield {"role": "assistant"}
elif event_type == ResponsesAPIStreamEvents.FUNCTION_CALL_ARGUMENTS_DONE:
if current_tool_call is not None:
current_tool_call.status = "completed"
raw_args = json.loads(current_tool_call.arguments)
for key in ("area", "floor"):
if key in raw_args and not raw_args[key]:
# Remove keys that are "" or None
raw_args.pop(key, None)
yield {
"tool_calls": [
llm.ToolInput(
id=current_tool_call.call_id,
tool_name=current_tool_call.name,
tool_args=raw_args,
)
]
}
elif event_type == ResponsesAPIStreamEvents.RESPONSE_COMPLETED:
if event.response.usage is not None:
chat_log.async_trace(
{
"stats": {
"input_tokens": event.response.usage.input_tokens,
"output_tokens": event.response.usage.output_tokens,
}
}
)
elif event_type == ResponsesAPIStreamEvents.RESPONSE_INCOMPLETE:
if event.response.usage is not None:
chat_log.async_trace(
{
"stats": {
"input_tokens": event.response.usage.input_tokens,
"output_tokens": event.response.usage.output_tokens,
}
}
)
if (
event.response.incomplete_details
and event.response.incomplete_details.reason
):
reason: str = event.response.incomplete_details.reason
else:
reason = "unknown reason"
if reason == "max_output_tokens":
reason = "max output tokens reached"
elif reason == "content_filter":
reason = "content filter triggered"
raise HomeAssistantError(f"OpenAI response incomplete: {reason}")
elif event_type == ResponsesAPIStreamEvents.RESPONSE_FAILED:
if event.response.usage is not None:
chat_log.async_trace(
{
"stats": {
"input_tokens": event.response.usage.input_tokens,
"output_tokens": event.response.usage.output_tokens,
}
}
)
reason = "unknown reason"
if event.response.error is not None:
reason = event.response.error.message
raise HomeAssistantError(f"OpenAI response failed: {reason}")
elif event_type == ResponsesAPIStreamEvents.ERROR:
raise HomeAssistantError(f"OpenAI response error: {event.message}")
class BaseCloudLLMEntity(Entity):
"""Cloud LLM conversation agent."""
def __init__(self, cloud: Cloud[CloudClient], config_entry: ConfigEntry) -> None:
"""Initialize the entity."""
self._cloud = cloud
self._entry = config_entry
async def _prepare_chat_for_generation(
self,
chat_log: conversation.ChatLog,
response_format: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Prepare kwargs for Cloud LLM from the chat log."""
messages = [
message
for content in chat_log.content
if (message := _convert_content_to_chat_message(content))
]
if not messages or messages[-1]["role"] != "user":
raise HomeAssistantError("No user prompt found")
last_content = chat_log.content[-1]
if last_content.role == "user" and last_content.attachments:
files = await self._async_prepare_files_for_prompt(last_content.attachments)
user_message = messages[-1]
current_content = user_message.get("content", [])
user_message["content"] = [*(current_content or []), *files]
tools: list[ToolParam] = []
tool_choice: str | None = None
if chat_log.llm_api:
ha_tools: list[ToolParam] = [
_format_tool(tool, chat_log.llm_api.custom_serializer)
for tool in chat_log.llm_api.tools
]
if ha_tools:
if not chat_log.unresponded_tool_results:
tools = ha_tools
tool_choice = "auto"
else:
tools = []
tool_choice = "none"
web_search = WebSearchToolParam(
type="web_search",
search_context_size="medium",
)
tools.append(web_search)
response_kwargs: dict[str, Any] = {
"messages": messages,
"conversation_id": chat_log.conversation_id,
}
if response_format is not None:
response_kwargs["response_format"] = response_format
if tools is not None:
response_kwargs["tools"] = tools
if tool_choice is not None:
response_kwargs["tool_choice"] = tool_choice
response_kwargs["stream"] = True
return response_kwargs
async def _async_prepare_files_for_prompt(
self,
attachments: list[conversation.Attachment],
) -> list[dict[str, Any]]:
"""Prepare files for multimodal prompts."""
def prepare() -> list[dict[str, Any]]:
content: list[dict[str, Any]] = []
for attachment in attachments:
mime_type = attachment.mime_type
path = attachment.path
if not path.exists():
raise HomeAssistantError(f"`{path}` does not exist")
data = base64.b64encode(path.read_bytes()).decode("utf-8")
if mime_type and mime_type.startswith("image/"):
content.append(
{
"type": "input_image",
"image_url": f"data:{mime_type};base64,{data}",
"detail": "auto",
}
)
elif mime_type and mime_type.startswith("application/pdf"):
content.append(
{
"type": "input_file",
"filename": str(path.name),
"file_data": f"data:{mime_type};base64,{data}",
}
)
else:
raise HomeAssistantError(
"Only images and PDF are currently supported as attachments"
)
return content
return await self.hass.async_add_executor_job(prepare)
async def _async_handle_chat_log(
self,
type: Literal["ai_task", "conversation"],
chat_log: conversation.ChatLog,
structure_name: str | None = None,
structure: vol.Schema | None = None,
) -> None:
"""Generate a response for the chat log."""
for _ in range(_MAX_TOOL_ITERATIONS):
response_format: dict[str, Any] | None = None
if structure and structure_name:
response_format = {
"type": "json_schema",
"json_schema": {
"name": slugify(structure_name),
"schema": _format_structured_output(
structure, chat_log.llm_api
),
"strict": True,
},
}
response_kwargs = await self._prepare_chat_for_generation(
chat_log,
response_format,
)
try:
if type == "conversation":
raw_stream = await self._cloud.llm.async_process_conversation(
**response_kwargs,
)
else:
raw_stream = await self._cloud.llm.async_generate_data(
**response_kwargs,
)
async for _ in chat_log.async_add_delta_content_stream(
agent_id=self.entity_id,
stream=_transform_stream(
chat_log,
raw_stream,
True,
),
):
pass
except LLMAuthenticationError as err:
raise ConfigEntryAuthFailed("Cloud LLM authentication failed") from err
except LLMRateLimitError as err:
raise HomeAssistantError("Cloud LLM is rate limited") from err
except LLMResponseError as err:
raise HomeAssistantError(str(err)) from err
except LLMServiceError as err:
raise HomeAssistantError("Error talking to Cloud LLM") from err
except LLMError as err:
raise HomeAssistantError(str(err)) from err
if not chat_log.unresponded_tool_results:
break

View File

@@ -1,5 +1,7 @@
"""Helpers for the cloud component."""
from __future__ import annotations
from collections import deque
import logging

View File

@@ -1,4 +1,11 @@
{
"entity": {
"ai_task": {
"cloud_ai": {
"name": "Home Assistant Cloud AI"
}
}
},
"exceptions": {
"backup_size_too_large": {
"message": "The backup size of {size}GB is too large to be uploaded to Home Assistant Cloud."

View File

@@ -21,22 +21,26 @@
## Active Integrations
Built-in integrations: 15
Built-in integrations: 19
Custom integrations: 1
<details><summary>Built-in integrations</summary>
Domain | Name
--- | ---
ai_task | AI Task
auth | Auth
binary_sensor | Binary Sensor
cloud | Home Assistant Cloud
cloud.binary_sensor | Unknown
cloud.stt | Unknown
cloud.tts | Unknown
conversation | Conversation
ffmpeg | FFmpeg
homeassistant | Home Assistant Core Integration
http | HTTP
intent | Intent
media_source | Media Source
mock_no_info_integration | mock_no_info_integration
repairs | Repairs
stt | Speech-to-text (STT)
@@ -116,22 +120,26 @@
## Active Integrations
Built-in integrations: 15
Built-in integrations: 19
Custom integrations: 0
<details><summary>Built-in integrations</summary>
Domain | Name
--- | ---
ai_task | AI Task
auth | Auth
binary_sensor | Binary Sensor
cloud | Home Assistant Cloud
cloud.binary_sensor | Unknown
cloud.stt | Unknown
cloud.tts | Unknown
conversation | Conversation
ffmpeg | FFmpeg
homeassistant | Home Assistant Core Integration
http | HTTP
intent | Intent
media_source | Media Source
mock_no_info_integration | mock_no_info_integration
repairs | Repairs
stt | Speech-to-text (STT)

View File

@@ -0,0 +1,343 @@
"""Tests for the Home Assistant Cloud AI Task entity."""
from __future__ import annotations
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
from hass_nabucasa.llm import (
LLMAuthenticationError,
LLMError,
LLMImageAttachment,
LLMRateLimitError,
LLMResponseError,
LLMServiceError,
)
from PIL import Image
import pytest
import voluptuous as vol
from homeassistant.components import ai_task, conversation
from homeassistant.components.cloud.ai_task import (
CloudLLMTaskEntity,
async_prepare_image_generation_attachments,
)
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ConfigEntryAuthFailed, HomeAssistantError
from tests.common import MockConfigEntry
@pytest.fixture
def mock_cloud_ai_task_entity(hass: HomeAssistant) -> CloudLLMTaskEntity:
"""Return a CloudLLMTaskEntity with a mocked cloud LLM."""
cloud = MagicMock()
cloud.llm = MagicMock(
async_generate_image=AsyncMock(),
async_edit_image=AsyncMock(),
)
cloud.is_logged_in = True
cloud.valid_subscription = True
entry = MockConfigEntry(domain="cloud")
entry.add_to_hass(hass)
entity = CloudLLMTaskEntity(cloud, entry)
entity.entity_id = "ai_task.cloud_ai_task"
entity.hass = hass
return entity
@pytest.fixture(name="mock_handle_chat_log")
def mock_handle_chat_log_fixture() -> AsyncMock:
"""Patch the chat log handler."""
with patch(
"homeassistant.components.cloud.ai_task.CloudLLMTaskEntity._async_handle_chat_log",
AsyncMock(),
) as mock:
yield mock
@pytest.fixture(name="mock_prepare_generation_attachments")
def mock_prepare_generation_attachments_fixture() -> AsyncMock:
"""Patch image generation attachment preparation."""
with patch(
"homeassistant.components.cloud.ai_task.async_prepare_image_generation_attachments",
AsyncMock(),
) as mock:
yield mock
async def test_prepare_image_generation_attachments(
hass: HomeAssistant, tmp_path: Path
) -> None:
"""Test preparing attachments for image generation."""
image_path = tmp_path / "snapshot.jpg"
Image.new("RGB", (2, 2), "red").save(image_path, "JPEG")
attachments = [
conversation.Attachment(
media_content_id="media-source://media/snapshot.jpg",
mime_type="image/jpeg",
path=image_path,
)
]
result = await async_prepare_image_generation_attachments(hass, attachments)
assert len(result) == 1
attachment = result[0]
assert attachment["filename"] == "snapshot.jpg"
assert attachment["mime_type"] == "image/png"
assert attachment["data"].startswith(b"\x89PNG")
async def test_prepare_image_generation_attachments_only_images(
hass: HomeAssistant, tmp_path: Path
) -> None:
"""Test non image attachments are rejected."""
doc_path = tmp_path / "context.txt"
doc_path.write_text("context")
attachments = [
conversation.Attachment(
media_content_id="media-source://media/context.txt",
mime_type="text/plain",
path=doc_path,
)
]
with pytest.raises(
HomeAssistantError,
match="Only image attachments are supported for image generation",
):
await async_prepare_image_generation_attachments(hass, attachments)
async def test_prepare_image_generation_attachments_missing_file(
hass: HomeAssistant, tmp_path: Path
) -> None:
"""Test missing attachments raise a helpful error."""
missing_path = tmp_path / "missing.png"
attachments = [
conversation.Attachment(
media_content_id="media-source://media/missing.png",
mime_type="image/png",
path=missing_path,
)
]
with pytest.raises(HomeAssistantError, match="`.*missing.png` does not exist"):
await async_prepare_image_generation_attachments(hass, attachments)
async def test_prepare_image_generation_attachments_processing_error(
hass: HomeAssistant, tmp_path: Path
) -> None:
"""Test invalid image data raises a processing error."""
broken_path = tmp_path / "broken.png"
broken_path.write_bytes(b"not-an-image")
attachments = [
conversation.Attachment(
media_content_id="media-source://media/broken.png",
mime_type="image/png",
path=broken_path,
)
]
with pytest.raises(
HomeAssistantError,
match="Failed to process image attachment",
):
await async_prepare_image_generation_attachments(hass, attachments)
async def test_generate_data_returns_text(
hass: HomeAssistant,
mock_cloud_ai_task_entity: CloudLLMTaskEntity,
mock_handle_chat_log: AsyncMock,
) -> None:
"""Test generating plain text data."""
chat_log = conversation.ChatLog(hass, "conversation-id")
chat_log.async_add_user_content(
conversation.UserContent(content="Tell me something")
)
task = ai_task.GenDataTask(name="Task", instructions="Say hi")
async def fake_handle(chat_type, log, task_name, structure):
"""Inject assistant output."""
assert chat_type == "ai_task"
log.async_add_assistant_content_without_tools(
conversation.AssistantContent(
agent_id=mock_cloud_ai_task_entity.entity_id or "",
content="Hello from the cloud",
)
)
mock_handle_chat_log.side_effect = fake_handle
result = await mock_cloud_ai_task_entity._async_generate_data(task, chat_log)
assert result.conversation_id == "conversation-id"
assert result.data == "Hello from the cloud"
async def test_generate_data_returns_json(
hass: HomeAssistant,
mock_cloud_ai_task_entity: CloudLLMTaskEntity,
mock_handle_chat_log: AsyncMock,
) -> None:
"""Test generating structured data."""
chat_log = conversation.ChatLog(hass, "conversation-id")
chat_log.async_add_user_content(conversation.UserContent(content="List names"))
task = ai_task.GenDataTask(
name="Task",
instructions="Return JSON",
structure=vol.Schema({vol.Required("names"): [str]}),
)
async def fake_handle(chat_type, log, task_name, structure):
log.async_add_assistant_content_without_tools(
conversation.AssistantContent(
agent_id=mock_cloud_ai_task_entity.entity_id or "",
content='{"names": ["A", "B"]}',
)
)
mock_handle_chat_log.side_effect = fake_handle
result = await mock_cloud_ai_task_entity._async_generate_data(task, chat_log)
assert result.data == {"names": ["A", "B"]}
async def test_generate_data_invalid_json(
hass: HomeAssistant,
mock_cloud_ai_task_entity: CloudLLMTaskEntity,
mock_handle_chat_log: AsyncMock,
) -> None:
"""Test invalid JSON responses raise an error."""
chat_log = conversation.ChatLog(hass, "conversation-id")
chat_log.async_add_user_content(conversation.UserContent(content="List names"))
task = ai_task.GenDataTask(
name="Task",
instructions="Return JSON",
structure=vol.Schema({vol.Required("names"): [str]}),
)
async def fake_handle(chat_type, log, task_name, structure):
log.async_add_assistant_content_without_tools(
conversation.AssistantContent(
agent_id=mock_cloud_ai_task_entity.entity_id or "",
content="not-json",
)
)
mock_handle_chat_log.side_effect = fake_handle
with pytest.raises(
HomeAssistantError, match="Error with OpenAI structured response"
):
await mock_cloud_ai_task_entity._async_generate_data(task, chat_log)
async def test_generate_image_no_attachments(
hass: HomeAssistant, mock_cloud_ai_task_entity: CloudLLMTaskEntity
) -> None:
"""Test generating an image without attachments."""
mock_cloud_ai_task_entity._cloud.llm.async_generate_image.return_value = {
"mime_type": "image/png",
"image_data": b"IMG",
"model": "mock-image",
"width": 1024,
"height": 768,
"revised_prompt": "Improved prompt",
}
task = ai_task.GenImageTask(name="Task", instructions="Draw something")
chat_log = conversation.ChatLog(hass, "conversation-id")
result = await mock_cloud_ai_task_entity._async_generate_image(task, chat_log)
assert result.image_data == b"IMG"
assert result.mime_type == "image/png"
mock_cloud_ai_task_entity._cloud.llm.async_generate_image.assert_awaited_once_with(
prompt="Draw something"
)
async def test_generate_image_with_attachments(
hass: HomeAssistant,
mock_cloud_ai_task_entity: CloudLLMTaskEntity,
mock_prepare_generation_attachments: AsyncMock,
) -> None:
"""Test generating an edited image when attachments are provided."""
mock_cloud_ai_task_entity._cloud.llm.async_edit_image.return_value = {
"mime_type": "image/png",
"image_data": b"IMG",
}
task = ai_task.GenImageTask(
name="Task",
instructions="Edit this",
attachments=[
conversation.Attachment(
media_content_id="media-source://media/snapshot.png",
mime_type="image/png",
path=hass.config.path("snapshot.png"),
)
],
)
chat_log = conversation.ChatLog(hass, "conversation-id")
prepared_attachments = [
LLMImageAttachment(filename="snapshot.png", mime_type="image/png", data=b"IMG")
]
mock_prepare_generation_attachments.return_value = prepared_attachments
await mock_cloud_ai_task_entity._async_generate_image(task, chat_log)
mock_cloud_ai_task_entity._cloud.llm.async_edit_image.assert_awaited_once_with(
prompt="Edit this",
attachments=prepared_attachments,
)
@pytest.mark.parametrize(
("err", "expected_exception", "message"),
[
(
LLMAuthenticationError("auth"),
ConfigEntryAuthFailed,
"Cloud LLM authentication failed",
),
(
LLMRateLimitError("limit"),
HomeAssistantError,
"Cloud LLM is rate limited",
),
(
LLMResponseError("bad response"),
HomeAssistantError,
"bad response",
),
(
LLMServiceError("service"),
HomeAssistantError,
"Error talking to Cloud LLM",
),
(
LLMError("generic"),
HomeAssistantError,
"generic",
),
],
)
async def test_generate_image_error_handling(
hass: HomeAssistant,
mock_cloud_ai_task_entity: CloudLLMTaskEntity,
err: Exception,
expected_exception: type[Exception],
message: str,
) -> None:
"""Test image generation error handling."""
mock_cloud_ai_task_entity._cloud.llm.async_generate_image.side_effect = err
task = ai_task.GenImageTask(name="Task", instructions="Draw something")
chat_log = conversation.ChatLog(hass, "conversation-id")
with pytest.raises(expected_exception, match=message):
await mock_cloud_ai_task_entity._async_generate_image(task, chat_log)

View File

@@ -0,0 +1,219 @@
"""Tests for helpers in the Home Assistant Cloud conversation entity."""
from __future__ import annotations
import base64
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
from PIL import Image
import pytest
import voluptuous as vol
from homeassistant.components import conversation
from homeassistant.components.cloud.entity import (
BaseCloudLLMEntity,
_format_structured_output,
)
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import llm, selector
from tests.common import MockConfigEntry
@pytest.fixture
def cloud_entity(hass: HomeAssistant) -> BaseCloudLLMEntity:
"""Return a CloudLLMTaskEntity attached to hass."""
cloud = MagicMock()
cloud.llm = MagicMock()
cloud.is_logged_in = True
cloud.valid_subscription = True
entry = MockConfigEntry(domain="cloud")
entry.add_to_hass(hass)
entity = BaseCloudLLMEntity(cloud, entry)
entity.entity_id = "ai_task.cloud_ai_task"
entity.hass = hass
return entity
class DummyTool(llm.Tool):
"""Simple tool used for schema conversion tests."""
name = "do_something"
description = "Test tool"
parameters = vol.Schema({vol.Required("value"): str})
async def async_call(self, hass: HomeAssistant, tool_input, llm_context):
"""No-op implementation."""
return {"value": "done"}
async def test_format_structured_output() -> None:
"""Test that structured output schemas are normalized."""
schema = vol.Schema(
{
vol.Required("name"): selector.TextSelector(),
vol.Optional("age"): selector.NumberSelector(
config=selector.NumberSelectorConfig(min=0, max=120),
),
vol.Required("stuff"): selector.ObjectSelector(
{
"multiple": True,
"fields": {
"item_name": {"selector": {"text": None}},
"item_value": {"selector": {"text": None}},
},
}
),
}
)
assert _format_structured_output(schema, None) == {
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "number", "minimum": 0.0, "maximum": 120.0},
"stuff": {
"type": "array",
"items": {
"type": "object",
"properties": {
"item_name": {"type": "string"},
"item_value": {"type": "string"},
},
"additionalProperties": False,
},
},
},
"required": ["name", "stuff"],
"additionalProperties": False,
}
async def test_prepare_files_for_prompt(
cloud_entity: BaseCloudLLMEntity, tmp_path: Path
) -> None:
"""Test that media attachments are converted to the expected payload."""
image_path = tmp_path / "doorbell.jpg"
Image.new("RGB", (2, 2), "blue").save(image_path, "JPEG")
pdf_path = tmp_path / "context.pdf"
pdf_path.write_bytes(b"%PDF-1.3\nmock\n")
attachments = [
conversation.Attachment(
media_content_id="media-source://media/doorbell.jpg",
mime_type="image/jpeg",
path=image_path,
),
conversation.Attachment(
media_content_id="media-source://media/context.pdf",
mime_type="application/pdf",
path=pdf_path,
),
]
files = await cloud_entity._async_prepare_files_for_prompt(attachments)
assert files[0] == {
"type": "input_image",
"image_url": "data:image/jpeg;base64,"
+ base64.b64encode(image_path.read_bytes()).decode(),
"detail": "auto",
}
assert files[1] == {
"type": "input_file",
"filename": "context.pdf",
"file_data": "data:application/pdf;base64,"
+ base64.b64encode(pdf_path.read_bytes()).decode(),
}
async def test_prepare_files_for_prompt_invalid_type(
cloud_entity: BaseCloudLLMEntity, tmp_path: Path
) -> None:
"""Test that unsupported attachments raise an error."""
text_path = tmp_path / "notes.txt"
text_path.write_text("notes")
attachments = [
conversation.Attachment(
media_content_id="media-source://media/notes.txt",
mime_type="text/plain",
path=text_path,
)
]
with pytest.raises(
HomeAssistantError,
match="Only images and PDF are currently supported as attachments",
):
await cloud_entity._async_prepare_files_for_prompt(attachments)
async def test_prepare_chat_for_generation_appends_attachments(
hass: HomeAssistant,
cloud_entity: BaseCloudLLMEntity,
mock_prepare_files_for_prompt: AsyncMock,
) -> None:
"""Test chat preparation adds LLM tools, attachments, and metadata."""
chat_log = conversation.ChatLog(hass, "conversation-id")
attachment = conversation.Attachment(
media_content_id="media-source://media/doorbell.jpg",
mime_type="image/jpeg",
path=hass.config.path("doorbell.jpg"),
)
chat_log.async_add_user_content(
conversation.UserContent(content="Describe the door", attachments=[attachment])
)
chat_log.llm_api = SimpleNamespace(
tools=[DummyTool()],
custom_serializer=None,
)
files = [{"type": "input_image", "image_url": "data://img", "detail": "auto"}]
mock_prepare_files_for_prompt.return_value = files
response = await cloud_entity._prepare_chat_for_generation(
chat_log, response_format={"type": "json"}
)
assert response["conversation_id"] == "conversation-id"
assert response["response_format"] == {"type": "json"}
assert response["tool_choice"] == "auto"
assert len(response["tools"]) == 2
assert response["tools"][0]["name"] == "do_something"
assert response["tools"][1]["type"] == "web_search"
user_message = response["messages"][-1]
assert user_message["content"][0] == {
"type": "input_text",
"text": "Describe the door",
}
assert user_message["content"][1:] == files
async def test_prepare_chat_for_generation_requires_user_prompt(
hass: HomeAssistant, cloud_entity: BaseCloudLLMEntity
) -> None:
"""Test that we fail fast when there is no user input to process."""
chat_log = conversation.ChatLog(hass, "conversation-id")
chat_log.async_add_assistant_content_without_tools(
conversation.AssistantContent(agent_id="agent", content="Ready")
)
with pytest.raises(HomeAssistantError, match="No user prompt found"):
await cloud_entity._prepare_chat_for_generation(chat_log)
@pytest.fixture
def mock_prepare_files_for_prompt(
cloud_entity: BaseCloudLLMEntity,
) -> AsyncMock:
"""Patch file preparation helper on the entity."""
with patch.object(
cloud_entity,
"_async_prepare_files_for_prompt",
AsyncMock(),
) as mock:
yield mock