mirror of
https://github.com/Electric-Special/ha-core.git
synced 2026-03-21 05:06:13 +01:00
Allow storing AI Task generate image preferred entity (#151938)
This commit is contained in:
@@ -126,7 +126,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
schema=vol.Schema(
|
||||
{
|
||||
vol.Required(ATTR_TASK_NAME): cv.string,
|
||||
vol.Required(ATTR_ENTITY_ID): cv.entity_id,
|
||||
vol.Optional(ATTR_ENTITY_ID): cv.entity_id,
|
||||
vol.Required(ATTR_INSTRUCTIONS): cv.string,
|
||||
vol.Optional(ATTR_ATTACHMENTS): vol.All(
|
||||
cv.ensure_list, [selector.MediaSelector({"accept": ["*/*"]})]
|
||||
@@ -163,9 +163,10 @@ async def async_service_generate_image(call: ServiceCall) -> ServiceResponse:
|
||||
class AITaskPreferences:
|
||||
"""AI Task preferences."""
|
||||
|
||||
KEYS = ("gen_data_entity_id",)
|
||||
KEYS = ("gen_data_entity_id", "gen_image_entity_id")
|
||||
|
||||
gen_data_entity_id: str | None = None
|
||||
gen_image_entity_id: str | None = None
|
||||
|
||||
def __init__(self, hass: HomeAssistant) -> None:
|
||||
"""Initialize the preferences."""
|
||||
@@ -179,17 +180,21 @@ class AITaskPreferences:
|
||||
if data is None:
|
||||
return
|
||||
for key in self.KEYS:
|
||||
setattr(self, key, data[key])
|
||||
setattr(self, key, data.get(key))
|
||||
|
||||
@callback
|
||||
def async_set_preferences(
|
||||
self,
|
||||
*,
|
||||
gen_data_entity_id: str | None | UndefinedType = UNDEFINED,
|
||||
gen_image_entity_id: str | None | UndefinedType = UNDEFINED,
|
||||
) -> None:
|
||||
"""Set the preferences."""
|
||||
changed = False
|
||||
for key, value in (("gen_data_entity_id", gen_data_entity_id),):
|
||||
for key, value in (
|
||||
("gen_data_entity_id", gen_data_entity_id),
|
||||
("gen_image_entity_id", gen_image_entity_id),
|
||||
):
|
||||
if value is not UNDEFINED:
|
||||
if getattr(self, key) != value:
|
||||
setattr(self, key, value)
|
||||
|
||||
@@ -37,6 +37,7 @@ def websocket_get_preferences(
|
||||
{
|
||||
vol.Required("type"): "ai_task/preferences/set",
|
||||
vol.Optional("gen_data_entity_id"): vol.Any(str, None),
|
||||
vol.Optional("gen_image_entity_id"): vol.Any(str, None),
|
||||
}
|
||||
)
|
||||
@websocket_api.require_admin
|
||||
|
||||
@@ -177,11 +177,17 @@ async def async_generate_image(
|
||||
hass: HomeAssistant,
|
||||
*,
|
||||
task_name: str,
|
||||
entity_id: str,
|
||||
entity_id: str | None = None,
|
||||
instructions: str,
|
||||
attachments: list[dict] | None = None,
|
||||
) -> ServiceResponse:
|
||||
"""Run an image generation task in the AI Task integration."""
|
||||
if entity_id is None:
|
||||
entity_id = hass.data[DATA_PREFERENCES].gen_image_entity_id
|
||||
|
||||
if entity_id is None:
|
||||
raise HomeAssistantError("No entity_id provided and no preferred entity set")
|
||||
|
||||
entity = hass.data[DATA_COMPONENT].get_entity(entity_id)
|
||||
if entity is None:
|
||||
raise HomeAssistantError(f"AI Task entity {entity_id} not found")
|
||||
|
||||
@@ -19,6 +19,7 @@ async def test_ws_preferences(
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {
|
||||
"gen_data_entity_id": None,
|
||||
"gen_image_entity_id": None,
|
||||
}
|
||||
|
||||
# Set preferences
|
||||
@@ -32,6 +33,7 @@ async def test_ws_preferences(
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {
|
||||
"gen_data_entity_id": "ai_task.summary_1",
|
||||
"gen_image_entity_id": None,
|
||||
}
|
||||
|
||||
# Get updated preferences
|
||||
@@ -40,6 +42,7 @@ async def test_ws_preferences(
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {
|
||||
"gen_data_entity_id": "ai_task.summary_1",
|
||||
"gen_image_entity_id": None,
|
||||
}
|
||||
|
||||
# Update an existing preference
|
||||
@@ -53,6 +56,7 @@ async def test_ws_preferences(
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {
|
||||
"gen_data_entity_id": "ai_task.summary_2",
|
||||
"gen_image_entity_id": None,
|
||||
}
|
||||
|
||||
# Get updated preferences
|
||||
@@ -61,6 +65,7 @@ async def test_ws_preferences(
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {
|
||||
"gen_data_entity_id": "ai_task.summary_2",
|
||||
"gen_image_entity_id": None,
|
||||
}
|
||||
|
||||
# No preferences set will preserve existing preferences
|
||||
@@ -73,6 +78,7 @@ async def test_ws_preferences(
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {
|
||||
"gen_data_entity_id": "ai_task.summary_2",
|
||||
"gen_image_entity_id": None,
|
||||
}
|
||||
|
||||
# Get updated preferences
|
||||
@@ -81,4 +87,43 @@ async def test_ws_preferences(
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {
|
||||
"gen_data_entity_id": "ai_task.summary_2",
|
||||
"gen_image_entity_id": None,
|
||||
}
|
||||
|
||||
# Set gen_image_entity_id preference
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "ai_task/preferences/set",
|
||||
"gen_image_entity_id": "ai_task.image_gen_1",
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {
|
||||
"gen_data_entity_id": "ai_task.summary_2",
|
||||
"gen_image_entity_id": "ai_task.image_gen_1",
|
||||
}
|
||||
|
||||
# Update both preferences
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "ai_task/preferences/set",
|
||||
"gen_data_entity_id": "ai_task.summary_3",
|
||||
"gen_image_entity_id": "ai_task.image_gen_2",
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {
|
||||
"gen_data_entity_id": "ai_task.summary_3",
|
||||
"gen_image_entity_id": "ai_task.image_gen_2",
|
||||
}
|
||||
|
||||
# Get final preferences
|
||||
await client.send_json_auto_id({"type": "ai_task/preferences/get"})
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {
|
||||
"gen_data_entity_id": "ai_task.summary_3",
|
||||
"gen_image_entity_id": "ai_task.image_gen_2",
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ from homeassistant.components import media_source
|
||||
from homeassistant.components.ai_task import AITaskPreferences
|
||||
from homeassistant.components.ai_task.const import DATA_PREFERENCES
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import selector
|
||||
|
||||
from .conftest import TEST_ENTITY_ID, MockAITaskEntity
|
||||
@@ -277,3 +278,70 @@ async def test_generate_data_service_invalid_structure(
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("set_preferences", "msg_extra"),
|
||||
[
|
||||
({}, {"entity_id": TEST_ENTITY_ID}),
|
||||
({"gen_image_entity_id": TEST_ENTITY_ID}, {}),
|
||||
(
|
||||
{"gen_image_entity_id": "ai_task.other_entity"},
|
||||
{"entity_id": TEST_ENTITY_ID},
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_generate_image_service(
|
||||
hass: HomeAssistant,
|
||||
init_components: None,
|
||||
set_preferences: dict[str, str | None],
|
||||
msg_extra: dict[str, str],
|
||||
mock_ai_task_entity: MockAITaskEntity,
|
||||
) -> None:
|
||||
"""Test the generate image service."""
|
||||
preferences = hass.data[DATA_PREFERENCES]
|
||||
preferences.async_set_preferences(**set_preferences)
|
||||
|
||||
result = await hass.services.async_call(
|
||||
"ai_task",
|
||||
"generate_image",
|
||||
{
|
||||
"task_name": "Test Image",
|
||||
"instructions": "Generate a test image",
|
||||
}
|
||||
| msg_extra,
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
)
|
||||
|
||||
assert "image_data" not in result
|
||||
assert result["media_source_id"].startswith("media-source://ai_task/images/")
|
||||
assert result["url"].startswith("http://10.10.10.10:8123/api/ai_task/images/")
|
||||
assert result["mime_type"] == "image/png"
|
||||
assert result["model"] == "mock_model"
|
||||
assert result["revised_prompt"] == "mock_revised_prompt"
|
||||
|
||||
assert len(mock_ai_task_entity.mock_generate_image_tasks) == 1
|
||||
task = mock_ai_task_entity.mock_generate_image_tasks[0]
|
||||
assert task.instructions == "Generate a test image"
|
||||
|
||||
|
||||
async def test_generate_image_service_no_entity(
|
||||
hass: HomeAssistant,
|
||||
init_components: None,
|
||||
) -> None:
|
||||
"""Test the generate image service with no entity specified."""
|
||||
with pytest.raises(
|
||||
HomeAssistantError,
|
||||
match="No entity_id provided and no preferred entity set",
|
||||
):
|
||||
await hass.services.async_call(
|
||||
"ai_task",
|
||||
"generate_image",
|
||||
{
|
||||
"task_name": "Test Image",
|
||||
"instructions": "Generate a test image",
|
||||
},
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user