diff --git a/homeassistant/components/automation/__init__.py b/homeassistant/components/automation/__init__.py index 89544e83562..dbae4c5af89 100644 --- a/homeassistant/components/automation/__init__.py +++ b/homeassistant/components/automation/__init__.py @@ -7,7 +7,7 @@ import asyncio from collections.abc import Callable, Mapping from dataclasses import dataclass import logging -from typing import Any, Literal, Protocol, cast +from typing import Any, Protocol, cast from propcache.api import cached_property import voluptuous as vol @@ -25,18 +25,11 @@ from homeassistant.const import ( CONF_ACTIONS, CONF_ALIAS, CONF_CONDITIONS, - CONF_DEVICE_ID, - CONF_ENTITY_ID, - CONF_EVENT_DATA, CONF_ID, CONF_MODE, - CONF_OPTIONS, CONF_PATH, - CONF_PLATFORM, - CONF_TARGET, CONF_TRIGGERS, CONF_VARIABLES, - CONF_ZONE, EVENT_HOMEASSISTANT_STARTED, SERVICE_RELOAD, SERVICE_TOGGLE, @@ -53,10 +46,13 @@ from homeassistant.core import ( ServiceCall, callback, split_entity_id, - valid_entity_id, ) from homeassistant.exceptions import HomeAssistantError, ServiceNotFound, TemplateError -from homeassistant.helpers import condition as condition_helper, config_validation as cv +from homeassistant.helpers import ( + condition as condition_helper, + config_validation as cv, + trigger as trigger_helper, +) from homeassistant.helpers.entity import ToggleEntity from homeassistant.helpers.entity_component import EntityComponent from homeassistant.helpers.issue_registry import ( @@ -86,7 +82,6 @@ from homeassistant.helpers.trace import ( trace_get, trace_path, ) -from homeassistant.helpers.trigger import async_initialize_triggers from homeassistant.helpers.typing import ConfigType from homeassistant.loader import bind_hass from homeassistant.util.dt import parse_datetime @@ -618,7 +613,7 @@ class AutomationEntity(BaseAutomationEntity, RestoreEntity): ) for conf in self._trigger_config: - referenced |= set(_get_targets_from_trigger_config(conf, ATTR_LABEL_ID)) + referenced |= set(trigger_helper.async_extract_targets(conf, ATTR_LABEL_ID)) return referenced @cached_property @@ -633,7 +628,7 @@ class AutomationEntity(BaseAutomationEntity, RestoreEntity): ) for conf in self._trigger_config: - referenced |= set(_get_targets_from_trigger_config(conf, ATTR_FLOOR_ID)) + referenced |= set(trigger_helper.async_extract_targets(conf, ATTR_FLOOR_ID)) return referenced @cached_property @@ -646,7 +641,7 @@ class AutomationEntity(BaseAutomationEntity, RestoreEntity): referenced |= condition_helper.async_extract_targets(conf, ATTR_AREA_ID) for conf in self._trigger_config: - referenced |= set(_get_targets_from_trigger_config(conf, ATTR_AREA_ID)) + referenced |= set(trigger_helper.async_extract_targets(conf, ATTR_AREA_ID)) return referenced @property @@ -666,7 +661,7 @@ class AutomationEntity(BaseAutomationEntity, RestoreEntity): referenced |= condition_helper.async_extract_devices(conf) for conf in self._trigger_config: - referenced |= set(_trigger_extract_devices(conf)) + referenced |= set(trigger_helper.async_extract_devices(conf)) return referenced @@ -680,7 +675,7 @@ class AutomationEntity(BaseAutomationEntity, RestoreEntity): referenced |= condition_helper.async_extract_entities(conf) for conf in self._trigger_config: - for entity_id in _trigger_extract_entities(conf): + for entity_id in trigger_helper.async_extract_entities(conf): referenced.add(entity_id) return referenced @@ -954,7 +949,7 @@ class AutomationEntity(BaseAutomationEntity, RestoreEntity): self._logger.error("Error rendering trigger variables: %s", err) return None - return await async_initialize_triggers( + return await trigger_helper.async_initialize_triggers( self.hass, self._trigger_config, self._async_trigger_if_enabled, @@ -1238,78 +1233,6 @@ async def _async_process_if( return result -@callback -def _trigger_extract_devices(trigger_conf: dict) -> list[str]: - """Extract devices from a trigger config.""" - if trigger_conf[CONF_PLATFORM] == "device": - return [trigger_conf[CONF_DEVICE_ID]] - - if ( - trigger_conf[CONF_PLATFORM] == "event" - and CONF_EVENT_DATA in trigger_conf - and CONF_DEVICE_ID in trigger_conf[CONF_EVENT_DATA] - and isinstance(trigger_conf[CONF_EVENT_DATA][CONF_DEVICE_ID], str) - ): - return [trigger_conf[CONF_EVENT_DATA][CONF_DEVICE_ID]] - - if trigger_conf[CONF_PLATFORM] == "tag" and CONF_DEVICE_ID in trigger_conf: - return trigger_conf[CONF_DEVICE_ID] # type: ignore[no-any-return] - - if target_devices := _get_targets_from_trigger_config(trigger_conf, CONF_DEVICE_ID): - return target_devices - - return [] - - -@callback -def _trigger_extract_entities(trigger_conf: dict) -> list[str]: - """Extract entities from a trigger config.""" - if trigger_conf[CONF_PLATFORM] in ("state", "numeric_state"): - return trigger_conf[CONF_ENTITY_ID] # type: ignore[no-any-return] - - if trigger_conf[CONF_PLATFORM] == "calendar": - return [trigger_conf[CONF_OPTIONS][CONF_ENTITY_ID]] - - if trigger_conf[CONF_PLATFORM] == "zone": - return trigger_conf[CONF_ENTITY_ID] + [trigger_conf[CONF_ZONE]] # type: ignore[no-any-return] - - if trigger_conf[CONF_PLATFORM] == "geo_location": - return [trigger_conf[CONF_ZONE]] - - if trigger_conf[CONF_PLATFORM] == "sun": - return ["sun.sun"] - - if ( - trigger_conf[CONF_PLATFORM] == "event" - and CONF_EVENT_DATA in trigger_conf - and CONF_ENTITY_ID in trigger_conf[CONF_EVENT_DATA] - and isinstance(trigger_conf[CONF_EVENT_DATA][CONF_ENTITY_ID], str) - and valid_entity_id(trigger_conf[CONF_EVENT_DATA][CONF_ENTITY_ID]) - ): - return [trigger_conf[CONF_EVENT_DATA][CONF_ENTITY_ID]] - - if target_entities := _get_targets_from_trigger_config( - trigger_conf, CONF_ENTITY_ID - ): - return target_entities - - return [] - - -@callback -def _get_targets_from_trigger_config( - config: dict, - target: Literal["entity_id", "device_id", "area_id", "floor_id", "label_id"], -) -> list[str]: - """Extract targets from a target config.""" - if not (target_conf := config.get(CONF_TARGET)): - return [] - if not (targets := target_conf.get(target)): - return [] - - return [targets] if isinstance(targets, str) else targets - - @websocket_api.websocket_command({"type": "automation/config", "entity_id": str}) def websocket_config( hass: HomeAssistant, diff --git a/homeassistant/helpers/script.py b/homeassistant/helpers/script.py index 81d91ac8042..98f23ecd47e 100644 --- a/homeassistant/helpers/script.py +++ b/homeassistant/helpers/script.py @@ -85,7 +85,13 @@ from homeassistant.util.dt import utcnow from homeassistant.util.hass_dict import HassKey from homeassistant.util.signal_type import SignalType, SignalTypeFormat -from . import condition, config_validation as cv, service, template +from . import ( + condition, + config_validation as cv, + service, + template, + trigger as trigger_helper, +) from .condition import ConditionCheckerTypeOptional, trace_condition_function from .dispatcher import async_dispatcher_connect, async_dispatcher_send_internal from .event import async_call_later, async_track_template @@ -107,7 +113,6 @@ from .trace import ( trace_stack_top, trace_update_result, ) -from .trigger import async_initialize_triggers, async_validate_trigger_config from .typing import UNDEFINED, ConfigType, TemplateVarsType, UndefinedType SCRIPT_MODE_PARALLEL = "parallel" @@ -319,7 +324,9 @@ async def async_validate_action_config( config = await condition.async_validate_condition_config(hass, config) elif action_type == cv.SCRIPT_ACTION_WAIT_FOR_TRIGGER: - config[CONF_WAIT_FOR_TRIGGER] = await async_validate_trigger_config( + config[ + CONF_WAIT_FOR_TRIGGER + ] = await trigger_helper.async_validate_trigger_config( hass, config[CONF_WAIT_FOR_TRIGGER] ) @@ -1232,7 +1239,7 @@ class _ScriptRun: def log_cb(level: int, msg: str, **kwargs: Any) -> None: self._log(msg, level=level, **kwargs) - remove_triggers = await async_initialize_triggers( + remove_triggers = await trigger_helper.async_initialize_triggers( self._hass, self._action[CONF_WAIT_FOR_TRIGGER], async_done, @@ -1604,6 +1611,12 @@ class Script: elif action == cv.SCRIPT_ACTION_CHECK_CONDITION: referenced |= condition.async_extract_targets(step, target) + elif action == cv.SCRIPT_ACTION_WAIT_FOR_TRIGGER: + for trigger in step[CONF_WAIT_FOR_TRIGGER]: + referenced |= set( + trigger_helper.async_extract_targets(trigger, target) + ) + elif action == cv.SCRIPT_ACTION_CHOOSE: for choice in step[CONF_CHOOSE]: for cond in choice[CONF_CONDITIONS]: @@ -1657,6 +1670,10 @@ class Script: elif action == cv.SCRIPT_ACTION_CHECK_CONDITION: referenced |= condition.async_extract_devices(step) + elif action == cv.SCRIPT_ACTION_WAIT_FOR_TRIGGER: + for trigger in step[CONF_WAIT_FOR_TRIGGER]: + referenced |= set(trigger_helper.async_extract_devices(trigger)) + elif action == cv.SCRIPT_ACTION_DEVICE_AUTOMATION: referenced.add(step[CONF_DEVICE_ID]) @@ -1708,6 +1725,10 @@ class Script: elif action == cv.SCRIPT_ACTION_CHECK_CONDITION: referenced |= condition.async_extract_entities(step) + elif action == cv.SCRIPT_ACTION_WAIT_FOR_TRIGGER: + for trigger in step[CONF_WAIT_FOR_TRIGGER]: + referenced |= set(trigger_helper.async_extract_entities(trigger)) + elif action == cv.SCRIPT_ACTION_ACTIVATE_SCENE: referenced.add(step[CONF_SCENE]) diff --git a/homeassistant/helpers/trigger.py b/homeassistant/helpers/trigger.py index 225c6bcbc66..8c8ea61870e 100644 --- a/homeassistant/helpers/trigger.py +++ b/homeassistant/helpers/trigger.py @@ -11,7 +11,16 @@ from enum import StrEnum import functools import inspect import logging -from typing import TYPE_CHECKING, Any, Final, Protocol, TypedDict, cast, override +from typing import ( + TYPE_CHECKING, + Any, + Final, + Literal, + Protocol, + TypedDict, + cast, + override, +) import voluptuous as vol @@ -20,13 +29,17 @@ from homeassistant.const import ( CONF_ABOVE, CONF_ALIAS, CONF_BELOW, + CONF_DEVICE_ID, CONF_ENABLED, + CONF_ENTITY_ID, + CONF_EVENT_DATA, CONF_ID, CONF_OPTIONS, CONF_PLATFORM, CONF_SELECTOR, CONF_TARGET, CONF_VARIABLES, + CONF_ZONE, STATE_UNAVAILABLE, STATE_UNKNOWN, ) @@ -41,6 +54,7 @@ from homeassistant.core import ( get_hassjob_callable_job_type, is_callback, split_entity_id, + valid_entity_id, ) from homeassistant.exceptions import HomeAssistantError, TemplateError from homeassistant.loader import ( @@ -1440,3 +1454,73 @@ async def async_get_all_descriptions( new_descriptions_cache[missing_trigger] = description hass.data[TRIGGER_DESCRIPTION_CACHE] = new_descriptions_cache return new_descriptions_cache + + +@callback +def async_extract_devices(trigger_conf: dict) -> list[str]: + """Extract devices from a trigger config.""" + if trigger_conf[CONF_PLATFORM] == "device": + return [trigger_conf[CONF_DEVICE_ID]] + + if ( + trigger_conf[CONF_PLATFORM] == "event" + and CONF_EVENT_DATA in trigger_conf + and CONF_DEVICE_ID in trigger_conf[CONF_EVENT_DATA] + and isinstance(trigger_conf[CONF_EVENT_DATA][CONF_DEVICE_ID], str) + ): + return [trigger_conf[CONF_EVENT_DATA][CONF_DEVICE_ID]] + + if trigger_conf[CONF_PLATFORM] == "tag" and CONF_DEVICE_ID in trigger_conf: + return trigger_conf[CONF_DEVICE_ID] # type: ignore[no-any-return] + + if target_devices := async_extract_targets(trigger_conf, CONF_DEVICE_ID): + return target_devices + + return [] + + +@callback +def async_extract_entities(trigger_conf: dict) -> list[str]: + """Extract entities from a trigger config.""" + if trigger_conf[CONF_PLATFORM] in ("state", "numeric_state"): + return trigger_conf[CONF_ENTITY_ID] # type: ignore[no-any-return] + + if trigger_conf[CONF_PLATFORM] == "calendar": + return [trigger_conf[CONF_OPTIONS][CONF_ENTITY_ID]] + + if trigger_conf[CONF_PLATFORM] == "zone": + return trigger_conf[CONF_ENTITY_ID] + [trigger_conf[CONF_ZONE]] # type: ignore[no-any-return] + + if trigger_conf[CONF_PLATFORM] == "geo_location": + return [trigger_conf[CONF_ZONE]] + + if trigger_conf[CONF_PLATFORM] == "sun": + return ["sun.sun"] + + if ( + trigger_conf[CONF_PLATFORM] == "event" + and CONF_EVENT_DATA in trigger_conf + and CONF_ENTITY_ID in trigger_conf[CONF_EVENT_DATA] + and isinstance(trigger_conf[CONF_EVENT_DATA][CONF_ENTITY_ID], str) + and valid_entity_id(trigger_conf[CONF_EVENT_DATA][CONF_ENTITY_ID]) + ): + return [trigger_conf[CONF_EVENT_DATA][CONF_ENTITY_ID]] + + if target_entities := async_extract_targets(trigger_conf, CONF_ENTITY_ID): + return target_entities + + return [] + + +@callback +def async_extract_targets( + config: dict, + target: Literal["entity_id", "device_id", "area_id", "floor_id", "label_id"], +) -> list[str]: + """Extract targets from a target config.""" + if not (target_conf := config.get(CONF_TARGET)): + return [] + if not (targets := target_conf.get(target)): + return [] + + return [targets] if isinstance(targets, str) else targets diff --git a/tests/helpers/test_script.py b/tests/helpers/test_script.py index d58a5e03f42..08fedae528e 100644 --- a/tests/helpers/test_script.py +++ b/tests/helpers/test_script.py @@ -4286,6 +4286,25 @@ async def test_referenced_labels(hass: HomeAssistant) -> None: } ], }, + { + "wait_for_trigger": { + "platform": "state", + "entity_id": "sensor.test", + "target": {"label_id": "label_wait_trigger"}, + }, + }, + { + "wait_for_trigger": { + "platform": "state", + "entity_id": "sensor.test", + "target": { + "label_id": [ + "label_wait_trigger_list_1", + "label_wait_trigger_list_2", + ] + }, + }, + }, ] ), "Test Name", @@ -4309,6 +4328,9 @@ async def test_referenced_labels(hass: HomeAssistant) -> None: "label_service_list_1", "label_service_list_2", "label_service_not_list", + "label_wait_trigger", + "label_wait_trigger_list_1", + "label_wait_trigger_list_2", } # Test we cache results. assert script_obj.referenced_labels is script_obj.referenced_labels @@ -4418,6 +4440,25 @@ async def test_referenced_floors(hass: HomeAssistant) -> None: } ], }, + { + "wait_for_trigger": { + "platform": "state", + "entity_id": "sensor.test", + "target": {"floor_id": "floor_wait_trigger"}, + }, + }, + { + "wait_for_trigger": { + "platform": "state", + "entity_id": "sensor.test", + "target": { + "floor_id": [ + "floor_wait_trigger_list_1", + "floor_wait_trigger_list_2", + ] + }, + }, + }, ] ), "Test Name", @@ -4440,6 +4481,9 @@ async def test_referenced_floors(hass: HomeAssistant) -> None: "floor_sequence", "floor_service_list", "floor_service_not_list", + "floor_wait_trigger", + "floor_wait_trigger_list_1", + "floor_wait_trigger_list_2", } # Test we cache results. assert script_obj.referenced_floors is script_obj.referenced_floors @@ -4549,6 +4593,25 @@ async def test_referenced_areas(hass: HomeAssistant) -> None: } ], }, + { + "wait_for_trigger": { + "platform": "state", + "entity_id": "sensor.test", + "target": {"area_id": "area_wait_trigger"}, + }, + }, + { + "wait_for_trigger": { + "platform": "state", + "entity_id": "sensor.test", + "target": { + "area_id": [ + "area_wait_trigger_list_1", + "area_wait_trigger_list_2", + ] + }, + }, + }, ] ), "Test Name", @@ -4571,6 +4634,9 @@ async def test_referenced_areas(hass: HomeAssistant) -> None: "area_sequence", "area_service_list", "area_service_not_list", + "area_wait_trigger", + "area_wait_trigger_list_1", + "area_wait_trigger_list_2", # 'area_service_template', # no area extraction from template } # Test we cache results. @@ -4692,6 +4758,21 @@ async def test_referenced_entities(hass: HomeAssistant) -> None: } ], }, + { + "wait_for_trigger": { + "platform": "state", + "entity_id": ["sensor.wait_trigger_state"], + }, + }, + { + "wait_for_trigger": { + "platform": "state", + "entity_id": [ + "sensor.wait_trigger_state_list_1", + "sensor.wait_trigger_state_list_2", + ], + }, + }, ] ), "Test Name", @@ -4718,6 +4799,9 @@ async def test_referenced_entities(hass: HomeAssistant) -> None: # "light.service_template", # no entity extraction from template "scene.hello", "sensor.condition", + "sensor.wait_trigger_state", + "sensor.wait_trigger_state_list_1", + "sensor.wait_trigger_state_list_2", } # Test we cache results. assert script_obj.referenced_entities is script_obj.referenced_entities @@ -4834,6 +4918,32 @@ async def test_referenced_devices(hass: HomeAssistant) -> None: } ], }, + { + "wait_for_trigger": { + "platform": "device", + "device_id": "wait-trigger-device", + "domain": "switch", + }, + }, + { + "wait_for_trigger": { + "platform": "state", + "entity_id": "sensor.test", + "target": {"device_id": "wait-trigger-target"}, + }, + }, + { + "wait_for_trigger": { + "platform": "state", + "entity_id": "sensor.test", + "target": { + "device_id": [ + "wait-trigger-target-list-1", + "wait-trigger-target-list-2", + ] + }, + }, + }, ] ), "Test Name", @@ -4859,6 +4969,10 @@ async def test_referenced_devices(hass: HomeAssistant) -> None: "if-else", "parallel-device", "sequence-device", + "wait-trigger-device", + "wait-trigger-target", + "wait-trigger-target-list-1", + "wait-trigger-target-list-2", } # Test we cache results. assert script_obj.referenced_devices is script_obj.referenced_devices