diff --git a/homeassistant/helpers/selector.py b/homeassistant/helpers/selector.py index a3510726175..176c2930277 100644 --- a/homeassistant/helpers/selector.py +++ b/homeassistant/helpers/selector.py @@ -57,6 +57,10 @@ class Selector[_T: Mapping[str, Any]]: CONFIG_SCHEMA: Callable config: _T selector_type: str + # Context keys that are allowed to be used in the selector, with list of allowed selector types. + # Selectors can use the value of other fields in the same schema as context for filtering for example. + # The selector defines which context keys it supports and what selector types are allowed for each key. + allowed_context_keys: dict[str, set[str]] = {} def __init__(self, config: Mapping[str, Any] | None = None) -> None: """Instantiate a selector.""" @@ -346,6 +350,11 @@ class AttributeSelector(Selector[AttributeSelectorConfig]): selector_type = "attribute" + allowed_context_keys = { + # Filters the available attributes based on the selected entity + "filter_entity": {"entity"} + } + CONFIG_SCHEMA = make_selector_config_schema( { vol.Required("entity_id"): cv.entity_id, @@ -1039,6 +1048,11 @@ class MediaSelector(Selector[MediaSelectorConfig]): selector_type = "media" + allowed_context_keys = { + # Filters the available media based on the selected entity + "filter_entity": {EntitySelector.selector_type} + } + CONFIG_SCHEMA = make_selector_config_schema( { vol.Optional("accept"): [str], @@ -1385,6 +1399,15 @@ class StateSelector(Selector[StateSelectorConfig]): selector_type = "state" + allowed_context_keys = { + # Filters the available states based on the selected entity + "filter_entity": {EntitySelector.selector_type}, + # Filters the available states based on the selected target + "filter_target": {"target"}, + # Only show the attribute values of a specific attribute + "filter_attribute": {AttributeSelector.selector_type}, + } + CONFIG_SCHEMA = make_selector_config_schema( { vol.Optional("entity_id"): cv.entity_id, diff --git a/script/hassfest/triggers.py b/script/hassfest/triggers.py index 4eb376c435f..86e4a475475 100644 --- a/script/hassfest/triggers.py +++ b/script/hassfest/triggers.py @@ -26,21 +26,95 @@ def exists(value: Any) -> Any: return value +def validate_field_schema(trigger_schema: dict[str, Any]) -> dict[str, Any]: + """Validate a field schema including context references.""" + + for field_name, field_schema in trigger_schema.get("fields", {}).items(): + # Validate context if present + if "context" in field_schema: + if CONF_SELECTOR not in field_schema: + raise vol.Invalid( + f"Context defined without a selector in '{field_name}'" + ) + + context = field_schema["context"] + if not isinstance(context, dict): + raise vol.Invalid(f"Context must be a dictionary in '{field_name}'") + + # Determine which selector type is being used + selector_config = field_schema[CONF_SELECTOR] + selector_class = selector.selector(selector_config) + + for context_key, field_ref in context.items(): + # Check if context key is allowed for this selector type + allowed_keys = selector_class.allowed_context_keys + if context_key not in allowed_keys: + raise vol.Invalid( + f"Invalid context key '{context_key}' for selector type '{selector_class.selector_type}'. " + f"Allowed keys: {', '.join(sorted(allowed_keys)) if allowed_keys else 'none'}" + ) + + # Check if the referenced field exists in trigger schema or target + if not isinstance(field_ref, str): + raise vol.Invalid( + f"Context value for '{context_key}' must be a string field reference" + ) + + # Check if field exists in trigger schema fields or target + trigger_fields = trigger_schema["fields"] + field_exists = field_ref in trigger_fields + if field_exists and "selector" in trigger_fields[field_ref]: + # Check if the selector type is allowed for this context key + field_selector_config = trigger_fields[field_ref][CONF_SELECTOR] + field_selector_class = selector.selector(field_selector_config) + if field_selector_class.selector_type not in allowed_keys.get( + context_key, set() + ): + raise vol.Invalid( + f"The context '{context_key}' for '{field_name}' references '{field_ref}', but '{context_key}' " + f"does not allow selectors of type '{field_selector_class.selector_type}'. Allowed selector types: {', '.join(allowed_keys.get(context_key, set()))}" + ) + if not field_exists and "target" in trigger_schema: + # Target is a special field that always exists when defined + field_exists = field_ref == "target" + if field_exists and "target" not in allowed_keys.get( + context_key, set() + ): + raise vol.Invalid( + f"The context '{context_key}' for '{field_name}' references 'target', but '{context_key}' " + f"does not allow 'target'. Allowed selector types: {', '.join(allowed_keys.get(context_key, set()))}" + ) + + if not field_exists: + raise vol.Invalid( + f"Context reference '{field_ref}' for key '{context_key}' does not exist " + f"in trigger schema fields or target" + ) + + return trigger_schema + + FIELD_SCHEMA = vol.Schema( { vol.Optional("example"): exists, vol.Optional("default"): exists, vol.Optional("required"): bool, vol.Optional(CONF_SELECTOR): selector.validate_selector, + vol.Optional("context"): { + str: str # key is context key, value is field name in the schema which value should be used + }, # Will be validated in validate_field_schema } ) TRIGGER_SCHEMA = vol.Any( - vol.Schema( - { - vol.Optional("target"): selector.TargetSelector.CONFIG_SCHEMA, - vol.Optional("fields"): vol.Schema({str: FIELD_SCHEMA}), - } + vol.All( + vol.Schema( + { + vol.Optional("target"): selector.TargetSelector.CONFIG_SCHEMA, + vol.Optional("fields"): vol.Schema({str: FIELD_SCHEMA}), + } + ), + validate_field_schema, ), None, )