mirror of
https://github.com/Electric-Special/ha-core.git
synced 2026-03-21 08:06:00 +01:00
Add HomeAssistant Cloud ai_task (#157015)
This commit is contained in:
@@ -77,7 +77,12 @@ from .subscription import async_subscription_info
|
||||
|
||||
DEFAULT_MODE = MODE_PROD
|
||||
|
||||
PLATFORMS = [Platform.BINARY_SENSOR, Platform.STT, Platform.TTS]
|
||||
PLATFORMS = [
|
||||
Platform.AI_TASK,
|
||||
Platform.BINARY_SENSOR,
|
||||
Platform.STT,
|
||||
Platform.TTS,
|
||||
]
|
||||
|
||||
SERVICE_REMOTE_CONNECT = "remote_connect"
|
||||
SERVICE_REMOTE_DISCONNECT = "remote_disconnect"
|
||||
|
||||
200
homeassistant/components/cloud/ai_task.py
Normal file
200
homeassistant/components/cloud/ai_task.py
Normal file
@@ -0,0 +1,200 @@
|
||||
"""AI Task integration for Home Assistant Cloud."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
from json import JSONDecodeError
|
||||
import logging
|
||||
|
||||
from hass_nabucasa.llm import (
|
||||
LLMAuthenticationError,
|
||||
LLMError,
|
||||
LLMImageAttachment,
|
||||
LLMRateLimitError,
|
||||
LLMResponseError,
|
||||
LLMServiceError,
|
||||
)
|
||||
from PIL import Image
|
||||
|
||||
from homeassistant.components import ai_task, conversation
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import ConfigEntryAuthFailed, HomeAssistantError
|
||||
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
||||
from homeassistant.util.json import json_loads
|
||||
|
||||
from .const import AI_TASK_ENTITY_UNIQUE_ID, DATA_CLOUD
|
||||
from .entity import BaseCloudLLMEntity
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _convert_image_for_editing(data: bytes) -> tuple[bytes, str]:
|
||||
"""Ensure the image data is in a format accepted by OpenAI image edits."""
|
||||
stream = io.BytesIO(data)
|
||||
with Image.open(stream) as img:
|
||||
mode = img.mode
|
||||
if mode not in ("RGBA", "LA", "L"):
|
||||
img = img.convert("RGBA")
|
||||
|
||||
output = io.BytesIO()
|
||||
if img.mode in ("RGBA", "LA", "L"):
|
||||
img.save(output, format="PNG")
|
||||
return output.getvalue(), "image/png"
|
||||
|
||||
img.save(output, format=img.format or "PNG")
|
||||
return output.getvalue(), f"image/{(img.format or 'png').lower()}"
|
||||
|
||||
|
||||
async def async_prepare_image_generation_attachments(
|
||||
hass: HomeAssistant, attachments: list[conversation.Attachment]
|
||||
) -> list[LLMImageAttachment]:
|
||||
"""Load attachment data for image generation."""
|
||||
|
||||
def prepare() -> list[LLMImageAttachment]:
|
||||
items: list[LLMImageAttachment] = []
|
||||
for attachment in attachments:
|
||||
if not attachment.mime_type or not attachment.mime_type.startswith(
|
||||
"image/"
|
||||
):
|
||||
raise HomeAssistantError(
|
||||
"Only image attachments are supported for image generation"
|
||||
)
|
||||
path = attachment.path
|
||||
if not path.exists():
|
||||
raise HomeAssistantError(f"`{path}` does not exist")
|
||||
|
||||
data = path.read_bytes()
|
||||
mime_type = attachment.mime_type
|
||||
|
||||
try:
|
||||
data, mime_type = _convert_image_for_editing(data)
|
||||
except HomeAssistantError:
|
||||
raise
|
||||
except Exception as err:
|
||||
raise HomeAssistantError("Failed to process image attachment") from err
|
||||
|
||||
items.append(
|
||||
LLMImageAttachment(
|
||||
filename=path.name,
|
||||
mime_type=mime_type,
|
||||
data=data,
|
||||
)
|
||||
)
|
||||
|
||||
return items
|
||||
|
||||
return await hass.async_add_executor_job(prepare)
|
||||
|
||||
|
||||
async def async_setup_entry(
|
||||
hass: HomeAssistant,
|
||||
config_entry: ConfigEntry,
|
||||
async_add_entities: AddConfigEntryEntitiesCallback,
|
||||
) -> None:
|
||||
"""Set up Home Assistant Cloud AI Task entity."""
|
||||
cloud = hass.data[DATA_CLOUD]
|
||||
try:
|
||||
await cloud.llm.async_ensure_token()
|
||||
except LLMError:
|
||||
return
|
||||
|
||||
async_add_entities([CloudLLMTaskEntity(cloud, config_entry)])
|
||||
|
||||
|
||||
class CloudLLMTaskEntity(ai_task.AITaskEntity, BaseCloudLLMEntity):
|
||||
"""Home Assistant Cloud AI Task entity."""
|
||||
|
||||
_attr_has_entity_name = True
|
||||
_attr_supported_features = (
|
||||
ai_task.AITaskEntityFeature.GENERATE_DATA
|
||||
| ai_task.AITaskEntityFeature.GENERATE_IMAGE
|
||||
| ai_task.AITaskEntityFeature.SUPPORT_ATTACHMENTS
|
||||
)
|
||||
_attr_translation_key = "cloud_ai"
|
||||
_attr_unique_id = AI_TASK_ENTITY_UNIQUE_ID
|
||||
|
||||
@property
|
||||
def available(self) -> bool:
|
||||
"""Return if the entity is available."""
|
||||
return self._cloud.is_logged_in and self._cloud.valid_subscription
|
||||
|
||||
async def _async_generate_data(
|
||||
self,
|
||||
task: ai_task.GenDataTask,
|
||||
chat_log: conversation.ChatLog,
|
||||
) -> ai_task.GenDataTaskResult:
|
||||
"""Handle a generate data task."""
|
||||
await self._async_handle_chat_log(
|
||||
"ai_task", chat_log, task.name, task.structure
|
||||
)
|
||||
|
||||
if not isinstance(chat_log.content[-1], conversation.AssistantContent):
|
||||
raise HomeAssistantError(
|
||||
"Last content in chat log is not an AssistantContent"
|
||||
)
|
||||
|
||||
text = chat_log.content[-1].content or ""
|
||||
|
||||
if not task.structure:
|
||||
return ai_task.GenDataTaskResult(
|
||||
conversation_id=chat_log.conversation_id,
|
||||
data=text,
|
||||
)
|
||||
try:
|
||||
data = json_loads(text)
|
||||
except JSONDecodeError as err:
|
||||
_LOGGER.error(
|
||||
"Failed to parse JSON response: %s. Response: %s",
|
||||
err,
|
||||
text,
|
||||
)
|
||||
raise HomeAssistantError("Error with OpenAI structured response") from err
|
||||
|
||||
return ai_task.GenDataTaskResult(
|
||||
conversation_id=chat_log.conversation_id,
|
||||
data=data,
|
||||
)
|
||||
|
||||
async def _async_generate_image(
|
||||
self,
|
||||
task: ai_task.GenImageTask,
|
||||
chat_log: conversation.ChatLog,
|
||||
) -> ai_task.GenImageTaskResult:
|
||||
"""Handle a generate image task."""
|
||||
attachments: list[LLMImageAttachment] | None = None
|
||||
if task.attachments:
|
||||
attachments = await async_prepare_image_generation_attachments(
|
||||
self.hass, task.attachments
|
||||
)
|
||||
|
||||
try:
|
||||
if attachments is None:
|
||||
image = await self._cloud.llm.async_generate_image(
|
||||
prompt=task.instructions,
|
||||
)
|
||||
else:
|
||||
image = await self._cloud.llm.async_edit_image(
|
||||
prompt=task.instructions,
|
||||
attachments=attachments,
|
||||
)
|
||||
except LLMAuthenticationError as err:
|
||||
raise ConfigEntryAuthFailed("Cloud LLM authentication failed") from err
|
||||
except LLMRateLimitError as err:
|
||||
raise HomeAssistantError("Cloud LLM is rate limited") from err
|
||||
except LLMResponseError as err:
|
||||
raise HomeAssistantError(str(err)) from err
|
||||
except LLMServiceError as err:
|
||||
raise HomeAssistantError("Error talking to Cloud LLM") from err
|
||||
except LLMError as err:
|
||||
raise HomeAssistantError(str(err)) from err
|
||||
|
||||
return ai_task.GenImageTaskResult(
|
||||
conversation_id=chat_log.conversation_id,
|
||||
mime_type=image["mime_type"],
|
||||
image_data=image["image_data"],
|
||||
model=image.get("model"),
|
||||
width=image.get("width"),
|
||||
height=image.get("height"),
|
||||
revised_prompt=image.get("revised_prompt"),
|
||||
)
|
||||
@@ -91,6 +91,7 @@ DISPATCHER_REMOTE_UPDATE: SignalType[Any] = SignalType("cloud_remote_update")
|
||||
|
||||
STT_ENTITY_UNIQUE_ID = "cloud-speech-to-text"
|
||||
TTS_ENTITY_UNIQUE_ID = "cloud-text-to-speech"
|
||||
AI_TASK_ENTITY_UNIQUE_ID = "cloud-ai-task"
|
||||
|
||||
LOGIN_MFA_TIMEOUT = 60
|
||||
|
||||
|
||||
543
homeassistant/components/cloud/entity.py
Normal file
543
homeassistant/components/cloud/entity.py
Normal file
@@ -0,0 +1,543 @@
|
||||
"""Helpers for cloud LLM chat handling."""
|
||||
|
||||
import base64
|
||||
from collections.abc import AsyncGenerator, Callable
|
||||
from enum import Enum
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from hass_nabucasa import Cloud
|
||||
from hass_nabucasa.llm import (
|
||||
LLMAuthenticationError,
|
||||
LLMError,
|
||||
LLMRateLimitError,
|
||||
LLMResponseError,
|
||||
LLMServiceError,
|
||||
)
|
||||
from litellm import ResponseFunctionToolCall, ResponsesAPIStreamEvents
|
||||
from openai.types.responses import (
|
||||
FunctionToolParam,
|
||||
ResponseReasoningItem,
|
||||
ToolParam,
|
||||
WebSearchToolParam,
|
||||
)
|
||||
import voluptuous as vol
|
||||
from voluptuous_openapi import convert
|
||||
|
||||
from homeassistant.components import conversation
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.exceptions import ConfigEntryAuthFailed, HomeAssistantError
|
||||
from homeassistant.helpers import llm
|
||||
from homeassistant.helpers.entity import Entity
|
||||
from homeassistant.util import slugify
|
||||
|
||||
from .client import CloudClient
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
_MAX_TOOL_ITERATIONS = 10
|
||||
|
||||
|
||||
class ResponseItemType(str, Enum):
|
||||
"""Response item types."""
|
||||
|
||||
FUNCTION_CALL = "function_call"
|
||||
MESSAGE = "message"
|
||||
REASONING = "reasoning"
|
||||
WEB_SEARCH_CALL = "web_search_call"
|
||||
IMAGE = "image"
|
||||
|
||||
|
||||
def _convert_content_to_chat_message(
|
||||
content: conversation.Content,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Convert ChatLog content to a responses message."""
|
||||
if content.role not in ("user", "system", "tool_result", "assistant"):
|
||||
return None
|
||||
|
||||
text_content = cast(
|
||||
conversation.SystemContent
|
||||
| conversation.UserContent
|
||||
| conversation.AssistantContent,
|
||||
content,
|
||||
)
|
||||
|
||||
if not text_content.content:
|
||||
return None
|
||||
|
||||
content_type = "output_text" if text_content.role == "assistant" else "input_text"
|
||||
|
||||
return {
|
||||
"role": text_content.role,
|
||||
"content": [
|
||||
{
|
||||
"type": content_type,
|
||||
"text": text_content.content,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def _format_tool(
|
||||
tool: llm.Tool,
|
||||
custom_serializer: Callable[[Any], Any] | None,
|
||||
) -> ToolParam:
|
||||
"""Format a Home Assistant tool for the OpenAI Responses API."""
|
||||
parameters = convert(tool.parameters, custom_serializer=custom_serializer)
|
||||
|
||||
spec: FunctionToolParam = {
|
||||
"type": "function",
|
||||
"name": tool.name,
|
||||
"strict": False,
|
||||
"description": tool.description,
|
||||
"parameters": parameters,
|
||||
}
|
||||
|
||||
return spec
|
||||
|
||||
|
||||
def _adjust_schema(schema: dict[str, Any]) -> None:
|
||||
"""Adjust the schema to be compatible with OpenAI API."""
|
||||
if schema["type"] == "object":
|
||||
schema.setdefault("strict", True)
|
||||
schema.setdefault("additionalProperties", False)
|
||||
if "properties" not in schema:
|
||||
return
|
||||
|
||||
if "required" not in schema:
|
||||
schema["required"] = []
|
||||
|
||||
# Ensure all properties are required
|
||||
for prop, prop_info in schema["properties"].items():
|
||||
_adjust_schema(prop_info)
|
||||
if prop not in schema["required"]:
|
||||
prop_info["type"] = [prop_info["type"], "null"]
|
||||
schema["required"].append(prop)
|
||||
|
||||
elif schema["type"] == "array":
|
||||
if "items" not in schema:
|
||||
return
|
||||
|
||||
_adjust_schema(schema["items"])
|
||||
|
||||
|
||||
def _format_structured_output(
|
||||
schema: vol.Schema, llm_api: llm.APIInstance | None
|
||||
) -> dict[str, Any]:
|
||||
"""Format the schema to be compatible with OpenAI API."""
|
||||
result: dict[str, Any] = convert(
|
||||
schema,
|
||||
custom_serializer=(
|
||||
llm_api.custom_serializer if llm_api else llm.selector_serializer
|
||||
),
|
||||
)
|
||||
|
||||
_ensure_schema_constraints(result)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _ensure_schema_constraints(schema: dict[str, Any]) -> None:
|
||||
"""Ensure generated schemas match the Responses API expectations."""
|
||||
schema_type = schema.get("type")
|
||||
|
||||
if schema_type == "object":
|
||||
schema.setdefault("additionalProperties", False)
|
||||
properties = schema.get("properties")
|
||||
if isinstance(properties, dict):
|
||||
for property_schema in properties.values():
|
||||
if isinstance(property_schema, dict):
|
||||
_ensure_schema_constraints(property_schema)
|
||||
elif schema_type == "array":
|
||||
items = schema.get("items")
|
||||
if isinstance(items, dict):
|
||||
_ensure_schema_constraints(items)
|
||||
|
||||
|
||||
# Borrowed and adapted from openai_conversation component
|
||||
async def _transform_stream( # noqa: C901 - This is complex, but better to have it in one place
|
||||
chat_log: conversation.ChatLog,
|
||||
stream: Any,
|
||||
remove_citations: bool = False,
|
||||
) -> AsyncGenerator[
|
||||
conversation.AssistantContentDeltaDict | conversation.ToolResultContentDeltaDict
|
||||
]:
|
||||
"""Transform stream result into HA format."""
|
||||
last_summary_index = None
|
||||
last_role: Literal["assistant", "tool_result"] | None = None
|
||||
current_tool_call: ResponseFunctionToolCall | None = None
|
||||
|
||||
# Non-reasoning models don't follow our request to remove citations, so we remove
|
||||
# them manually here. They always follow the same pattern: the citation is always
|
||||
# in parentheses in Markdown format, the citation is always in a single delta event,
|
||||
# and sometimes the closing parenthesis is split into a separate delta event.
|
||||
remove_parentheses: bool = False
|
||||
citation_regexp = re.compile(r"\(\[([^\]]+)\]\((https?:\/\/[^\)]+)\)")
|
||||
|
||||
async for event in stream:
|
||||
event_type = getattr(event, "type", None)
|
||||
event_item = getattr(event, "item", None)
|
||||
event_item_type = getattr(event_item, "type", None) if event_item else None
|
||||
|
||||
_LOGGER.debug(
|
||||
"Event[%s] | item: %s",
|
||||
event_type,
|
||||
event_item_type,
|
||||
)
|
||||
|
||||
if event_type == ResponsesAPIStreamEvents.OUTPUT_ITEM_ADDED:
|
||||
# Detect function_call even when it's a BaseLiteLLMOpenAIResponseObject
|
||||
if event_item_type == ResponseItemType.FUNCTION_CALL:
|
||||
# OpenAI has tool calls as individual events
|
||||
# while HA puts tool calls inside the assistant message.
|
||||
# We turn them into individual assistant content for HA
|
||||
# to ensure that tools are called as soon as possible.
|
||||
yield {"role": "assistant"}
|
||||
last_role = "assistant"
|
||||
last_summary_index = None
|
||||
current_tool_call = cast(ResponseFunctionToolCall, event.item)
|
||||
elif (
|
||||
event_item_type == ResponseItemType.MESSAGE
|
||||
or (
|
||||
event_item_type == ResponseItemType.REASONING
|
||||
and last_summary_index is not None
|
||||
) # Subsequent ResponseReasoningItem
|
||||
or last_role != "assistant"
|
||||
):
|
||||
yield {"role": "assistant"}
|
||||
last_role = "assistant"
|
||||
last_summary_index = None
|
||||
|
||||
elif event_type == ResponsesAPIStreamEvents.OUTPUT_ITEM_DONE:
|
||||
if event_item_type == ResponseItemType.REASONING:
|
||||
encrypted_content = getattr(event.item, "encrypted_content", None)
|
||||
summary = getattr(event.item, "summary", []) or []
|
||||
|
||||
yield {
|
||||
"native": ResponseReasoningItem(
|
||||
type="reasoning",
|
||||
id=event.item.id,
|
||||
summary=[],
|
||||
encrypted_content=encrypted_content,
|
||||
)
|
||||
}
|
||||
|
||||
last_summary_index = len(summary) - 1 if summary else None
|
||||
elif event_item_type == ResponseItemType.WEB_SEARCH_CALL:
|
||||
action = getattr(event.item, "action", None)
|
||||
if isinstance(action, dict):
|
||||
action_dict = action
|
||||
elif action is not None:
|
||||
action_dict = action.to_dict()
|
||||
else:
|
||||
action_dict = {}
|
||||
yield {
|
||||
"tool_calls": [
|
||||
llm.ToolInput(
|
||||
id=event.item.id,
|
||||
tool_name="web_search_call",
|
||||
tool_args={"action": action_dict},
|
||||
external=True,
|
||||
)
|
||||
]
|
||||
}
|
||||
yield {
|
||||
"role": "tool_result",
|
||||
"tool_call_id": event.item.id,
|
||||
"tool_name": "web_search_call",
|
||||
"tool_result": {"status": event.item.status},
|
||||
}
|
||||
last_role = "tool_result"
|
||||
elif event_item_type == ResponseItemType.IMAGE:
|
||||
yield {"native": event.item}
|
||||
last_summary_index = -1 # Trigger new assistant message on next turn
|
||||
|
||||
elif event_type == ResponsesAPIStreamEvents.OUTPUT_TEXT_DELTA:
|
||||
data = event.delta
|
||||
if remove_parentheses:
|
||||
data = data.removeprefix(")")
|
||||
remove_parentheses = False
|
||||
elif remove_citations and (match := citation_regexp.search(data)):
|
||||
match_start, match_end = match.span()
|
||||
# remove leading space if any
|
||||
if data[match_start - 1 : match_start] == " ":
|
||||
match_start -= 1
|
||||
# remove closing parenthesis:
|
||||
if data[match_end : match_end + 1] == ")":
|
||||
match_end += 1
|
||||
else:
|
||||
remove_parentheses = True
|
||||
data = data[:match_start] + data[match_end:]
|
||||
if data:
|
||||
yield {"content": data}
|
||||
|
||||
elif event_type == ResponsesAPIStreamEvents.REASONING_SUMMARY_TEXT_DELTA:
|
||||
# OpenAI can output several reasoning summaries
|
||||
# in a single ResponseReasoningItem. We split them as separate
|
||||
# AssistantContent messages. Only last of them will have
|
||||
# the reasoning `native` field set.
|
||||
if (
|
||||
last_summary_index is not None
|
||||
and event.summary_index != last_summary_index
|
||||
):
|
||||
yield {"role": "assistant"}
|
||||
last_role = "assistant"
|
||||
last_summary_index = event.summary_index
|
||||
yield {"thinking_content": event.delta}
|
||||
|
||||
elif event_type == ResponsesAPIStreamEvents.FUNCTION_CALL_ARGUMENTS_DELTA:
|
||||
if current_tool_call is not None:
|
||||
current_tool_call.arguments += event.delta
|
||||
|
||||
elif event_type == ResponsesAPIStreamEvents.WEB_SEARCH_CALL_SEARCHING:
|
||||
yield {"role": "assistant"}
|
||||
|
||||
elif event_type == ResponsesAPIStreamEvents.FUNCTION_CALL_ARGUMENTS_DONE:
|
||||
if current_tool_call is not None:
|
||||
current_tool_call.status = "completed"
|
||||
|
||||
raw_args = json.loads(current_tool_call.arguments)
|
||||
for key in ("area", "floor"):
|
||||
if key in raw_args and not raw_args[key]:
|
||||
# Remove keys that are "" or None
|
||||
raw_args.pop(key, None)
|
||||
|
||||
yield {
|
||||
"tool_calls": [
|
||||
llm.ToolInput(
|
||||
id=current_tool_call.call_id,
|
||||
tool_name=current_tool_call.name,
|
||||
tool_args=raw_args,
|
||||
)
|
||||
]
|
||||
}
|
||||
|
||||
elif event_type == ResponsesAPIStreamEvents.RESPONSE_COMPLETED:
|
||||
if event.response.usage is not None:
|
||||
chat_log.async_trace(
|
||||
{
|
||||
"stats": {
|
||||
"input_tokens": event.response.usage.input_tokens,
|
||||
"output_tokens": event.response.usage.output_tokens,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
elif event_type == ResponsesAPIStreamEvents.RESPONSE_INCOMPLETE:
|
||||
if event.response.usage is not None:
|
||||
chat_log.async_trace(
|
||||
{
|
||||
"stats": {
|
||||
"input_tokens": event.response.usage.input_tokens,
|
||||
"output_tokens": event.response.usage.output_tokens,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
if (
|
||||
event.response.incomplete_details
|
||||
and event.response.incomplete_details.reason
|
||||
):
|
||||
reason: str = event.response.incomplete_details.reason
|
||||
else:
|
||||
reason = "unknown reason"
|
||||
|
||||
if reason == "max_output_tokens":
|
||||
reason = "max output tokens reached"
|
||||
elif reason == "content_filter":
|
||||
reason = "content filter triggered"
|
||||
|
||||
raise HomeAssistantError(f"OpenAI response incomplete: {reason}")
|
||||
|
||||
elif event_type == ResponsesAPIStreamEvents.RESPONSE_FAILED:
|
||||
if event.response.usage is not None:
|
||||
chat_log.async_trace(
|
||||
{
|
||||
"stats": {
|
||||
"input_tokens": event.response.usage.input_tokens,
|
||||
"output_tokens": event.response.usage.output_tokens,
|
||||
}
|
||||
}
|
||||
)
|
||||
reason = "unknown reason"
|
||||
if event.response.error is not None:
|
||||
reason = event.response.error.message
|
||||
raise HomeAssistantError(f"OpenAI response failed: {reason}")
|
||||
|
||||
elif event_type == ResponsesAPIStreamEvents.ERROR:
|
||||
raise HomeAssistantError(f"OpenAI response error: {event.message}")
|
||||
|
||||
|
||||
class BaseCloudLLMEntity(Entity):
|
||||
"""Cloud LLM conversation agent."""
|
||||
|
||||
def __init__(self, cloud: Cloud[CloudClient], config_entry: ConfigEntry) -> None:
|
||||
"""Initialize the entity."""
|
||||
self._cloud = cloud
|
||||
self._entry = config_entry
|
||||
|
||||
async def _prepare_chat_for_generation(
|
||||
self,
|
||||
chat_log: conversation.ChatLog,
|
||||
response_format: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Prepare kwargs for Cloud LLM from the chat log."""
|
||||
|
||||
messages = [
|
||||
message
|
||||
for content in chat_log.content
|
||||
if (message := _convert_content_to_chat_message(content))
|
||||
]
|
||||
|
||||
if not messages or messages[-1]["role"] != "user":
|
||||
raise HomeAssistantError("No user prompt found")
|
||||
|
||||
last_content = chat_log.content[-1]
|
||||
if last_content.role == "user" and last_content.attachments:
|
||||
files = await self._async_prepare_files_for_prompt(last_content.attachments)
|
||||
user_message = messages[-1]
|
||||
current_content = user_message.get("content", [])
|
||||
user_message["content"] = [*(current_content or []), *files]
|
||||
|
||||
tools: list[ToolParam] = []
|
||||
tool_choice: str | None = None
|
||||
|
||||
if chat_log.llm_api:
|
||||
ha_tools: list[ToolParam] = [
|
||||
_format_tool(tool, chat_log.llm_api.custom_serializer)
|
||||
for tool in chat_log.llm_api.tools
|
||||
]
|
||||
|
||||
if ha_tools:
|
||||
if not chat_log.unresponded_tool_results:
|
||||
tools = ha_tools
|
||||
tool_choice = "auto"
|
||||
else:
|
||||
tools = []
|
||||
tool_choice = "none"
|
||||
|
||||
web_search = WebSearchToolParam(
|
||||
type="web_search",
|
||||
search_context_size="medium",
|
||||
)
|
||||
tools.append(web_search)
|
||||
|
||||
response_kwargs: dict[str, Any] = {
|
||||
"messages": messages,
|
||||
"conversation_id": chat_log.conversation_id,
|
||||
}
|
||||
|
||||
if response_format is not None:
|
||||
response_kwargs["response_format"] = response_format
|
||||
if tools is not None:
|
||||
response_kwargs["tools"] = tools
|
||||
if tool_choice is not None:
|
||||
response_kwargs["tool_choice"] = tool_choice
|
||||
|
||||
response_kwargs["stream"] = True
|
||||
|
||||
return response_kwargs
|
||||
|
||||
async def _async_prepare_files_for_prompt(
|
||||
self,
|
||||
attachments: list[conversation.Attachment],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Prepare files for multimodal prompts."""
|
||||
|
||||
def prepare() -> list[dict[str, Any]]:
|
||||
content: list[dict[str, Any]] = []
|
||||
for attachment in attachments:
|
||||
mime_type = attachment.mime_type
|
||||
path = attachment.path
|
||||
if not path.exists():
|
||||
raise HomeAssistantError(f"`{path}` does not exist")
|
||||
|
||||
data = base64.b64encode(path.read_bytes()).decode("utf-8")
|
||||
if mime_type and mime_type.startswith("image/"):
|
||||
content.append(
|
||||
{
|
||||
"type": "input_image",
|
||||
"image_url": f"data:{mime_type};base64,{data}",
|
||||
"detail": "auto",
|
||||
}
|
||||
)
|
||||
elif mime_type and mime_type.startswith("application/pdf"):
|
||||
content.append(
|
||||
{
|
||||
"type": "input_file",
|
||||
"filename": str(path.name),
|
||||
"file_data": f"data:{mime_type};base64,{data}",
|
||||
}
|
||||
)
|
||||
else:
|
||||
raise HomeAssistantError(
|
||||
"Only images and PDF are currently supported as attachments"
|
||||
)
|
||||
|
||||
return content
|
||||
|
||||
return await self.hass.async_add_executor_job(prepare)
|
||||
|
||||
async def _async_handle_chat_log(
|
||||
self,
|
||||
type: Literal["ai_task", "conversation"],
|
||||
chat_log: conversation.ChatLog,
|
||||
structure_name: str | None = None,
|
||||
structure: vol.Schema | None = None,
|
||||
) -> None:
|
||||
"""Generate a response for the chat log."""
|
||||
|
||||
for _ in range(_MAX_TOOL_ITERATIONS):
|
||||
response_format: dict[str, Any] | None = None
|
||||
if structure and structure_name:
|
||||
response_format = {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": slugify(structure_name),
|
||||
"schema": _format_structured_output(
|
||||
structure, chat_log.llm_api
|
||||
),
|
||||
"strict": True,
|
||||
},
|
||||
}
|
||||
|
||||
response_kwargs = await self._prepare_chat_for_generation(
|
||||
chat_log,
|
||||
response_format,
|
||||
)
|
||||
|
||||
try:
|
||||
if type == "conversation":
|
||||
raw_stream = await self._cloud.llm.async_process_conversation(
|
||||
**response_kwargs,
|
||||
)
|
||||
else:
|
||||
raw_stream = await self._cloud.llm.async_generate_data(
|
||||
**response_kwargs,
|
||||
)
|
||||
|
||||
async for _ in chat_log.async_add_delta_content_stream(
|
||||
agent_id=self.entity_id,
|
||||
stream=_transform_stream(
|
||||
chat_log,
|
||||
raw_stream,
|
||||
True,
|
||||
),
|
||||
):
|
||||
pass
|
||||
|
||||
except LLMAuthenticationError as err:
|
||||
raise ConfigEntryAuthFailed("Cloud LLM authentication failed") from err
|
||||
except LLMRateLimitError as err:
|
||||
raise HomeAssistantError("Cloud LLM is rate limited") from err
|
||||
except LLMResponseError as err:
|
||||
raise HomeAssistantError(str(err)) from err
|
||||
except LLMServiceError as err:
|
||||
raise HomeAssistantError("Error talking to Cloud LLM") from err
|
||||
except LLMError as err:
|
||||
raise HomeAssistantError(str(err)) from err
|
||||
|
||||
if not chat_log.unresponded_tool_results:
|
||||
break
|
||||
@@ -1,5 +1,7 @@
|
||||
"""Helpers for the cloud component."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
import logging
|
||||
|
||||
|
||||
@@ -1,4 +1,11 @@
|
||||
{
|
||||
"entity": {
|
||||
"ai_task": {
|
||||
"cloud_ai": {
|
||||
"name": "Home Assistant Cloud AI"
|
||||
}
|
||||
}
|
||||
},
|
||||
"exceptions": {
|
||||
"backup_size_too_large": {
|
||||
"message": "The backup size of {size}GB is too large to be uploaded to Home Assistant Cloud."
|
||||
|
||||
@@ -21,22 +21,26 @@
|
||||
|
||||
## Active Integrations
|
||||
|
||||
Built-in integrations: 15
|
||||
Built-in integrations: 19
|
||||
Custom integrations: 1
|
||||
|
||||
<details><summary>Built-in integrations</summary>
|
||||
|
||||
Domain | Name
|
||||
--- | ---
|
||||
ai_task | AI Task
|
||||
auth | Auth
|
||||
binary_sensor | Binary Sensor
|
||||
cloud | Home Assistant Cloud
|
||||
cloud.binary_sensor | Unknown
|
||||
cloud.stt | Unknown
|
||||
cloud.tts | Unknown
|
||||
conversation | Conversation
|
||||
ffmpeg | FFmpeg
|
||||
homeassistant | Home Assistant Core Integration
|
||||
http | HTTP
|
||||
intent | Intent
|
||||
media_source | Media Source
|
||||
mock_no_info_integration | mock_no_info_integration
|
||||
repairs | Repairs
|
||||
stt | Speech-to-text (STT)
|
||||
@@ -116,22 +120,26 @@
|
||||
|
||||
## Active Integrations
|
||||
|
||||
Built-in integrations: 15
|
||||
Built-in integrations: 19
|
||||
Custom integrations: 0
|
||||
|
||||
<details><summary>Built-in integrations</summary>
|
||||
|
||||
Domain | Name
|
||||
--- | ---
|
||||
ai_task | AI Task
|
||||
auth | Auth
|
||||
binary_sensor | Binary Sensor
|
||||
cloud | Home Assistant Cloud
|
||||
cloud.binary_sensor | Unknown
|
||||
cloud.stt | Unknown
|
||||
cloud.tts | Unknown
|
||||
conversation | Conversation
|
||||
ffmpeg | FFmpeg
|
||||
homeassistant | Home Assistant Core Integration
|
||||
http | HTTP
|
||||
intent | Intent
|
||||
media_source | Media Source
|
||||
mock_no_info_integration | mock_no_info_integration
|
||||
repairs | Repairs
|
||||
stt | Speech-to-text (STT)
|
||||
|
||||
343
tests/components/cloud/test_ai_task.py
Normal file
343
tests/components/cloud/test_ai_task.py
Normal file
@@ -0,0 +1,343 @@
|
||||
"""Tests for the Home Assistant Cloud AI Task entity."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from hass_nabucasa.llm import (
|
||||
LLMAuthenticationError,
|
||||
LLMError,
|
||||
LLMImageAttachment,
|
||||
LLMRateLimitError,
|
||||
LLMResponseError,
|
||||
LLMServiceError,
|
||||
)
|
||||
from PIL import Image
|
||||
import pytest
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import ai_task, conversation
|
||||
from homeassistant.components.cloud.ai_task import (
|
||||
CloudLLMTaskEntity,
|
||||
async_prepare_image_generation_attachments,
|
||||
)
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import ConfigEntryAuthFailed, HomeAssistantError
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_cloud_ai_task_entity(hass: HomeAssistant) -> CloudLLMTaskEntity:
|
||||
"""Return a CloudLLMTaskEntity with a mocked cloud LLM."""
|
||||
cloud = MagicMock()
|
||||
cloud.llm = MagicMock(
|
||||
async_generate_image=AsyncMock(),
|
||||
async_edit_image=AsyncMock(),
|
||||
)
|
||||
cloud.is_logged_in = True
|
||||
cloud.valid_subscription = True
|
||||
entry = MockConfigEntry(domain="cloud")
|
||||
entry.add_to_hass(hass)
|
||||
entity = CloudLLMTaskEntity(cloud, entry)
|
||||
entity.entity_id = "ai_task.cloud_ai_task"
|
||||
entity.hass = hass
|
||||
return entity
|
||||
|
||||
|
||||
@pytest.fixture(name="mock_handle_chat_log")
|
||||
def mock_handle_chat_log_fixture() -> AsyncMock:
|
||||
"""Patch the chat log handler."""
|
||||
with patch(
|
||||
"homeassistant.components.cloud.ai_task.CloudLLMTaskEntity._async_handle_chat_log",
|
||||
AsyncMock(),
|
||||
) as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture(name="mock_prepare_generation_attachments")
|
||||
def mock_prepare_generation_attachments_fixture() -> AsyncMock:
|
||||
"""Patch image generation attachment preparation."""
|
||||
with patch(
|
||||
"homeassistant.components.cloud.ai_task.async_prepare_image_generation_attachments",
|
||||
AsyncMock(),
|
||||
) as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
async def test_prepare_image_generation_attachments(
|
||||
hass: HomeAssistant, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test preparing attachments for image generation."""
|
||||
image_path = tmp_path / "snapshot.jpg"
|
||||
Image.new("RGB", (2, 2), "red").save(image_path, "JPEG")
|
||||
|
||||
attachments = [
|
||||
conversation.Attachment(
|
||||
media_content_id="media-source://media/snapshot.jpg",
|
||||
mime_type="image/jpeg",
|
||||
path=image_path,
|
||||
)
|
||||
]
|
||||
|
||||
result = await async_prepare_image_generation_attachments(hass, attachments)
|
||||
|
||||
assert len(result) == 1
|
||||
attachment = result[0]
|
||||
assert attachment["filename"] == "snapshot.jpg"
|
||||
assert attachment["mime_type"] == "image/png"
|
||||
assert attachment["data"].startswith(b"\x89PNG")
|
||||
|
||||
|
||||
async def test_prepare_image_generation_attachments_only_images(
|
||||
hass: HomeAssistant, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test non image attachments are rejected."""
|
||||
doc_path = tmp_path / "context.txt"
|
||||
doc_path.write_text("context")
|
||||
|
||||
attachments = [
|
||||
conversation.Attachment(
|
||||
media_content_id="media-source://media/context.txt",
|
||||
mime_type="text/plain",
|
||||
path=doc_path,
|
||||
)
|
||||
]
|
||||
|
||||
with pytest.raises(
|
||||
HomeAssistantError,
|
||||
match="Only image attachments are supported for image generation",
|
||||
):
|
||||
await async_prepare_image_generation_attachments(hass, attachments)
|
||||
|
||||
|
||||
async def test_prepare_image_generation_attachments_missing_file(
|
||||
hass: HomeAssistant, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test missing attachments raise a helpful error."""
|
||||
missing_path = tmp_path / "missing.png"
|
||||
|
||||
attachments = [
|
||||
conversation.Attachment(
|
||||
media_content_id="media-source://media/missing.png",
|
||||
mime_type="image/png",
|
||||
path=missing_path,
|
||||
)
|
||||
]
|
||||
|
||||
with pytest.raises(HomeAssistantError, match="`.*missing.png` does not exist"):
|
||||
await async_prepare_image_generation_attachments(hass, attachments)
|
||||
|
||||
|
||||
async def test_prepare_image_generation_attachments_processing_error(
|
||||
hass: HomeAssistant, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test invalid image data raises a processing error."""
|
||||
broken_path = tmp_path / "broken.png"
|
||||
broken_path.write_bytes(b"not-an-image")
|
||||
|
||||
attachments = [
|
||||
conversation.Attachment(
|
||||
media_content_id="media-source://media/broken.png",
|
||||
mime_type="image/png",
|
||||
path=broken_path,
|
||||
)
|
||||
]
|
||||
|
||||
with pytest.raises(
|
||||
HomeAssistantError,
|
||||
match="Failed to process image attachment",
|
||||
):
|
||||
await async_prepare_image_generation_attachments(hass, attachments)
|
||||
|
||||
|
||||
async def test_generate_data_returns_text(
|
||||
hass: HomeAssistant,
|
||||
mock_cloud_ai_task_entity: CloudLLMTaskEntity,
|
||||
mock_handle_chat_log: AsyncMock,
|
||||
) -> None:
|
||||
"""Test generating plain text data."""
|
||||
chat_log = conversation.ChatLog(hass, "conversation-id")
|
||||
chat_log.async_add_user_content(
|
||||
conversation.UserContent(content="Tell me something")
|
||||
)
|
||||
task = ai_task.GenDataTask(name="Task", instructions="Say hi")
|
||||
|
||||
async def fake_handle(chat_type, log, task_name, structure):
|
||||
"""Inject assistant output."""
|
||||
assert chat_type == "ai_task"
|
||||
log.async_add_assistant_content_without_tools(
|
||||
conversation.AssistantContent(
|
||||
agent_id=mock_cloud_ai_task_entity.entity_id or "",
|
||||
content="Hello from the cloud",
|
||||
)
|
||||
)
|
||||
|
||||
mock_handle_chat_log.side_effect = fake_handle
|
||||
result = await mock_cloud_ai_task_entity._async_generate_data(task, chat_log)
|
||||
|
||||
assert result.conversation_id == "conversation-id"
|
||||
assert result.data == "Hello from the cloud"
|
||||
|
||||
|
||||
async def test_generate_data_returns_json(
|
||||
hass: HomeAssistant,
|
||||
mock_cloud_ai_task_entity: CloudLLMTaskEntity,
|
||||
mock_handle_chat_log: AsyncMock,
|
||||
) -> None:
|
||||
"""Test generating structured data."""
|
||||
chat_log = conversation.ChatLog(hass, "conversation-id")
|
||||
chat_log.async_add_user_content(conversation.UserContent(content="List names"))
|
||||
task = ai_task.GenDataTask(
|
||||
name="Task",
|
||||
instructions="Return JSON",
|
||||
structure=vol.Schema({vol.Required("names"): [str]}),
|
||||
)
|
||||
|
||||
async def fake_handle(chat_type, log, task_name, structure):
|
||||
log.async_add_assistant_content_without_tools(
|
||||
conversation.AssistantContent(
|
||||
agent_id=mock_cloud_ai_task_entity.entity_id or "",
|
||||
content='{"names": ["A", "B"]}',
|
||||
)
|
||||
)
|
||||
|
||||
mock_handle_chat_log.side_effect = fake_handle
|
||||
result = await mock_cloud_ai_task_entity._async_generate_data(task, chat_log)
|
||||
|
||||
assert result.data == {"names": ["A", "B"]}
|
||||
|
||||
|
||||
async def test_generate_data_invalid_json(
|
||||
hass: HomeAssistant,
|
||||
mock_cloud_ai_task_entity: CloudLLMTaskEntity,
|
||||
mock_handle_chat_log: AsyncMock,
|
||||
) -> None:
|
||||
"""Test invalid JSON responses raise an error."""
|
||||
chat_log = conversation.ChatLog(hass, "conversation-id")
|
||||
chat_log.async_add_user_content(conversation.UserContent(content="List names"))
|
||||
task = ai_task.GenDataTask(
|
||||
name="Task",
|
||||
instructions="Return JSON",
|
||||
structure=vol.Schema({vol.Required("names"): [str]}),
|
||||
)
|
||||
|
||||
async def fake_handle(chat_type, log, task_name, structure):
|
||||
log.async_add_assistant_content_without_tools(
|
||||
conversation.AssistantContent(
|
||||
agent_id=mock_cloud_ai_task_entity.entity_id or "",
|
||||
content="not-json",
|
||||
)
|
||||
)
|
||||
|
||||
mock_handle_chat_log.side_effect = fake_handle
|
||||
with pytest.raises(
|
||||
HomeAssistantError, match="Error with OpenAI structured response"
|
||||
):
|
||||
await mock_cloud_ai_task_entity._async_generate_data(task, chat_log)
|
||||
|
||||
|
||||
async def test_generate_image_no_attachments(
|
||||
hass: HomeAssistant, mock_cloud_ai_task_entity: CloudLLMTaskEntity
|
||||
) -> None:
|
||||
"""Test generating an image without attachments."""
|
||||
mock_cloud_ai_task_entity._cloud.llm.async_generate_image.return_value = {
|
||||
"mime_type": "image/png",
|
||||
"image_data": b"IMG",
|
||||
"model": "mock-image",
|
||||
"width": 1024,
|
||||
"height": 768,
|
||||
"revised_prompt": "Improved prompt",
|
||||
}
|
||||
task = ai_task.GenImageTask(name="Task", instructions="Draw something")
|
||||
chat_log = conversation.ChatLog(hass, "conversation-id")
|
||||
|
||||
result = await mock_cloud_ai_task_entity._async_generate_image(task, chat_log)
|
||||
|
||||
assert result.image_data == b"IMG"
|
||||
assert result.mime_type == "image/png"
|
||||
mock_cloud_ai_task_entity._cloud.llm.async_generate_image.assert_awaited_once_with(
|
||||
prompt="Draw something"
|
||||
)
|
||||
|
||||
|
||||
async def test_generate_image_with_attachments(
|
||||
hass: HomeAssistant,
|
||||
mock_cloud_ai_task_entity: CloudLLMTaskEntity,
|
||||
mock_prepare_generation_attachments: AsyncMock,
|
||||
) -> None:
|
||||
"""Test generating an edited image when attachments are provided."""
|
||||
mock_cloud_ai_task_entity._cloud.llm.async_edit_image.return_value = {
|
||||
"mime_type": "image/png",
|
||||
"image_data": b"IMG",
|
||||
}
|
||||
task = ai_task.GenImageTask(
|
||||
name="Task",
|
||||
instructions="Edit this",
|
||||
attachments=[
|
||||
conversation.Attachment(
|
||||
media_content_id="media-source://media/snapshot.png",
|
||||
mime_type="image/png",
|
||||
path=hass.config.path("snapshot.png"),
|
||||
)
|
||||
],
|
||||
)
|
||||
chat_log = conversation.ChatLog(hass, "conversation-id")
|
||||
prepared_attachments = [
|
||||
LLMImageAttachment(filename="snapshot.png", mime_type="image/png", data=b"IMG")
|
||||
]
|
||||
|
||||
mock_prepare_generation_attachments.return_value = prepared_attachments
|
||||
await mock_cloud_ai_task_entity._async_generate_image(task, chat_log)
|
||||
|
||||
mock_cloud_ai_task_entity._cloud.llm.async_edit_image.assert_awaited_once_with(
|
||||
prompt="Edit this",
|
||||
attachments=prepared_attachments,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("err", "expected_exception", "message"),
|
||||
[
|
||||
(
|
||||
LLMAuthenticationError("auth"),
|
||||
ConfigEntryAuthFailed,
|
||||
"Cloud LLM authentication failed",
|
||||
),
|
||||
(
|
||||
LLMRateLimitError("limit"),
|
||||
HomeAssistantError,
|
||||
"Cloud LLM is rate limited",
|
||||
),
|
||||
(
|
||||
LLMResponseError("bad response"),
|
||||
HomeAssistantError,
|
||||
"bad response",
|
||||
),
|
||||
(
|
||||
LLMServiceError("service"),
|
||||
HomeAssistantError,
|
||||
"Error talking to Cloud LLM",
|
||||
),
|
||||
(
|
||||
LLMError("generic"),
|
||||
HomeAssistantError,
|
||||
"generic",
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_generate_image_error_handling(
|
||||
hass: HomeAssistant,
|
||||
mock_cloud_ai_task_entity: CloudLLMTaskEntity,
|
||||
err: Exception,
|
||||
expected_exception: type[Exception],
|
||||
message: str,
|
||||
) -> None:
|
||||
"""Test image generation error handling."""
|
||||
mock_cloud_ai_task_entity._cloud.llm.async_generate_image.side_effect = err
|
||||
task = ai_task.GenImageTask(name="Task", instructions="Draw something")
|
||||
chat_log = conversation.ChatLog(hass, "conversation-id")
|
||||
|
||||
with pytest.raises(expected_exception, match=message):
|
||||
await mock_cloud_ai_task_entity._async_generate_image(task, chat_log)
|
||||
219
tests/components/cloud/test_entity.py
Normal file
219
tests/components/cloud/test_entity.py
Normal file
@@ -0,0 +1,219 @@
|
||||
"""Tests for helpers in the Home Assistant Cloud conversation entity."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from PIL import Image
|
||||
import pytest
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import conversation
|
||||
from homeassistant.components.cloud.entity import (
|
||||
BaseCloudLLMEntity,
|
||||
_format_structured_output,
|
||||
)
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import llm, selector
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cloud_entity(hass: HomeAssistant) -> BaseCloudLLMEntity:
|
||||
"""Return a CloudLLMTaskEntity attached to hass."""
|
||||
cloud = MagicMock()
|
||||
cloud.llm = MagicMock()
|
||||
cloud.is_logged_in = True
|
||||
cloud.valid_subscription = True
|
||||
entry = MockConfigEntry(domain="cloud")
|
||||
entry.add_to_hass(hass)
|
||||
entity = BaseCloudLLMEntity(cloud, entry)
|
||||
entity.entity_id = "ai_task.cloud_ai_task"
|
||||
entity.hass = hass
|
||||
return entity
|
||||
|
||||
|
||||
class DummyTool(llm.Tool):
|
||||
"""Simple tool used for schema conversion tests."""
|
||||
|
||||
name = "do_something"
|
||||
description = "Test tool"
|
||||
parameters = vol.Schema({vol.Required("value"): str})
|
||||
|
||||
async def async_call(self, hass: HomeAssistant, tool_input, llm_context):
|
||||
"""No-op implementation."""
|
||||
return {"value": "done"}
|
||||
|
||||
|
||||
async def test_format_structured_output() -> None:
|
||||
"""Test that structured output schemas are normalized."""
|
||||
schema = vol.Schema(
|
||||
{
|
||||
vol.Required("name"): selector.TextSelector(),
|
||||
vol.Optional("age"): selector.NumberSelector(
|
||||
config=selector.NumberSelectorConfig(min=0, max=120),
|
||||
),
|
||||
vol.Required("stuff"): selector.ObjectSelector(
|
||||
{
|
||||
"multiple": True,
|
||||
"fields": {
|
||||
"item_name": {"selector": {"text": None}},
|
||||
"item_value": {"selector": {"text": None}},
|
||||
},
|
||||
}
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
assert _format_structured_output(schema, None) == {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"age": {"type": "number", "minimum": 0.0, "maximum": 120.0},
|
||||
"stuff": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"item_name": {"type": "string"},
|
||||
"item_value": {"type": "string"},
|
||||
},
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
},
|
||||
"required": ["name", "stuff"],
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
|
||||
async def test_prepare_files_for_prompt(
|
||||
cloud_entity: BaseCloudLLMEntity, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test that media attachments are converted to the expected payload."""
|
||||
image_path = tmp_path / "doorbell.jpg"
|
||||
Image.new("RGB", (2, 2), "blue").save(image_path, "JPEG")
|
||||
pdf_path = tmp_path / "context.pdf"
|
||||
pdf_path.write_bytes(b"%PDF-1.3\nmock\n")
|
||||
|
||||
attachments = [
|
||||
conversation.Attachment(
|
||||
media_content_id="media-source://media/doorbell.jpg",
|
||||
mime_type="image/jpeg",
|
||||
path=image_path,
|
||||
),
|
||||
conversation.Attachment(
|
||||
media_content_id="media-source://media/context.pdf",
|
||||
mime_type="application/pdf",
|
||||
path=pdf_path,
|
||||
),
|
||||
]
|
||||
|
||||
files = await cloud_entity._async_prepare_files_for_prompt(attachments)
|
||||
|
||||
assert files[0] == {
|
||||
"type": "input_image",
|
||||
"image_url": "data:image/jpeg;base64,"
|
||||
+ base64.b64encode(image_path.read_bytes()).decode(),
|
||||
"detail": "auto",
|
||||
}
|
||||
assert files[1] == {
|
||||
"type": "input_file",
|
||||
"filename": "context.pdf",
|
||||
"file_data": "data:application/pdf;base64,"
|
||||
+ base64.b64encode(pdf_path.read_bytes()).decode(),
|
||||
}
|
||||
|
||||
|
||||
async def test_prepare_files_for_prompt_invalid_type(
|
||||
cloud_entity: BaseCloudLLMEntity, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test that unsupported attachments raise an error."""
|
||||
text_path = tmp_path / "notes.txt"
|
||||
text_path.write_text("notes")
|
||||
|
||||
attachments = [
|
||||
conversation.Attachment(
|
||||
media_content_id="media-source://media/notes.txt",
|
||||
mime_type="text/plain",
|
||||
path=text_path,
|
||||
)
|
||||
]
|
||||
|
||||
with pytest.raises(
|
||||
HomeAssistantError,
|
||||
match="Only images and PDF are currently supported as attachments",
|
||||
):
|
||||
await cloud_entity._async_prepare_files_for_prompt(attachments)
|
||||
|
||||
|
||||
async def test_prepare_chat_for_generation_appends_attachments(
|
||||
hass: HomeAssistant,
|
||||
cloud_entity: BaseCloudLLMEntity,
|
||||
mock_prepare_files_for_prompt: AsyncMock,
|
||||
) -> None:
|
||||
"""Test chat preparation adds LLM tools, attachments, and metadata."""
|
||||
chat_log = conversation.ChatLog(hass, "conversation-id")
|
||||
attachment = conversation.Attachment(
|
||||
media_content_id="media-source://media/doorbell.jpg",
|
||||
mime_type="image/jpeg",
|
||||
path=hass.config.path("doorbell.jpg"),
|
||||
)
|
||||
chat_log.async_add_user_content(
|
||||
conversation.UserContent(content="Describe the door", attachments=[attachment])
|
||||
)
|
||||
chat_log.llm_api = SimpleNamespace(
|
||||
tools=[DummyTool()],
|
||||
custom_serializer=None,
|
||||
)
|
||||
|
||||
files = [{"type": "input_image", "image_url": "data://img", "detail": "auto"}]
|
||||
|
||||
mock_prepare_files_for_prompt.return_value = files
|
||||
response = await cloud_entity._prepare_chat_for_generation(
|
||||
chat_log, response_format={"type": "json"}
|
||||
)
|
||||
|
||||
assert response["conversation_id"] == "conversation-id"
|
||||
assert response["response_format"] == {"type": "json"}
|
||||
assert response["tool_choice"] == "auto"
|
||||
assert len(response["tools"]) == 2
|
||||
assert response["tools"][0]["name"] == "do_something"
|
||||
assert response["tools"][1]["type"] == "web_search"
|
||||
user_message = response["messages"][-1]
|
||||
assert user_message["content"][0] == {
|
||||
"type": "input_text",
|
||||
"text": "Describe the door",
|
||||
}
|
||||
assert user_message["content"][1:] == files
|
||||
|
||||
|
||||
async def test_prepare_chat_for_generation_requires_user_prompt(
|
||||
hass: HomeAssistant, cloud_entity: BaseCloudLLMEntity
|
||||
) -> None:
|
||||
"""Test that we fail fast when there is no user input to process."""
|
||||
chat_log = conversation.ChatLog(hass, "conversation-id")
|
||||
chat_log.async_add_assistant_content_without_tools(
|
||||
conversation.AssistantContent(agent_id="agent", content="Ready")
|
||||
)
|
||||
|
||||
with pytest.raises(HomeAssistantError, match="No user prompt found"):
|
||||
await cloud_entity._prepare_chat_for_generation(chat_log)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_prepare_files_for_prompt(
|
||||
cloud_entity: BaseCloudLLMEntity,
|
||||
) -> AsyncMock:
|
||||
"""Patch file preparation helper on the entity."""
|
||||
with patch.object(
|
||||
cloud_entity,
|
||||
"_async_prepare_files_for_prompt",
|
||||
AsyncMock(),
|
||||
) as mock:
|
||||
yield mock
|
||||
Reference in New Issue
Block a user