From 8f1abb6dbb8a3e28d1ffafdbfd429e7358c4fe35 Mon Sep 17 00:00:00 2001 From: victorigualada <21220224+victorigualada@users.noreply.github.com> Date: Tue, 25 Nov 2025 17:01:32 +0100 Subject: [PATCH] Add HomeAssistant Cloud ai_task (#157015) --- homeassistant/components/cloud/__init__.py | 7 +- homeassistant/components/cloud/ai_task.py | 200 +++++++ homeassistant/components/cloud/const.py | 1 + homeassistant/components/cloud/entity.py | 543 ++++++++++++++++++ homeassistant/components/cloud/helpers.py | 2 + homeassistant/components/cloud/strings.json | 7 + .../cloud/snapshots/test_http_api.ambr | 12 +- tests/components/cloud/test_ai_task.py | 343 +++++++++++ tests/components/cloud/test_entity.py | 219 +++++++ 9 files changed, 1331 insertions(+), 3 deletions(-) create mode 100644 homeassistant/components/cloud/ai_task.py create mode 100644 homeassistant/components/cloud/entity.py create mode 100644 tests/components/cloud/test_ai_task.py create mode 100644 tests/components/cloud/test_entity.py diff --git a/homeassistant/components/cloud/__init__.py b/homeassistant/components/cloud/__init__.py index 5377075fe54..0a0811fddc8 100644 --- a/homeassistant/components/cloud/__init__.py +++ b/homeassistant/components/cloud/__init__.py @@ -77,7 +77,12 @@ from .subscription import async_subscription_info DEFAULT_MODE = MODE_PROD -PLATFORMS = [Platform.BINARY_SENSOR, Platform.STT, Platform.TTS] +PLATFORMS = [ + Platform.AI_TASK, + Platform.BINARY_SENSOR, + Platform.STT, + Platform.TTS, +] SERVICE_REMOTE_CONNECT = "remote_connect" SERVICE_REMOTE_DISCONNECT = "remote_disconnect" diff --git a/homeassistant/components/cloud/ai_task.py b/homeassistant/components/cloud/ai_task.py new file mode 100644 index 00000000000..4b6d0223f49 --- /dev/null +++ b/homeassistant/components/cloud/ai_task.py @@ -0,0 +1,200 @@ +"""AI Task integration for Home Assistant Cloud.""" + +from __future__ import annotations + +import io +from json import JSONDecodeError +import logging + +from hass_nabucasa.llm import ( + LLMAuthenticationError, + LLMError, + LLMImageAttachment, + LLMRateLimitError, + LLMResponseError, + LLMServiceError, +) +from PIL import Image + +from homeassistant.components import ai_task, conversation +from homeassistant.config_entries import ConfigEntry +from homeassistant.core import HomeAssistant +from homeassistant.exceptions import ConfigEntryAuthFailed, HomeAssistantError +from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback +from homeassistant.util.json import json_loads + +from .const import AI_TASK_ENTITY_UNIQUE_ID, DATA_CLOUD +from .entity import BaseCloudLLMEntity + +_LOGGER = logging.getLogger(__name__) + + +def _convert_image_for_editing(data: bytes) -> tuple[bytes, str]: + """Ensure the image data is in a format accepted by OpenAI image edits.""" + stream = io.BytesIO(data) + with Image.open(stream) as img: + mode = img.mode + if mode not in ("RGBA", "LA", "L"): + img = img.convert("RGBA") + + output = io.BytesIO() + if img.mode in ("RGBA", "LA", "L"): + img.save(output, format="PNG") + return output.getvalue(), "image/png" + + img.save(output, format=img.format or "PNG") + return output.getvalue(), f"image/{(img.format or 'png').lower()}" + + +async def async_prepare_image_generation_attachments( + hass: HomeAssistant, attachments: list[conversation.Attachment] +) -> list[LLMImageAttachment]: + """Load attachment data for image generation.""" + + def prepare() -> list[LLMImageAttachment]: + items: list[LLMImageAttachment] = [] + for attachment in attachments: + if not attachment.mime_type or not attachment.mime_type.startswith( + "image/" + ): + raise HomeAssistantError( + "Only image attachments are supported for image generation" + ) + path = attachment.path + if not path.exists(): + raise HomeAssistantError(f"`{path}` does not exist") + + data = path.read_bytes() + mime_type = attachment.mime_type + + try: + data, mime_type = _convert_image_for_editing(data) + except HomeAssistantError: + raise + except Exception as err: + raise HomeAssistantError("Failed to process image attachment") from err + + items.append( + LLMImageAttachment( + filename=path.name, + mime_type=mime_type, + data=data, + ) + ) + + return items + + return await hass.async_add_executor_job(prepare) + + +async def async_setup_entry( + hass: HomeAssistant, + config_entry: ConfigEntry, + async_add_entities: AddConfigEntryEntitiesCallback, +) -> None: + """Set up Home Assistant Cloud AI Task entity.""" + cloud = hass.data[DATA_CLOUD] + try: + await cloud.llm.async_ensure_token() + except LLMError: + return + + async_add_entities([CloudLLMTaskEntity(cloud, config_entry)]) + + +class CloudLLMTaskEntity(ai_task.AITaskEntity, BaseCloudLLMEntity): + """Home Assistant Cloud AI Task entity.""" + + _attr_has_entity_name = True + _attr_supported_features = ( + ai_task.AITaskEntityFeature.GENERATE_DATA + | ai_task.AITaskEntityFeature.GENERATE_IMAGE + | ai_task.AITaskEntityFeature.SUPPORT_ATTACHMENTS + ) + _attr_translation_key = "cloud_ai" + _attr_unique_id = AI_TASK_ENTITY_UNIQUE_ID + + @property + def available(self) -> bool: + """Return if the entity is available.""" + return self._cloud.is_logged_in and self._cloud.valid_subscription + + async def _async_generate_data( + self, + task: ai_task.GenDataTask, + chat_log: conversation.ChatLog, + ) -> ai_task.GenDataTaskResult: + """Handle a generate data task.""" + await self._async_handle_chat_log( + "ai_task", chat_log, task.name, task.structure + ) + + if not isinstance(chat_log.content[-1], conversation.AssistantContent): + raise HomeAssistantError( + "Last content in chat log is not an AssistantContent" + ) + + text = chat_log.content[-1].content or "" + + if not task.structure: + return ai_task.GenDataTaskResult( + conversation_id=chat_log.conversation_id, + data=text, + ) + try: + data = json_loads(text) + except JSONDecodeError as err: + _LOGGER.error( + "Failed to parse JSON response: %s. Response: %s", + err, + text, + ) + raise HomeAssistantError("Error with OpenAI structured response") from err + + return ai_task.GenDataTaskResult( + conversation_id=chat_log.conversation_id, + data=data, + ) + + async def _async_generate_image( + self, + task: ai_task.GenImageTask, + chat_log: conversation.ChatLog, + ) -> ai_task.GenImageTaskResult: + """Handle a generate image task.""" + attachments: list[LLMImageAttachment] | None = None + if task.attachments: + attachments = await async_prepare_image_generation_attachments( + self.hass, task.attachments + ) + + try: + if attachments is None: + image = await self._cloud.llm.async_generate_image( + prompt=task.instructions, + ) + else: + image = await self._cloud.llm.async_edit_image( + prompt=task.instructions, + attachments=attachments, + ) + except LLMAuthenticationError as err: + raise ConfigEntryAuthFailed("Cloud LLM authentication failed") from err + except LLMRateLimitError as err: + raise HomeAssistantError("Cloud LLM is rate limited") from err + except LLMResponseError as err: + raise HomeAssistantError(str(err)) from err + except LLMServiceError as err: + raise HomeAssistantError("Error talking to Cloud LLM") from err + except LLMError as err: + raise HomeAssistantError(str(err)) from err + + return ai_task.GenImageTaskResult( + conversation_id=chat_log.conversation_id, + mime_type=image["mime_type"], + image_data=image["image_data"], + model=image.get("model"), + width=image.get("width"), + height=image.get("height"), + revised_prompt=image.get("revised_prompt"), + ) diff --git a/homeassistant/components/cloud/const.py b/homeassistant/components/cloud/const.py index ab3ce2365de..787ca97bf13 100644 --- a/homeassistant/components/cloud/const.py +++ b/homeassistant/components/cloud/const.py @@ -91,6 +91,7 @@ DISPATCHER_REMOTE_UPDATE: SignalType[Any] = SignalType("cloud_remote_update") STT_ENTITY_UNIQUE_ID = "cloud-speech-to-text" TTS_ENTITY_UNIQUE_ID = "cloud-text-to-speech" +AI_TASK_ENTITY_UNIQUE_ID = "cloud-ai-task" LOGIN_MFA_TIMEOUT = 60 diff --git a/homeassistant/components/cloud/entity.py b/homeassistant/components/cloud/entity.py new file mode 100644 index 00000000000..e0e9334f27e --- /dev/null +++ b/homeassistant/components/cloud/entity.py @@ -0,0 +1,543 @@ +"""Helpers for cloud LLM chat handling.""" + +import base64 +from collections.abc import AsyncGenerator, Callable +from enum import Enum +import json +import logging +import re +from typing import Any, Literal, cast + +from hass_nabucasa import Cloud +from hass_nabucasa.llm import ( + LLMAuthenticationError, + LLMError, + LLMRateLimitError, + LLMResponseError, + LLMServiceError, +) +from litellm import ResponseFunctionToolCall, ResponsesAPIStreamEvents +from openai.types.responses import ( + FunctionToolParam, + ResponseReasoningItem, + ToolParam, + WebSearchToolParam, +) +import voluptuous as vol +from voluptuous_openapi import convert + +from homeassistant.components import conversation +from homeassistant.config_entries import ConfigEntry +from homeassistant.exceptions import ConfigEntryAuthFailed, HomeAssistantError +from homeassistant.helpers import llm +from homeassistant.helpers.entity import Entity +from homeassistant.util import slugify + +from .client import CloudClient + +_LOGGER = logging.getLogger(__name__) + +_MAX_TOOL_ITERATIONS = 10 + + +class ResponseItemType(str, Enum): + """Response item types.""" + + FUNCTION_CALL = "function_call" + MESSAGE = "message" + REASONING = "reasoning" + WEB_SEARCH_CALL = "web_search_call" + IMAGE = "image" + + +def _convert_content_to_chat_message( + content: conversation.Content, +) -> dict[str, Any] | None: + """Convert ChatLog content to a responses message.""" + if content.role not in ("user", "system", "tool_result", "assistant"): + return None + + text_content = cast( + conversation.SystemContent + | conversation.UserContent + | conversation.AssistantContent, + content, + ) + + if not text_content.content: + return None + + content_type = "output_text" if text_content.role == "assistant" else "input_text" + + return { + "role": text_content.role, + "content": [ + { + "type": content_type, + "text": text_content.content, + } + ], + } + + +def _format_tool( + tool: llm.Tool, + custom_serializer: Callable[[Any], Any] | None, +) -> ToolParam: + """Format a Home Assistant tool for the OpenAI Responses API.""" + parameters = convert(tool.parameters, custom_serializer=custom_serializer) + + spec: FunctionToolParam = { + "type": "function", + "name": tool.name, + "strict": False, + "description": tool.description, + "parameters": parameters, + } + + return spec + + +def _adjust_schema(schema: dict[str, Any]) -> None: + """Adjust the schema to be compatible with OpenAI API.""" + if schema["type"] == "object": + schema.setdefault("strict", True) + schema.setdefault("additionalProperties", False) + if "properties" not in schema: + return + + if "required" not in schema: + schema["required"] = [] + + # Ensure all properties are required + for prop, prop_info in schema["properties"].items(): + _adjust_schema(prop_info) + if prop not in schema["required"]: + prop_info["type"] = [prop_info["type"], "null"] + schema["required"].append(prop) + + elif schema["type"] == "array": + if "items" not in schema: + return + + _adjust_schema(schema["items"]) + + +def _format_structured_output( + schema: vol.Schema, llm_api: llm.APIInstance | None +) -> dict[str, Any]: + """Format the schema to be compatible with OpenAI API.""" + result: dict[str, Any] = convert( + schema, + custom_serializer=( + llm_api.custom_serializer if llm_api else llm.selector_serializer + ), + ) + + _ensure_schema_constraints(result) + + return result + + +def _ensure_schema_constraints(schema: dict[str, Any]) -> None: + """Ensure generated schemas match the Responses API expectations.""" + schema_type = schema.get("type") + + if schema_type == "object": + schema.setdefault("additionalProperties", False) + properties = schema.get("properties") + if isinstance(properties, dict): + for property_schema in properties.values(): + if isinstance(property_schema, dict): + _ensure_schema_constraints(property_schema) + elif schema_type == "array": + items = schema.get("items") + if isinstance(items, dict): + _ensure_schema_constraints(items) + + +# Borrowed and adapted from openai_conversation component +async def _transform_stream( # noqa: C901 - This is complex, but better to have it in one place + chat_log: conversation.ChatLog, + stream: Any, + remove_citations: bool = False, +) -> AsyncGenerator[ + conversation.AssistantContentDeltaDict | conversation.ToolResultContentDeltaDict +]: + """Transform stream result into HA format.""" + last_summary_index = None + last_role: Literal["assistant", "tool_result"] | None = None + current_tool_call: ResponseFunctionToolCall | None = None + + # Non-reasoning models don't follow our request to remove citations, so we remove + # them manually here. They always follow the same pattern: the citation is always + # in parentheses in Markdown format, the citation is always in a single delta event, + # and sometimes the closing parenthesis is split into a separate delta event. + remove_parentheses: bool = False + citation_regexp = re.compile(r"\(\[([^\]]+)\]\((https?:\/\/[^\)]+)\)") + + async for event in stream: + event_type = getattr(event, "type", None) + event_item = getattr(event, "item", None) + event_item_type = getattr(event_item, "type", None) if event_item else None + + _LOGGER.debug( + "Event[%s] | item: %s", + event_type, + event_item_type, + ) + + if event_type == ResponsesAPIStreamEvents.OUTPUT_ITEM_ADDED: + # Detect function_call even when it's a BaseLiteLLMOpenAIResponseObject + if event_item_type == ResponseItemType.FUNCTION_CALL: + # OpenAI has tool calls as individual events + # while HA puts tool calls inside the assistant message. + # We turn them into individual assistant content for HA + # to ensure that tools are called as soon as possible. + yield {"role": "assistant"} + last_role = "assistant" + last_summary_index = None + current_tool_call = cast(ResponseFunctionToolCall, event.item) + elif ( + event_item_type == ResponseItemType.MESSAGE + or ( + event_item_type == ResponseItemType.REASONING + and last_summary_index is not None + ) # Subsequent ResponseReasoningItem + or last_role != "assistant" + ): + yield {"role": "assistant"} + last_role = "assistant" + last_summary_index = None + + elif event_type == ResponsesAPIStreamEvents.OUTPUT_ITEM_DONE: + if event_item_type == ResponseItemType.REASONING: + encrypted_content = getattr(event.item, "encrypted_content", None) + summary = getattr(event.item, "summary", []) or [] + + yield { + "native": ResponseReasoningItem( + type="reasoning", + id=event.item.id, + summary=[], + encrypted_content=encrypted_content, + ) + } + + last_summary_index = len(summary) - 1 if summary else None + elif event_item_type == ResponseItemType.WEB_SEARCH_CALL: + action = getattr(event.item, "action", None) + if isinstance(action, dict): + action_dict = action + elif action is not None: + action_dict = action.to_dict() + else: + action_dict = {} + yield { + "tool_calls": [ + llm.ToolInput( + id=event.item.id, + tool_name="web_search_call", + tool_args={"action": action_dict}, + external=True, + ) + ] + } + yield { + "role": "tool_result", + "tool_call_id": event.item.id, + "tool_name": "web_search_call", + "tool_result": {"status": event.item.status}, + } + last_role = "tool_result" + elif event_item_type == ResponseItemType.IMAGE: + yield {"native": event.item} + last_summary_index = -1 # Trigger new assistant message on next turn + + elif event_type == ResponsesAPIStreamEvents.OUTPUT_TEXT_DELTA: + data = event.delta + if remove_parentheses: + data = data.removeprefix(")") + remove_parentheses = False + elif remove_citations and (match := citation_regexp.search(data)): + match_start, match_end = match.span() + # remove leading space if any + if data[match_start - 1 : match_start] == " ": + match_start -= 1 + # remove closing parenthesis: + if data[match_end : match_end + 1] == ")": + match_end += 1 + else: + remove_parentheses = True + data = data[:match_start] + data[match_end:] + if data: + yield {"content": data} + + elif event_type == ResponsesAPIStreamEvents.REASONING_SUMMARY_TEXT_DELTA: + # OpenAI can output several reasoning summaries + # in a single ResponseReasoningItem. We split them as separate + # AssistantContent messages. Only last of them will have + # the reasoning `native` field set. + if ( + last_summary_index is not None + and event.summary_index != last_summary_index + ): + yield {"role": "assistant"} + last_role = "assistant" + last_summary_index = event.summary_index + yield {"thinking_content": event.delta} + + elif event_type == ResponsesAPIStreamEvents.FUNCTION_CALL_ARGUMENTS_DELTA: + if current_tool_call is not None: + current_tool_call.arguments += event.delta + + elif event_type == ResponsesAPIStreamEvents.WEB_SEARCH_CALL_SEARCHING: + yield {"role": "assistant"} + + elif event_type == ResponsesAPIStreamEvents.FUNCTION_CALL_ARGUMENTS_DONE: + if current_tool_call is not None: + current_tool_call.status = "completed" + + raw_args = json.loads(current_tool_call.arguments) + for key in ("area", "floor"): + if key in raw_args and not raw_args[key]: + # Remove keys that are "" or None + raw_args.pop(key, None) + + yield { + "tool_calls": [ + llm.ToolInput( + id=current_tool_call.call_id, + tool_name=current_tool_call.name, + tool_args=raw_args, + ) + ] + } + + elif event_type == ResponsesAPIStreamEvents.RESPONSE_COMPLETED: + if event.response.usage is not None: + chat_log.async_trace( + { + "stats": { + "input_tokens": event.response.usage.input_tokens, + "output_tokens": event.response.usage.output_tokens, + } + } + ) + + elif event_type == ResponsesAPIStreamEvents.RESPONSE_INCOMPLETE: + if event.response.usage is not None: + chat_log.async_trace( + { + "stats": { + "input_tokens": event.response.usage.input_tokens, + "output_tokens": event.response.usage.output_tokens, + } + } + ) + + if ( + event.response.incomplete_details + and event.response.incomplete_details.reason + ): + reason: str = event.response.incomplete_details.reason + else: + reason = "unknown reason" + + if reason == "max_output_tokens": + reason = "max output tokens reached" + elif reason == "content_filter": + reason = "content filter triggered" + + raise HomeAssistantError(f"OpenAI response incomplete: {reason}") + + elif event_type == ResponsesAPIStreamEvents.RESPONSE_FAILED: + if event.response.usage is not None: + chat_log.async_trace( + { + "stats": { + "input_tokens": event.response.usage.input_tokens, + "output_tokens": event.response.usage.output_tokens, + } + } + ) + reason = "unknown reason" + if event.response.error is not None: + reason = event.response.error.message + raise HomeAssistantError(f"OpenAI response failed: {reason}") + + elif event_type == ResponsesAPIStreamEvents.ERROR: + raise HomeAssistantError(f"OpenAI response error: {event.message}") + + +class BaseCloudLLMEntity(Entity): + """Cloud LLM conversation agent.""" + + def __init__(self, cloud: Cloud[CloudClient], config_entry: ConfigEntry) -> None: + """Initialize the entity.""" + self._cloud = cloud + self._entry = config_entry + + async def _prepare_chat_for_generation( + self, + chat_log: conversation.ChatLog, + response_format: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """Prepare kwargs for Cloud LLM from the chat log.""" + + messages = [ + message + for content in chat_log.content + if (message := _convert_content_to_chat_message(content)) + ] + + if not messages or messages[-1]["role"] != "user": + raise HomeAssistantError("No user prompt found") + + last_content = chat_log.content[-1] + if last_content.role == "user" and last_content.attachments: + files = await self._async_prepare_files_for_prompt(last_content.attachments) + user_message = messages[-1] + current_content = user_message.get("content", []) + user_message["content"] = [*(current_content or []), *files] + + tools: list[ToolParam] = [] + tool_choice: str | None = None + + if chat_log.llm_api: + ha_tools: list[ToolParam] = [ + _format_tool(tool, chat_log.llm_api.custom_serializer) + for tool in chat_log.llm_api.tools + ] + + if ha_tools: + if not chat_log.unresponded_tool_results: + tools = ha_tools + tool_choice = "auto" + else: + tools = [] + tool_choice = "none" + + web_search = WebSearchToolParam( + type="web_search", + search_context_size="medium", + ) + tools.append(web_search) + + response_kwargs: dict[str, Any] = { + "messages": messages, + "conversation_id": chat_log.conversation_id, + } + + if response_format is not None: + response_kwargs["response_format"] = response_format + if tools is not None: + response_kwargs["tools"] = tools + if tool_choice is not None: + response_kwargs["tool_choice"] = tool_choice + + response_kwargs["stream"] = True + + return response_kwargs + + async def _async_prepare_files_for_prompt( + self, + attachments: list[conversation.Attachment], + ) -> list[dict[str, Any]]: + """Prepare files for multimodal prompts.""" + + def prepare() -> list[dict[str, Any]]: + content: list[dict[str, Any]] = [] + for attachment in attachments: + mime_type = attachment.mime_type + path = attachment.path + if not path.exists(): + raise HomeAssistantError(f"`{path}` does not exist") + + data = base64.b64encode(path.read_bytes()).decode("utf-8") + if mime_type and mime_type.startswith("image/"): + content.append( + { + "type": "input_image", + "image_url": f"data:{mime_type};base64,{data}", + "detail": "auto", + } + ) + elif mime_type and mime_type.startswith("application/pdf"): + content.append( + { + "type": "input_file", + "filename": str(path.name), + "file_data": f"data:{mime_type};base64,{data}", + } + ) + else: + raise HomeAssistantError( + "Only images and PDF are currently supported as attachments" + ) + + return content + + return await self.hass.async_add_executor_job(prepare) + + async def _async_handle_chat_log( + self, + type: Literal["ai_task", "conversation"], + chat_log: conversation.ChatLog, + structure_name: str | None = None, + structure: vol.Schema | None = None, + ) -> None: + """Generate a response for the chat log.""" + + for _ in range(_MAX_TOOL_ITERATIONS): + response_format: dict[str, Any] | None = None + if structure and structure_name: + response_format = { + "type": "json_schema", + "json_schema": { + "name": slugify(structure_name), + "schema": _format_structured_output( + structure, chat_log.llm_api + ), + "strict": True, + }, + } + + response_kwargs = await self._prepare_chat_for_generation( + chat_log, + response_format, + ) + + try: + if type == "conversation": + raw_stream = await self._cloud.llm.async_process_conversation( + **response_kwargs, + ) + else: + raw_stream = await self._cloud.llm.async_generate_data( + **response_kwargs, + ) + + async for _ in chat_log.async_add_delta_content_stream( + agent_id=self.entity_id, + stream=_transform_stream( + chat_log, + raw_stream, + True, + ), + ): + pass + + except LLMAuthenticationError as err: + raise ConfigEntryAuthFailed("Cloud LLM authentication failed") from err + except LLMRateLimitError as err: + raise HomeAssistantError("Cloud LLM is rate limited") from err + except LLMResponseError as err: + raise HomeAssistantError(str(err)) from err + except LLMServiceError as err: + raise HomeAssistantError("Error talking to Cloud LLM") from err + except LLMError as err: + raise HomeAssistantError(str(err)) from err + + if not chat_log.unresponded_tool_results: + break diff --git a/homeassistant/components/cloud/helpers.py b/homeassistant/components/cloud/helpers.py index 7795a314fb7..61abab18c75 100644 --- a/homeassistant/components/cloud/helpers.py +++ b/homeassistant/components/cloud/helpers.py @@ -1,5 +1,7 @@ """Helpers for the cloud component.""" +from __future__ import annotations + from collections import deque import logging diff --git a/homeassistant/components/cloud/strings.json b/homeassistant/components/cloud/strings.json index 43dbd9c0810..d642f1df682 100644 --- a/homeassistant/components/cloud/strings.json +++ b/homeassistant/components/cloud/strings.json @@ -1,4 +1,11 @@ { + "entity": { + "ai_task": { + "cloud_ai": { + "name": "Home Assistant Cloud AI" + } + } + }, "exceptions": { "backup_size_too_large": { "message": "The backup size of {size}GB is too large to be uploaded to Home Assistant Cloud." diff --git a/tests/components/cloud/snapshots/test_http_api.ambr b/tests/components/cloud/snapshots/test_http_api.ambr index 9e1f68e23f8..ad8afbe695e 100644 --- a/tests/components/cloud/snapshots/test_http_api.ambr +++ b/tests/components/cloud/snapshots/test_http_api.ambr @@ -21,22 +21,26 @@ ## Active Integrations - Built-in integrations: 15 + Built-in integrations: 19 Custom integrations: 1
Built-in integrations Domain | Name --- | --- + ai_task | AI Task auth | Auth binary_sensor | Binary Sensor cloud | Home Assistant Cloud cloud.binary_sensor | Unknown cloud.stt | Unknown cloud.tts | Unknown + conversation | Conversation ffmpeg | FFmpeg homeassistant | Home Assistant Core Integration http | HTTP + intent | Intent + media_source | Media Source mock_no_info_integration | mock_no_info_integration repairs | Repairs stt | Speech-to-text (STT) @@ -116,22 +120,26 @@ ## Active Integrations - Built-in integrations: 15 + Built-in integrations: 19 Custom integrations: 0
Built-in integrations Domain | Name --- | --- + ai_task | AI Task auth | Auth binary_sensor | Binary Sensor cloud | Home Assistant Cloud cloud.binary_sensor | Unknown cloud.stt | Unknown cloud.tts | Unknown + conversation | Conversation ffmpeg | FFmpeg homeassistant | Home Assistant Core Integration http | HTTP + intent | Intent + media_source | Media Source mock_no_info_integration | mock_no_info_integration repairs | Repairs stt | Speech-to-text (STT) diff --git a/tests/components/cloud/test_ai_task.py b/tests/components/cloud/test_ai_task.py new file mode 100644 index 00000000000..ab1f65e6f3e --- /dev/null +++ b/tests/components/cloud/test_ai_task.py @@ -0,0 +1,343 @@ +"""Tests for the Home Assistant Cloud AI Task entity.""" + +from __future__ import annotations + +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +from hass_nabucasa.llm import ( + LLMAuthenticationError, + LLMError, + LLMImageAttachment, + LLMRateLimitError, + LLMResponseError, + LLMServiceError, +) +from PIL import Image +import pytest +import voluptuous as vol + +from homeassistant.components import ai_task, conversation +from homeassistant.components.cloud.ai_task import ( + CloudLLMTaskEntity, + async_prepare_image_generation_attachments, +) +from homeassistant.core import HomeAssistant +from homeassistant.exceptions import ConfigEntryAuthFailed, HomeAssistantError + +from tests.common import MockConfigEntry + + +@pytest.fixture +def mock_cloud_ai_task_entity(hass: HomeAssistant) -> CloudLLMTaskEntity: + """Return a CloudLLMTaskEntity with a mocked cloud LLM.""" + cloud = MagicMock() + cloud.llm = MagicMock( + async_generate_image=AsyncMock(), + async_edit_image=AsyncMock(), + ) + cloud.is_logged_in = True + cloud.valid_subscription = True + entry = MockConfigEntry(domain="cloud") + entry.add_to_hass(hass) + entity = CloudLLMTaskEntity(cloud, entry) + entity.entity_id = "ai_task.cloud_ai_task" + entity.hass = hass + return entity + + +@pytest.fixture(name="mock_handle_chat_log") +def mock_handle_chat_log_fixture() -> AsyncMock: + """Patch the chat log handler.""" + with patch( + "homeassistant.components.cloud.ai_task.CloudLLMTaskEntity._async_handle_chat_log", + AsyncMock(), + ) as mock: + yield mock + + +@pytest.fixture(name="mock_prepare_generation_attachments") +def mock_prepare_generation_attachments_fixture() -> AsyncMock: + """Patch image generation attachment preparation.""" + with patch( + "homeassistant.components.cloud.ai_task.async_prepare_image_generation_attachments", + AsyncMock(), + ) as mock: + yield mock + + +async def test_prepare_image_generation_attachments( + hass: HomeAssistant, tmp_path: Path +) -> None: + """Test preparing attachments for image generation.""" + image_path = tmp_path / "snapshot.jpg" + Image.new("RGB", (2, 2), "red").save(image_path, "JPEG") + + attachments = [ + conversation.Attachment( + media_content_id="media-source://media/snapshot.jpg", + mime_type="image/jpeg", + path=image_path, + ) + ] + + result = await async_prepare_image_generation_attachments(hass, attachments) + + assert len(result) == 1 + attachment = result[0] + assert attachment["filename"] == "snapshot.jpg" + assert attachment["mime_type"] == "image/png" + assert attachment["data"].startswith(b"\x89PNG") + + +async def test_prepare_image_generation_attachments_only_images( + hass: HomeAssistant, tmp_path: Path +) -> None: + """Test non image attachments are rejected.""" + doc_path = tmp_path / "context.txt" + doc_path.write_text("context") + + attachments = [ + conversation.Attachment( + media_content_id="media-source://media/context.txt", + mime_type="text/plain", + path=doc_path, + ) + ] + + with pytest.raises( + HomeAssistantError, + match="Only image attachments are supported for image generation", + ): + await async_prepare_image_generation_attachments(hass, attachments) + + +async def test_prepare_image_generation_attachments_missing_file( + hass: HomeAssistant, tmp_path: Path +) -> None: + """Test missing attachments raise a helpful error.""" + missing_path = tmp_path / "missing.png" + + attachments = [ + conversation.Attachment( + media_content_id="media-source://media/missing.png", + mime_type="image/png", + path=missing_path, + ) + ] + + with pytest.raises(HomeAssistantError, match="`.*missing.png` does not exist"): + await async_prepare_image_generation_attachments(hass, attachments) + + +async def test_prepare_image_generation_attachments_processing_error( + hass: HomeAssistant, tmp_path: Path +) -> None: + """Test invalid image data raises a processing error.""" + broken_path = tmp_path / "broken.png" + broken_path.write_bytes(b"not-an-image") + + attachments = [ + conversation.Attachment( + media_content_id="media-source://media/broken.png", + mime_type="image/png", + path=broken_path, + ) + ] + + with pytest.raises( + HomeAssistantError, + match="Failed to process image attachment", + ): + await async_prepare_image_generation_attachments(hass, attachments) + + +async def test_generate_data_returns_text( + hass: HomeAssistant, + mock_cloud_ai_task_entity: CloudLLMTaskEntity, + mock_handle_chat_log: AsyncMock, +) -> None: + """Test generating plain text data.""" + chat_log = conversation.ChatLog(hass, "conversation-id") + chat_log.async_add_user_content( + conversation.UserContent(content="Tell me something") + ) + task = ai_task.GenDataTask(name="Task", instructions="Say hi") + + async def fake_handle(chat_type, log, task_name, structure): + """Inject assistant output.""" + assert chat_type == "ai_task" + log.async_add_assistant_content_without_tools( + conversation.AssistantContent( + agent_id=mock_cloud_ai_task_entity.entity_id or "", + content="Hello from the cloud", + ) + ) + + mock_handle_chat_log.side_effect = fake_handle + result = await mock_cloud_ai_task_entity._async_generate_data(task, chat_log) + + assert result.conversation_id == "conversation-id" + assert result.data == "Hello from the cloud" + + +async def test_generate_data_returns_json( + hass: HomeAssistant, + mock_cloud_ai_task_entity: CloudLLMTaskEntity, + mock_handle_chat_log: AsyncMock, +) -> None: + """Test generating structured data.""" + chat_log = conversation.ChatLog(hass, "conversation-id") + chat_log.async_add_user_content(conversation.UserContent(content="List names")) + task = ai_task.GenDataTask( + name="Task", + instructions="Return JSON", + structure=vol.Schema({vol.Required("names"): [str]}), + ) + + async def fake_handle(chat_type, log, task_name, structure): + log.async_add_assistant_content_without_tools( + conversation.AssistantContent( + agent_id=mock_cloud_ai_task_entity.entity_id or "", + content='{"names": ["A", "B"]}', + ) + ) + + mock_handle_chat_log.side_effect = fake_handle + result = await mock_cloud_ai_task_entity._async_generate_data(task, chat_log) + + assert result.data == {"names": ["A", "B"]} + + +async def test_generate_data_invalid_json( + hass: HomeAssistant, + mock_cloud_ai_task_entity: CloudLLMTaskEntity, + mock_handle_chat_log: AsyncMock, +) -> None: + """Test invalid JSON responses raise an error.""" + chat_log = conversation.ChatLog(hass, "conversation-id") + chat_log.async_add_user_content(conversation.UserContent(content="List names")) + task = ai_task.GenDataTask( + name="Task", + instructions="Return JSON", + structure=vol.Schema({vol.Required("names"): [str]}), + ) + + async def fake_handle(chat_type, log, task_name, structure): + log.async_add_assistant_content_without_tools( + conversation.AssistantContent( + agent_id=mock_cloud_ai_task_entity.entity_id or "", + content="not-json", + ) + ) + + mock_handle_chat_log.side_effect = fake_handle + with pytest.raises( + HomeAssistantError, match="Error with OpenAI structured response" + ): + await mock_cloud_ai_task_entity._async_generate_data(task, chat_log) + + +async def test_generate_image_no_attachments( + hass: HomeAssistant, mock_cloud_ai_task_entity: CloudLLMTaskEntity +) -> None: + """Test generating an image without attachments.""" + mock_cloud_ai_task_entity._cloud.llm.async_generate_image.return_value = { + "mime_type": "image/png", + "image_data": b"IMG", + "model": "mock-image", + "width": 1024, + "height": 768, + "revised_prompt": "Improved prompt", + } + task = ai_task.GenImageTask(name="Task", instructions="Draw something") + chat_log = conversation.ChatLog(hass, "conversation-id") + + result = await mock_cloud_ai_task_entity._async_generate_image(task, chat_log) + + assert result.image_data == b"IMG" + assert result.mime_type == "image/png" + mock_cloud_ai_task_entity._cloud.llm.async_generate_image.assert_awaited_once_with( + prompt="Draw something" + ) + + +async def test_generate_image_with_attachments( + hass: HomeAssistant, + mock_cloud_ai_task_entity: CloudLLMTaskEntity, + mock_prepare_generation_attachments: AsyncMock, +) -> None: + """Test generating an edited image when attachments are provided.""" + mock_cloud_ai_task_entity._cloud.llm.async_edit_image.return_value = { + "mime_type": "image/png", + "image_data": b"IMG", + } + task = ai_task.GenImageTask( + name="Task", + instructions="Edit this", + attachments=[ + conversation.Attachment( + media_content_id="media-source://media/snapshot.png", + mime_type="image/png", + path=hass.config.path("snapshot.png"), + ) + ], + ) + chat_log = conversation.ChatLog(hass, "conversation-id") + prepared_attachments = [ + LLMImageAttachment(filename="snapshot.png", mime_type="image/png", data=b"IMG") + ] + + mock_prepare_generation_attachments.return_value = prepared_attachments + await mock_cloud_ai_task_entity._async_generate_image(task, chat_log) + + mock_cloud_ai_task_entity._cloud.llm.async_edit_image.assert_awaited_once_with( + prompt="Edit this", + attachments=prepared_attachments, + ) + + +@pytest.mark.parametrize( + ("err", "expected_exception", "message"), + [ + ( + LLMAuthenticationError("auth"), + ConfigEntryAuthFailed, + "Cloud LLM authentication failed", + ), + ( + LLMRateLimitError("limit"), + HomeAssistantError, + "Cloud LLM is rate limited", + ), + ( + LLMResponseError("bad response"), + HomeAssistantError, + "bad response", + ), + ( + LLMServiceError("service"), + HomeAssistantError, + "Error talking to Cloud LLM", + ), + ( + LLMError("generic"), + HomeAssistantError, + "generic", + ), + ], +) +async def test_generate_image_error_handling( + hass: HomeAssistant, + mock_cloud_ai_task_entity: CloudLLMTaskEntity, + err: Exception, + expected_exception: type[Exception], + message: str, +) -> None: + """Test image generation error handling.""" + mock_cloud_ai_task_entity._cloud.llm.async_generate_image.side_effect = err + task = ai_task.GenImageTask(name="Task", instructions="Draw something") + chat_log = conversation.ChatLog(hass, "conversation-id") + + with pytest.raises(expected_exception, match=message): + await mock_cloud_ai_task_entity._async_generate_image(task, chat_log) diff --git a/tests/components/cloud/test_entity.py b/tests/components/cloud/test_entity.py new file mode 100644 index 00000000000..f24acda3c69 --- /dev/null +++ b/tests/components/cloud/test_entity.py @@ -0,0 +1,219 @@ +"""Tests for helpers in the Home Assistant Cloud conversation entity.""" + +from __future__ import annotations + +import base64 +from pathlib import Path +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch + +from PIL import Image +import pytest +import voluptuous as vol + +from homeassistant.components import conversation +from homeassistant.components.cloud.entity import ( + BaseCloudLLMEntity, + _format_structured_output, +) +from homeassistant.core import HomeAssistant +from homeassistant.exceptions import HomeAssistantError +from homeassistant.helpers import llm, selector + +from tests.common import MockConfigEntry + + +@pytest.fixture +def cloud_entity(hass: HomeAssistant) -> BaseCloudLLMEntity: + """Return a CloudLLMTaskEntity attached to hass.""" + cloud = MagicMock() + cloud.llm = MagicMock() + cloud.is_logged_in = True + cloud.valid_subscription = True + entry = MockConfigEntry(domain="cloud") + entry.add_to_hass(hass) + entity = BaseCloudLLMEntity(cloud, entry) + entity.entity_id = "ai_task.cloud_ai_task" + entity.hass = hass + return entity + + +class DummyTool(llm.Tool): + """Simple tool used for schema conversion tests.""" + + name = "do_something" + description = "Test tool" + parameters = vol.Schema({vol.Required("value"): str}) + + async def async_call(self, hass: HomeAssistant, tool_input, llm_context): + """No-op implementation.""" + return {"value": "done"} + + +async def test_format_structured_output() -> None: + """Test that structured output schemas are normalized.""" + schema = vol.Schema( + { + vol.Required("name"): selector.TextSelector(), + vol.Optional("age"): selector.NumberSelector( + config=selector.NumberSelectorConfig(min=0, max=120), + ), + vol.Required("stuff"): selector.ObjectSelector( + { + "multiple": True, + "fields": { + "item_name": {"selector": {"text": None}}, + "item_value": {"selector": {"text": None}}, + }, + } + ), + } + ) + + assert _format_structured_output(schema, None) == { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "number", "minimum": 0.0, "maximum": 120.0}, + "stuff": { + "type": "array", + "items": { + "type": "object", + "properties": { + "item_name": {"type": "string"}, + "item_value": {"type": "string"}, + }, + "additionalProperties": False, + }, + }, + }, + "required": ["name", "stuff"], + "additionalProperties": False, + } + + +async def test_prepare_files_for_prompt( + cloud_entity: BaseCloudLLMEntity, tmp_path: Path +) -> None: + """Test that media attachments are converted to the expected payload.""" + image_path = tmp_path / "doorbell.jpg" + Image.new("RGB", (2, 2), "blue").save(image_path, "JPEG") + pdf_path = tmp_path / "context.pdf" + pdf_path.write_bytes(b"%PDF-1.3\nmock\n") + + attachments = [ + conversation.Attachment( + media_content_id="media-source://media/doorbell.jpg", + mime_type="image/jpeg", + path=image_path, + ), + conversation.Attachment( + media_content_id="media-source://media/context.pdf", + mime_type="application/pdf", + path=pdf_path, + ), + ] + + files = await cloud_entity._async_prepare_files_for_prompt(attachments) + + assert files[0] == { + "type": "input_image", + "image_url": "data:image/jpeg;base64," + + base64.b64encode(image_path.read_bytes()).decode(), + "detail": "auto", + } + assert files[1] == { + "type": "input_file", + "filename": "context.pdf", + "file_data": "data:application/pdf;base64," + + base64.b64encode(pdf_path.read_bytes()).decode(), + } + + +async def test_prepare_files_for_prompt_invalid_type( + cloud_entity: BaseCloudLLMEntity, tmp_path: Path +) -> None: + """Test that unsupported attachments raise an error.""" + text_path = tmp_path / "notes.txt" + text_path.write_text("notes") + + attachments = [ + conversation.Attachment( + media_content_id="media-source://media/notes.txt", + mime_type="text/plain", + path=text_path, + ) + ] + + with pytest.raises( + HomeAssistantError, + match="Only images and PDF are currently supported as attachments", + ): + await cloud_entity._async_prepare_files_for_prompt(attachments) + + +async def test_prepare_chat_for_generation_appends_attachments( + hass: HomeAssistant, + cloud_entity: BaseCloudLLMEntity, + mock_prepare_files_for_prompt: AsyncMock, +) -> None: + """Test chat preparation adds LLM tools, attachments, and metadata.""" + chat_log = conversation.ChatLog(hass, "conversation-id") + attachment = conversation.Attachment( + media_content_id="media-source://media/doorbell.jpg", + mime_type="image/jpeg", + path=hass.config.path("doorbell.jpg"), + ) + chat_log.async_add_user_content( + conversation.UserContent(content="Describe the door", attachments=[attachment]) + ) + chat_log.llm_api = SimpleNamespace( + tools=[DummyTool()], + custom_serializer=None, + ) + + files = [{"type": "input_image", "image_url": "data://img", "detail": "auto"}] + + mock_prepare_files_for_prompt.return_value = files + response = await cloud_entity._prepare_chat_for_generation( + chat_log, response_format={"type": "json"} + ) + + assert response["conversation_id"] == "conversation-id" + assert response["response_format"] == {"type": "json"} + assert response["tool_choice"] == "auto" + assert len(response["tools"]) == 2 + assert response["tools"][0]["name"] == "do_something" + assert response["tools"][1]["type"] == "web_search" + user_message = response["messages"][-1] + assert user_message["content"][0] == { + "type": "input_text", + "text": "Describe the door", + } + assert user_message["content"][1:] == files + + +async def test_prepare_chat_for_generation_requires_user_prompt( + hass: HomeAssistant, cloud_entity: BaseCloudLLMEntity +) -> None: + """Test that we fail fast when there is no user input to process.""" + chat_log = conversation.ChatLog(hass, "conversation-id") + chat_log.async_add_assistant_content_without_tools( + conversation.AssistantContent(agent_id="agent", content="Ready") + ) + + with pytest.raises(HomeAssistantError, match="No user prompt found"): + await cloud_entity._prepare_chat_for_generation(chat_log) + + +@pytest.fixture +def mock_prepare_files_for_prompt( + cloud_entity: BaseCloudLLMEntity, +) -> AsyncMock: + """Patch file preparation helper on the entity.""" + with patch.object( + cloud_entity, + "_async_prepare_files_for_prompt", + AsyncMock(), + ) as mock: + yield mock