diff --git a/homeassistant/components/script/__init__.py b/homeassistant/components/script/__init__.py index dd293726484..8b1e2ce8b30 100644 --- a/homeassistant/components/script/__init__.py +++ b/homeassistant/components/script/__init__.py @@ -11,8 +11,9 @@ from typing import TYPE_CHECKING, Any, cast from propcache.api import cached_property import voluptuous as vol -from homeassistant.components import websocket_api +from homeassistant.components import automation, websocket_api from homeassistant.components.blueprint import CONF_USE_BLUEPRINT +from homeassistant.components.labs import async_listen as async_labs_listen from homeassistant.const import ( ATTR_ENTITY_ID, ATTR_MODE, @@ -280,6 +281,21 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: hass.services.async_register( DOMAIN, SERVICE_TOGGLE, toggle_service, schema=SCRIPT_TURN_ONOFF_SCHEMA ) + + @callback + def new_triggers_conditions_listener() -> None: + """Handle new_triggers_conditions flag change.""" + hass.async_create_task( + reload_service(ServiceCall(hass, DOMAIN, SERVICE_RELOAD)) + ) + + async_labs_listen( + hass, + automation.DOMAIN, + automation.NEW_TRIGGERS_CONDITIONS_FEATURE_FLAG, + new_triggers_conditions_listener, + ) + websocket_api.async_register_command(hass, websocket_config) return True diff --git a/tests/components/script/test_init.py b/tests/components/script/test_init.py index ff068d6f952..a0cf18b8785 100644 --- a/tests/components/script/test_init.py +++ b/tests/components/script/test_init.py @@ -7,7 +7,7 @@ from unittest.mock import ANY, Mock, patch import pytest -from homeassistant.components import script +from homeassistant.components import labs, script from homeassistant.components.script import DOMAIN, EVENT_SCRIPT_STARTED, ScriptEntity from homeassistant.config_entries import ConfigEntryState from homeassistant.const import ( @@ -1868,3 +1868,64 @@ async def test_script_queued_mode(hass: HomeAssistant) -> None: await hass.services.async_call("script", "test_main", blocking=True) assert calls == 4 + + +async def test_reload_when_labs_flag_changes( + hass: HomeAssistant, + hass_ws_client: WebSocketGenerator, +) -> None: + """Test scripts are reloaded when labs flag changes.""" + event = "test_event" + hass.states.async_set("test.script", "off") + + ws_client = await hass_ws_client(hass) + + assert await async_setup_component( + hass, + "script", + { + "script": { + "test": { + "sequence": [ + {"event": event}, + {"wait_template": "{{ is_state('test.script', 'on') }}"}, + ] + } + } + }, + ) + assert await async_setup_component(hass, labs.DOMAIN, {}) + + assert hass.states.get(ENTITY_ID) is not None + assert hass.services.has_service(script.DOMAIN, "test") + + for enabled, active_object_id, inactive_object_ids in ( + (False, "test2", ("test",)), + (True, "test3", ("test", "test2")), + ): + with patch( + "homeassistant.config.load_yaml_config_file", + return_value={ + "script": {active_object_id: {"sequence": [{"delay": {"seconds": 5}}]}} + }, + ): + await ws_client.send_json_auto_id( + { + "type": "labs/update", + "domain": "automation", + "preview_feature": "new_triggers_conditions", + "enabled": enabled, + } + ) + + msg = await ws_client.receive_json() + assert msg["success"] + await hass.async_block_till_done() + + for inactive_object_id in inactive_object_ids: + state = hass.states.get(f"script.{inactive_object_id}") + assert state.attributes["restored"] is True + assert not hass.services.has_service(script.DOMAIN, inactive_object_id) + + assert hass.states.get(f"script.{active_object_id}") is not None + assert hass.services.has_service(script.DOMAIN, active_object_id)