Allow storing AI Task generate image preferred entity (#151938)

This commit is contained in:
Paulus Schoutsen
2025-09-09 13:29:14 -04:00
committed by GitHub
parent eaf400f3b7
commit 285619e913
5 changed files with 130 additions and 5 deletions

View File

@@ -126,7 +126,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
schema=vol.Schema(
{
vol.Required(ATTR_TASK_NAME): cv.string,
vol.Required(ATTR_ENTITY_ID): cv.entity_id,
vol.Optional(ATTR_ENTITY_ID): cv.entity_id,
vol.Required(ATTR_INSTRUCTIONS): cv.string,
vol.Optional(ATTR_ATTACHMENTS): vol.All(
cv.ensure_list, [selector.MediaSelector({"accept": ["*/*"]})]
@@ -163,9 +163,10 @@ async def async_service_generate_image(call: ServiceCall) -> ServiceResponse:
class AITaskPreferences:
"""AI Task preferences."""
KEYS = ("gen_data_entity_id",)
KEYS = ("gen_data_entity_id", "gen_image_entity_id")
gen_data_entity_id: str | None = None
gen_image_entity_id: str | None = None
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the preferences."""
@@ -179,17 +180,21 @@ class AITaskPreferences:
if data is None:
return
for key in self.KEYS:
setattr(self, key, data[key])
setattr(self, key, data.get(key))
@callback
def async_set_preferences(
self,
*,
gen_data_entity_id: str | None | UndefinedType = UNDEFINED,
gen_image_entity_id: str | None | UndefinedType = UNDEFINED,
) -> None:
"""Set the preferences."""
changed = False
for key, value in (("gen_data_entity_id", gen_data_entity_id),):
for key, value in (
("gen_data_entity_id", gen_data_entity_id),
("gen_image_entity_id", gen_image_entity_id),
):
if value is not UNDEFINED:
if getattr(self, key) != value:
setattr(self, key, value)

View File

@@ -37,6 +37,7 @@ def websocket_get_preferences(
{
vol.Required("type"): "ai_task/preferences/set",
vol.Optional("gen_data_entity_id"): vol.Any(str, None),
vol.Optional("gen_image_entity_id"): vol.Any(str, None),
}
)
@websocket_api.require_admin

View File

@@ -177,11 +177,17 @@ async def async_generate_image(
hass: HomeAssistant,
*,
task_name: str,
entity_id: str,
entity_id: str | None = None,
instructions: str,
attachments: list[dict] | None = None,
) -> ServiceResponse:
"""Run an image generation task in the AI Task integration."""
if entity_id is None:
entity_id = hass.data[DATA_PREFERENCES].gen_image_entity_id
if entity_id is None:
raise HomeAssistantError("No entity_id provided and no preferred entity set")
entity = hass.data[DATA_COMPONENT].get_entity(entity_id)
if entity is None:
raise HomeAssistantError(f"AI Task entity {entity_id} not found")

View File

@@ -19,6 +19,7 @@ async def test_ws_preferences(
assert msg["success"]
assert msg["result"] == {
"gen_data_entity_id": None,
"gen_image_entity_id": None,
}
# Set preferences
@@ -32,6 +33,7 @@ async def test_ws_preferences(
assert msg["success"]
assert msg["result"] == {
"gen_data_entity_id": "ai_task.summary_1",
"gen_image_entity_id": None,
}
# Get updated preferences
@@ -40,6 +42,7 @@ async def test_ws_preferences(
assert msg["success"]
assert msg["result"] == {
"gen_data_entity_id": "ai_task.summary_1",
"gen_image_entity_id": None,
}
# Update an existing preference
@@ -53,6 +56,7 @@ async def test_ws_preferences(
assert msg["success"]
assert msg["result"] == {
"gen_data_entity_id": "ai_task.summary_2",
"gen_image_entity_id": None,
}
# Get updated preferences
@@ -61,6 +65,7 @@ async def test_ws_preferences(
assert msg["success"]
assert msg["result"] == {
"gen_data_entity_id": "ai_task.summary_2",
"gen_image_entity_id": None,
}
# No preferences set will preserve existing preferences
@@ -73,6 +78,7 @@ async def test_ws_preferences(
assert msg["success"]
assert msg["result"] == {
"gen_data_entity_id": "ai_task.summary_2",
"gen_image_entity_id": None,
}
# Get updated preferences
@@ -81,4 +87,43 @@ async def test_ws_preferences(
assert msg["success"]
assert msg["result"] == {
"gen_data_entity_id": "ai_task.summary_2",
"gen_image_entity_id": None,
}
# Set gen_image_entity_id preference
await client.send_json_auto_id(
{
"type": "ai_task/preferences/set",
"gen_image_entity_id": "ai_task.image_gen_1",
}
)
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {
"gen_data_entity_id": "ai_task.summary_2",
"gen_image_entity_id": "ai_task.image_gen_1",
}
# Update both preferences
await client.send_json_auto_id(
{
"type": "ai_task/preferences/set",
"gen_data_entity_id": "ai_task.summary_3",
"gen_image_entity_id": "ai_task.image_gen_2",
}
)
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {
"gen_data_entity_id": "ai_task.summary_3",
"gen_image_entity_id": "ai_task.image_gen_2",
}
# Get final preferences
await client.send_json_auto_id({"type": "ai_task/preferences/get"})
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {
"gen_data_entity_id": "ai_task.summary_3",
"gen_image_entity_id": "ai_task.image_gen_2",
}

View File

@@ -12,6 +12,7 @@ from homeassistant.components import media_source
from homeassistant.components.ai_task import AITaskPreferences
from homeassistant.components.ai_task.const import DATA_PREFERENCES
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import selector
from .conftest import TEST_ENTITY_ID, MockAITaskEntity
@@ -277,3 +278,70 @@ async def test_generate_data_service_invalid_structure(
blocking=True,
return_response=True,
)
@pytest.mark.parametrize(
("set_preferences", "msg_extra"),
[
({}, {"entity_id": TEST_ENTITY_ID}),
({"gen_image_entity_id": TEST_ENTITY_ID}, {}),
(
{"gen_image_entity_id": "ai_task.other_entity"},
{"entity_id": TEST_ENTITY_ID},
),
],
)
async def test_generate_image_service(
hass: HomeAssistant,
init_components: None,
set_preferences: dict[str, str | None],
msg_extra: dict[str, str],
mock_ai_task_entity: MockAITaskEntity,
) -> None:
"""Test the generate image service."""
preferences = hass.data[DATA_PREFERENCES]
preferences.async_set_preferences(**set_preferences)
result = await hass.services.async_call(
"ai_task",
"generate_image",
{
"task_name": "Test Image",
"instructions": "Generate a test image",
}
| msg_extra,
blocking=True,
return_response=True,
)
assert "image_data" not in result
assert result["media_source_id"].startswith("media-source://ai_task/images/")
assert result["url"].startswith("http://10.10.10.10:8123/api/ai_task/images/")
assert result["mime_type"] == "image/png"
assert result["model"] == "mock_model"
assert result["revised_prompt"] == "mock_revised_prompt"
assert len(mock_ai_task_entity.mock_generate_image_tasks) == 1
task = mock_ai_task_entity.mock_generate_image_tasks[0]
assert task.instructions == "Generate a test image"
async def test_generate_image_service_no_entity(
hass: HomeAssistant,
init_components: None,
) -> None:
"""Test the generate image service with no entity specified."""
with pytest.raises(
HomeAssistantError,
match="No entity_id provided and no preferred entity set",
):
await hass.services.async_call(
"ai_task",
"generate_image",
{
"task_name": "Test Image",
"instructions": "Generate a test image",
},
blocking=True,
return_response=True,
)