diff --git a/homeassistant/components/ai_task/__init__.py b/homeassistant/components/ai_task/__init__.py index 1e317186ee4..daaf190fc55 100644 --- a/homeassistant/components/ai_task/__init__.py +++ b/homeassistant/components/ai_task/__init__.py @@ -126,7 +126,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: schema=vol.Schema( { vol.Required(ATTR_TASK_NAME): cv.string, - vol.Required(ATTR_ENTITY_ID): cv.entity_id, + vol.Optional(ATTR_ENTITY_ID): cv.entity_id, vol.Required(ATTR_INSTRUCTIONS): cv.string, vol.Optional(ATTR_ATTACHMENTS): vol.All( cv.ensure_list, [selector.MediaSelector({"accept": ["*/*"]})] @@ -163,9 +163,10 @@ async def async_service_generate_image(call: ServiceCall) -> ServiceResponse: class AITaskPreferences: """AI Task preferences.""" - KEYS = ("gen_data_entity_id",) + KEYS = ("gen_data_entity_id", "gen_image_entity_id") gen_data_entity_id: str | None = None + gen_image_entity_id: str | None = None def __init__(self, hass: HomeAssistant) -> None: """Initialize the preferences.""" @@ -179,17 +180,21 @@ class AITaskPreferences: if data is None: return for key in self.KEYS: - setattr(self, key, data[key]) + setattr(self, key, data.get(key)) @callback def async_set_preferences( self, *, gen_data_entity_id: str | None | UndefinedType = UNDEFINED, + gen_image_entity_id: str | None | UndefinedType = UNDEFINED, ) -> None: """Set the preferences.""" changed = False - for key, value in (("gen_data_entity_id", gen_data_entity_id),): + for key, value in ( + ("gen_data_entity_id", gen_data_entity_id), + ("gen_image_entity_id", gen_image_entity_id), + ): if value is not UNDEFINED: if getattr(self, key) != value: setattr(self, key, value) diff --git a/homeassistant/components/ai_task/http.py b/homeassistant/components/ai_task/http.py index 5deffa84008..ba6aa63415b 100644 --- a/homeassistant/components/ai_task/http.py +++ b/homeassistant/components/ai_task/http.py @@ -37,6 +37,7 @@ def websocket_get_preferences( { vol.Required("type"): "ai_task/preferences/set", vol.Optional("gen_data_entity_id"): vol.Any(str, None), + vol.Optional("gen_image_entity_id"): vol.Any(str, None), } ) @websocket_api.require_admin diff --git a/homeassistant/components/ai_task/task.py b/homeassistant/components/ai_task/task.py index cc333cc7b62..a7fd6758943 100644 --- a/homeassistant/components/ai_task/task.py +++ b/homeassistant/components/ai_task/task.py @@ -177,11 +177,17 @@ async def async_generate_image( hass: HomeAssistant, *, task_name: str, - entity_id: str, + entity_id: str | None = None, instructions: str, attachments: list[dict] | None = None, ) -> ServiceResponse: """Run an image generation task in the AI Task integration.""" + if entity_id is None: + entity_id = hass.data[DATA_PREFERENCES].gen_image_entity_id + + if entity_id is None: + raise HomeAssistantError("No entity_id provided and no preferred entity set") + entity = hass.data[DATA_COMPONENT].get_entity(entity_id) if entity is None: raise HomeAssistantError(f"AI Task entity {entity_id} not found") diff --git a/tests/components/ai_task/test_http.py b/tests/components/ai_task/test_http.py index a2eecfddf74..545dce0c1c2 100644 --- a/tests/components/ai_task/test_http.py +++ b/tests/components/ai_task/test_http.py @@ -19,6 +19,7 @@ async def test_ws_preferences( assert msg["success"] assert msg["result"] == { "gen_data_entity_id": None, + "gen_image_entity_id": None, } # Set preferences @@ -32,6 +33,7 @@ async def test_ws_preferences( assert msg["success"] assert msg["result"] == { "gen_data_entity_id": "ai_task.summary_1", + "gen_image_entity_id": None, } # Get updated preferences @@ -40,6 +42,7 @@ async def test_ws_preferences( assert msg["success"] assert msg["result"] == { "gen_data_entity_id": "ai_task.summary_1", + "gen_image_entity_id": None, } # Update an existing preference @@ -53,6 +56,7 @@ async def test_ws_preferences( assert msg["success"] assert msg["result"] == { "gen_data_entity_id": "ai_task.summary_2", + "gen_image_entity_id": None, } # Get updated preferences @@ -61,6 +65,7 @@ async def test_ws_preferences( assert msg["success"] assert msg["result"] == { "gen_data_entity_id": "ai_task.summary_2", + "gen_image_entity_id": None, } # No preferences set will preserve existing preferences @@ -73,6 +78,7 @@ async def test_ws_preferences( assert msg["success"] assert msg["result"] == { "gen_data_entity_id": "ai_task.summary_2", + "gen_image_entity_id": None, } # Get updated preferences @@ -81,4 +87,43 @@ async def test_ws_preferences( assert msg["success"] assert msg["result"] == { "gen_data_entity_id": "ai_task.summary_2", + "gen_image_entity_id": None, + } + + # Set gen_image_entity_id preference + await client.send_json_auto_id( + { + "type": "ai_task/preferences/set", + "gen_image_entity_id": "ai_task.image_gen_1", + } + ) + msg = await client.receive_json() + assert msg["success"] + assert msg["result"] == { + "gen_data_entity_id": "ai_task.summary_2", + "gen_image_entity_id": "ai_task.image_gen_1", + } + + # Update both preferences + await client.send_json_auto_id( + { + "type": "ai_task/preferences/set", + "gen_data_entity_id": "ai_task.summary_3", + "gen_image_entity_id": "ai_task.image_gen_2", + } + ) + msg = await client.receive_json() + assert msg["success"] + assert msg["result"] == { + "gen_data_entity_id": "ai_task.summary_3", + "gen_image_entity_id": "ai_task.image_gen_2", + } + + # Get final preferences + await client.send_json_auto_id({"type": "ai_task/preferences/get"}) + msg = await client.receive_json() + assert msg["success"] + assert msg["result"] == { + "gen_data_entity_id": "ai_task.summary_3", + "gen_image_entity_id": "ai_task.image_gen_2", } diff --git a/tests/components/ai_task/test_init.py b/tests/components/ai_task/test_init.py index 09ee926c187..e89e4cea670 100644 --- a/tests/components/ai_task/test_init.py +++ b/tests/components/ai_task/test_init.py @@ -12,6 +12,7 @@ from homeassistant.components import media_source from homeassistant.components.ai_task import AITaskPreferences from homeassistant.components.ai_task.const import DATA_PREFERENCES from homeassistant.core import HomeAssistant +from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import selector from .conftest import TEST_ENTITY_ID, MockAITaskEntity @@ -277,3 +278,70 @@ async def test_generate_data_service_invalid_structure( blocking=True, return_response=True, ) + + +@pytest.mark.parametrize( + ("set_preferences", "msg_extra"), + [ + ({}, {"entity_id": TEST_ENTITY_ID}), + ({"gen_image_entity_id": TEST_ENTITY_ID}, {}), + ( + {"gen_image_entity_id": "ai_task.other_entity"}, + {"entity_id": TEST_ENTITY_ID}, + ), + ], +) +async def test_generate_image_service( + hass: HomeAssistant, + init_components: None, + set_preferences: dict[str, str | None], + msg_extra: dict[str, str], + mock_ai_task_entity: MockAITaskEntity, +) -> None: + """Test the generate image service.""" + preferences = hass.data[DATA_PREFERENCES] + preferences.async_set_preferences(**set_preferences) + + result = await hass.services.async_call( + "ai_task", + "generate_image", + { + "task_name": "Test Image", + "instructions": "Generate a test image", + } + | msg_extra, + blocking=True, + return_response=True, + ) + + assert "image_data" not in result + assert result["media_source_id"].startswith("media-source://ai_task/images/") + assert result["url"].startswith("http://10.10.10.10:8123/api/ai_task/images/") + assert result["mime_type"] == "image/png" + assert result["model"] == "mock_model" + assert result["revised_prompt"] == "mock_revised_prompt" + + assert len(mock_ai_task_entity.mock_generate_image_tasks) == 1 + task = mock_ai_task_entity.mock_generate_image_tasks[0] + assert task.instructions == "Generate a test image" + + +async def test_generate_image_service_no_entity( + hass: HomeAssistant, + init_components: None, +) -> None: + """Test the generate image service with no entity specified.""" + with pytest.raises( + HomeAssistantError, + match="No entity_id provided and no preferred entity set", + ): + await hass.services.async_call( + "ai_task", + "generate_image", + { + "task_name": "Test Image", + "instructions": "Generate a test image", + }, + blocking=True, + return_response=True, + )