mirror of
https://github.com/Electric-Special/ha-core.git
synced 2026-03-26 03:04:09 +01:00
Add secondary wake word and pipeline to ESPHome voice satellites (#151710)
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import replace
|
||||
|
||||
from homeassistant.components.select import SelectEntity, SelectEntityDescription
|
||||
from homeassistant.const import EntityCategory, Platform
|
||||
@@ -64,15 +65,36 @@ class AssistPipelineSelect(SelectEntity, restore_state.RestoreEntity):
|
||||
translation_key="pipeline",
|
||||
entity_category=EntityCategory.CONFIG,
|
||||
)
|
||||
|
||||
_attr_should_poll = False
|
||||
_attr_current_option = OPTION_PREFERRED
|
||||
_attr_options = [OPTION_PREFERRED]
|
||||
|
||||
def __init__(self, hass: HomeAssistant, domain: str, unique_id_prefix: str) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
domain: str,
|
||||
unique_id_prefix: str,
|
||||
index: int = 0,
|
||||
) -> None:
|
||||
"""Initialize a pipeline selector."""
|
||||
if index < 1:
|
||||
# Keep compatibility
|
||||
key_suffix = ""
|
||||
placeholder = ""
|
||||
else:
|
||||
key_suffix = f"_{index + 1}"
|
||||
placeholder = f" {index + 1}"
|
||||
|
||||
self.entity_description = replace(
|
||||
self.entity_description,
|
||||
key=f"pipeline{key_suffix}",
|
||||
translation_placeholders={"index": placeholder},
|
||||
)
|
||||
|
||||
self._domain = domain
|
||||
self._unique_id_prefix = unique_id_prefix
|
||||
self._attr_unique_id = f"{unique_id_prefix}-pipeline"
|
||||
self._attr_unique_id = f"{unique_id_prefix}-{self.entity_description.key}"
|
||||
self.hass = hass
|
||||
self._update_options()
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
},
|
||||
"select": {
|
||||
"pipeline": {
|
||||
"name": "Assistant",
|
||||
"name": "Assistant{index}",
|
||||
"state": {
|
||||
"preferred": "Preferred"
|
||||
}
|
||||
|
||||
@@ -127,27 +127,39 @@ class EsphomeAssistSatellite(
|
||||
available_wake_words=[], active_wake_words=[], max_active_wake_words=1
|
||||
)
|
||||
|
||||
@property
|
||||
def pipeline_entity_id(self) -> str | None:
|
||||
"""Return the entity ID of the pipeline to use for the next conversation."""
|
||||
assert self._entry_data.device_info is not None
|
||||
self._active_pipeline_index = 0
|
||||
|
||||
def _get_entity_id(self, suffix: str) -> str | None:
|
||||
"""Return the entity id for pipeline select, etc."""
|
||||
if self._entry_data.device_info is None:
|
||||
return None
|
||||
|
||||
ent_reg = er.async_get(self.hass)
|
||||
return ent_reg.async_get_entity_id(
|
||||
Platform.SELECT,
|
||||
DOMAIN,
|
||||
f"{self._entry_data.device_info.mac_address}-pipeline",
|
||||
f"{self._entry_data.device_info.mac_address}-{suffix}",
|
||||
)
|
||||
|
||||
@property
|
||||
def pipeline_entity_id(self) -> str | None:
|
||||
"""Return the entity ID of the primary pipeline to use for the next conversation."""
|
||||
return self.get_pipeline_entity(self._active_pipeline_index)
|
||||
|
||||
def get_pipeline_entity(self, index: int) -> str | None:
|
||||
"""Return the entity ID of a pipeline by index."""
|
||||
id_suffix = "" if index < 1 else f"_{index + 1}"
|
||||
return self._get_entity_id(f"pipeline{id_suffix}")
|
||||
|
||||
def get_wake_word_entity(self, index: int) -> str | None:
|
||||
"""Return the entity ID of a wake word by index."""
|
||||
id_suffix = "" if index < 1 else f"_{index + 1}"
|
||||
return self._get_entity_id(f"wake_word{id_suffix}")
|
||||
|
||||
@property
|
||||
def vad_sensitivity_entity_id(self) -> str | None:
|
||||
"""Return the entity ID of the VAD sensitivity to use for the next conversation."""
|
||||
assert self._entry_data.device_info is not None
|
||||
ent_reg = er.async_get(self.hass)
|
||||
return ent_reg.async_get_entity_id(
|
||||
Platform.SELECT,
|
||||
DOMAIN,
|
||||
f"{self._entry_data.device_info.mac_address}-vad_sensitivity",
|
||||
)
|
||||
return self._get_entity_id("vad_sensitivity")
|
||||
|
||||
@callback
|
||||
def async_get_configuration(
|
||||
@@ -235,6 +247,7 @@ class EsphomeAssistSatellite(
|
||||
)
|
||||
)
|
||||
|
||||
assert self._attr_supported_features is not None
|
||||
if feature_flags & VoiceAssistantFeature.ANNOUNCE:
|
||||
# Device supports announcements
|
||||
self._attr_supported_features |= (
|
||||
@@ -257,8 +270,8 @@ class EsphomeAssistSatellite(
|
||||
|
||||
# Update wake word select when config is updated
|
||||
self.async_on_remove(
|
||||
self._entry_data.async_register_assist_satellite_set_wake_word_callback(
|
||||
self.async_set_wake_word
|
||||
self._entry_data.async_register_assist_satellite_set_wake_words_callback(
|
||||
self.async_set_wake_words
|
||||
)
|
||||
)
|
||||
|
||||
@@ -482,8 +495,31 @@ class EsphomeAssistSatellite(
|
||||
# ANNOUNCEMENT format from media player
|
||||
self._update_tts_format()
|
||||
|
||||
# Run the pipeline
|
||||
_LOGGER.debug("Running pipeline from %s to %s", start_stage, end_stage)
|
||||
# Run the appropriate pipeline.
|
||||
self._active_pipeline_index = 0
|
||||
|
||||
maybe_pipeline_index = 0
|
||||
while True:
|
||||
if not (ww_entity_id := self.get_wake_word_entity(maybe_pipeline_index)):
|
||||
break
|
||||
|
||||
if not (ww_state := self.hass.states.get(ww_entity_id)):
|
||||
continue
|
||||
|
||||
if ww_state.state == wake_word_phrase:
|
||||
# First match
|
||||
self._active_pipeline_index = maybe_pipeline_index
|
||||
break
|
||||
|
||||
# Try next wake word select
|
||||
maybe_pipeline_index += 1
|
||||
|
||||
_LOGGER.debug(
|
||||
"Running pipeline %s from %s to %s",
|
||||
self._active_pipeline_index + 1,
|
||||
start_stage,
|
||||
end_stage,
|
||||
)
|
||||
self._pipeline_task = self.config_entry.async_create_background_task(
|
||||
self.hass,
|
||||
self.async_accept_pipeline_from_satellite(
|
||||
@@ -514,6 +550,7 @@ class EsphomeAssistSatellite(
|
||||
def handle_pipeline_finished(self) -> None:
|
||||
"""Handle when pipeline has finished running."""
|
||||
self._stop_udp_server()
|
||||
self._active_pipeline_index = 0
|
||||
_LOGGER.debug("Pipeline finished")
|
||||
|
||||
def handle_timer_event(
|
||||
@@ -542,15 +579,15 @@ class EsphomeAssistSatellite(
|
||||
self.tts_response_finished()
|
||||
|
||||
@callback
|
||||
def async_set_wake_word(self, wake_word_id: str) -> None:
|
||||
"""Set active wake word and update config on satellite."""
|
||||
self._satellite_config.active_wake_words = [wake_word_id]
|
||||
def async_set_wake_words(self, wake_word_ids: list[str]) -> None:
|
||||
"""Set active wake words and update config on satellite."""
|
||||
self._satellite_config.active_wake_words = wake_word_ids
|
||||
self.config_entry.async_create_background_task(
|
||||
self.hass,
|
||||
self.async_set_configuration(self._satellite_config),
|
||||
"esphome_voice_assistant_set_config",
|
||||
)
|
||||
_LOGGER.debug("Setting active wake word: %s", wake_word_id)
|
||||
_LOGGER.debug("Setting active wake word(s): %s", wake_word_ids)
|
||||
|
||||
def _update_tts_format(self) -> None:
|
||||
"""Update the TTS format from the first media player."""
|
||||
|
||||
@@ -25,3 +25,5 @@ PROJECT_URLS = {
|
||||
# ESPHome always uses .0 for the changelog URL
|
||||
STABLE_BLE_URL_VERSION = f"{STABLE_BLE_VERSION.major}.{STABLE_BLE_VERSION.minor}.0"
|
||||
DEFAULT_URL = f"https://esphome.io/changelog/{STABLE_BLE_URL_VERSION}.html"
|
||||
|
||||
NO_WAKE_WORD: Final[str] = "no_wake_word"
|
||||
|
||||
@@ -177,9 +177,10 @@ class RuntimeEntryData:
|
||||
assist_satellite_config_update_callbacks: list[
|
||||
Callable[[AssistSatelliteConfiguration], None]
|
||||
] = field(default_factory=list)
|
||||
assist_satellite_set_wake_word_callbacks: list[Callable[[str], None]] = field(
|
||||
default_factory=list
|
||||
assist_satellite_set_wake_words_callbacks: list[Callable[[list[str]], None]] = (
|
||||
field(default_factory=list)
|
||||
)
|
||||
assist_satellite_wake_words: dict[int, str] = field(default_factory=dict)
|
||||
device_id_to_name: dict[int, str] = field(default_factory=dict)
|
||||
entity_removal_callbacks: dict[EntityInfoKey, list[CALLBACK_TYPE]] = field(
|
||||
default_factory=dict
|
||||
@@ -501,19 +502,28 @@ class RuntimeEntryData:
|
||||
callback_(config)
|
||||
|
||||
@callback
|
||||
def async_register_assist_satellite_set_wake_word_callback(
|
||||
def async_register_assist_satellite_set_wake_words_callback(
|
||||
self,
|
||||
callback_: Callable[[str], None],
|
||||
callback_: Callable[[list[str]], None],
|
||||
) -> CALLBACK_TYPE:
|
||||
"""Register to receive callbacks when the Assist satellite's wake word is set."""
|
||||
self.assist_satellite_set_wake_word_callbacks.append(callback_)
|
||||
return partial(self.assist_satellite_set_wake_word_callbacks.remove, callback_)
|
||||
self.assist_satellite_set_wake_words_callbacks.append(callback_)
|
||||
return partial(self.assist_satellite_set_wake_words_callbacks.remove, callback_)
|
||||
|
||||
@callback
|
||||
def async_assist_satellite_set_wake_word(self, wake_word_id: str) -> None:
|
||||
"""Notify listeners that the Assist satellite wake word has been set."""
|
||||
for callback_ in self.assist_satellite_set_wake_word_callbacks.copy():
|
||||
callback_(wake_word_id)
|
||||
def async_assist_satellite_set_wake_word(
|
||||
self, wake_word_index: int, wake_word_id: str | None
|
||||
) -> None:
|
||||
"""Notify listeners that the Assist satellite wake words have been set."""
|
||||
if wake_word_id:
|
||||
self.assist_satellite_wake_words[wake_word_index] = wake_word_id
|
||||
else:
|
||||
self.assist_satellite_wake_words.pop(wake_word_index, None)
|
||||
|
||||
wake_word_ids = list(self.assist_satellite_wake_words.values())
|
||||
|
||||
for callback_ in self.assist_satellite_set_wake_words_callbacks.copy():
|
||||
callback_(wake_word_ids)
|
||||
|
||||
@callback
|
||||
def async_register_entity_removal_callback(
|
||||
|
||||
@@ -9,11 +9,17 @@
|
||||
"pipeline": {
|
||||
"default": "mdi:filter-outline"
|
||||
},
|
||||
"pipeline_2": {
|
||||
"default": "mdi:filter-outline"
|
||||
},
|
||||
"vad_sensitivity": {
|
||||
"default": "mdi:volume-high"
|
||||
},
|
||||
"wake_word": {
|
||||
"default": "mdi:microphone"
|
||||
},
|
||||
"wake_word_2": {
|
||||
"default": "mdi:microphone"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import replace
|
||||
|
||||
from aioesphomeapi import EntityInfo, SelectInfo, SelectState
|
||||
|
||||
from homeassistant.components.assist_pipeline.select import (
|
||||
@@ -15,7 +17,7 @@ from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers import restore_state
|
||||
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
||||
|
||||
from .const import DOMAIN
|
||||
from .const import DOMAIN, NO_WAKE_WORD
|
||||
from .entity import (
|
||||
EsphomeAssistEntity,
|
||||
EsphomeEntity,
|
||||
@@ -50,9 +52,11 @@ async def async_setup_entry(
|
||||
):
|
||||
async_add_entities(
|
||||
[
|
||||
EsphomeAssistPipelineSelect(hass, entry_data),
|
||||
EsphomeAssistPipelineSelect(hass, entry_data, index=0),
|
||||
EsphomeAssistPipelineSelect(hass, entry_data, index=1),
|
||||
EsphomeVadSensitivitySelect(hass, entry_data),
|
||||
EsphomeAssistSatelliteWakeWordSelect(entry_data),
|
||||
EsphomeAssistSatelliteWakeWordSelect(entry_data, index=0),
|
||||
EsphomeAssistSatelliteWakeWordSelect(entry_data, index=1),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -84,10 +88,14 @@ class EsphomeSelect(EsphomeEntity[SelectInfo, SelectState], SelectEntity):
|
||||
class EsphomeAssistPipelineSelect(EsphomeAssistEntity, AssistPipelineSelect):
|
||||
"""Pipeline selector for esphome devices."""
|
||||
|
||||
def __init__(self, hass: HomeAssistant, entry_data: RuntimeEntryData) -> None:
|
||||
def __init__(
|
||||
self, hass: HomeAssistant, entry_data: RuntimeEntryData, index: int = 0
|
||||
) -> None:
|
||||
"""Initialize a pipeline selector."""
|
||||
EsphomeAssistEntity.__init__(self, entry_data)
|
||||
AssistPipelineSelect.__init__(self, hass, DOMAIN, self._device_info.mac_address)
|
||||
AssistPipelineSelect.__init__(
|
||||
self, hass, DOMAIN, self._device_info.mac_address, index=index
|
||||
)
|
||||
|
||||
|
||||
class EsphomeVadSensitivitySelect(EsphomeAssistEntity, VadSensitivitySelect):
|
||||
@@ -109,28 +117,47 @@ class EsphomeAssistSatelliteWakeWordSelect(
|
||||
translation_key="wake_word",
|
||||
entity_category=EntityCategory.CONFIG,
|
||||
)
|
||||
_attr_current_option: str | None = None
|
||||
_attr_options: list[str] = []
|
||||
|
||||
def __init__(self, entry_data: RuntimeEntryData) -> None:
|
||||
_attr_current_option: str | None = None
|
||||
_attr_options: list[str] = [NO_WAKE_WORD]
|
||||
|
||||
def __init__(self, entry_data: RuntimeEntryData, index: int = 0) -> None:
|
||||
"""Initialize a wake word selector."""
|
||||
if index < 1:
|
||||
# Keep compatibility
|
||||
key_suffix = ""
|
||||
placeholder = ""
|
||||
else:
|
||||
key_suffix = f"_{index + 1}"
|
||||
placeholder = f" {index + 1}"
|
||||
|
||||
self.entity_description = replace(
|
||||
self.entity_description,
|
||||
key=f"wake_word{key_suffix}",
|
||||
translation_placeholders={"index": placeholder},
|
||||
)
|
||||
|
||||
EsphomeAssistEntity.__init__(self, entry_data)
|
||||
|
||||
unique_id_prefix = self._device_info.mac_address
|
||||
self._attr_unique_id = f"{unique_id_prefix}-wake_word"
|
||||
self._attr_unique_id = f"{unique_id_prefix}-{self.entity_description.key}"
|
||||
|
||||
# name -> id
|
||||
self._wake_words: dict[str, str] = {}
|
||||
self._wake_word_index = index
|
||||
|
||||
@property
|
||||
def available(self) -> bool:
|
||||
"""Return if entity is available."""
|
||||
return bool(self._attr_options)
|
||||
return len(self._attr_options) > 1 # more than just NO_WAKE_WORD
|
||||
|
||||
async def async_added_to_hass(self) -> None:
|
||||
"""Run when entity about to be added to hass."""
|
||||
await super().async_added_to_hass()
|
||||
|
||||
if last_state := await self.async_get_last_state():
|
||||
self._attr_current_option = last_state.state
|
||||
|
||||
# Update options when config is updated
|
||||
self.async_on_remove(
|
||||
self._entry_data.async_register_assist_satellite_config_updated_callback(
|
||||
@@ -140,33 +167,49 @@ class EsphomeAssistSatelliteWakeWordSelect(
|
||||
|
||||
async def async_select_option(self, option: str) -> None:
|
||||
"""Select an option."""
|
||||
if wake_word_id := self._wake_words.get(option):
|
||||
# _attr_current_option will be updated on
|
||||
# async_satellite_config_updated after the device sets the wake
|
||||
# word.
|
||||
self._entry_data.async_assist_satellite_set_wake_word(wake_word_id)
|
||||
self._attr_current_option = option
|
||||
self.async_write_ha_state()
|
||||
|
||||
wake_word_id = self._wake_words.get(option)
|
||||
self._entry_data.async_assist_satellite_set_wake_word(
|
||||
self._wake_word_index, wake_word_id
|
||||
)
|
||||
|
||||
def async_satellite_config_updated(
|
||||
self, config: AssistSatelliteConfiguration
|
||||
) -> None:
|
||||
"""Update options with available wake words."""
|
||||
if (not config.available_wake_words) or (config.max_active_wake_words < 1):
|
||||
self._attr_current_option = None
|
||||
# No wake words
|
||||
self._wake_words.clear()
|
||||
self._attr_current_option = NO_WAKE_WORD
|
||||
self._attr_options = [NO_WAKE_WORD]
|
||||
self._entry_data.assist_satellite_wake_words.pop(
|
||||
self._wake_word_index, None
|
||||
)
|
||||
self.async_write_ha_state()
|
||||
return
|
||||
|
||||
self._wake_words = {w.wake_word: w.id for w in config.available_wake_words}
|
||||
self._attr_options = sorted(self._wake_words)
|
||||
self._attr_options = [NO_WAKE_WORD, *sorted(self._wake_words)]
|
||||
|
||||
if config.active_wake_words:
|
||||
# Select first active wake word
|
||||
wake_word_id = config.active_wake_words[0]
|
||||
for wake_word in config.available_wake_words:
|
||||
if wake_word.id == wake_word_id:
|
||||
self._attr_current_option = wake_word.wake_word
|
||||
else:
|
||||
# Select first available wake word
|
||||
self._attr_current_option = config.available_wake_words[0].wake_word
|
||||
option = self._attr_current_option
|
||||
if (
|
||||
(option is None)
|
||||
or ((wake_word_id := self._wake_words.get(option)) is None)
|
||||
or (wake_word_id not in config.active_wake_words)
|
||||
):
|
||||
option = NO_WAKE_WORD
|
||||
|
||||
self._attr_current_option = option
|
||||
self.async_write_ha_state()
|
||||
|
||||
# Keep entry data in sync
|
||||
if wake_word_id := self._wake_words.get(option):
|
||||
self._entry_data.assist_satellite_wake_words[self._wake_word_index] = (
|
||||
wake_word_id
|
||||
)
|
||||
else:
|
||||
self._entry_data.assist_satellite_wake_words.pop(
|
||||
self._wake_word_index, None
|
||||
)
|
||||
|
||||
@@ -119,8 +119,9 @@
|
||||
}
|
||||
},
|
||||
"wake_word": {
|
||||
"name": "Wake word",
|
||||
"name": "Wake word{index}",
|
||||
"state": {
|
||||
"no_wake_word": "No wake word",
|
||||
"okay_nabu": "Okay Nabu"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,6 +28,7 @@ from homeassistant.components import (
|
||||
tts,
|
||||
)
|
||||
from homeassistant.components.assist_pipeline import PipelineEvent, PipelineEventType
|
||||
from homeassistant.components.assist_pipeline.pipeline import KEY_ASSIST_PIPELINE
|
||||
from homeassistant.components.assist_satellite import (
|
||||
AssistSatelliteConfiguration,
|
||||
AssistSatelliteEntityFeature,
|
||||
@@ -37,6 +38,7 @@ from homeassistant.components.assist_satellite import (
|
||||
# pylint: disable-next=hass-component-root-import
|
||||
from homeassistant.components.assist_satellite.entity import AssistSatelliteState
|
||||
from homeassistant.components.esphome.assist_satellite import VoiceAssistantUDPServer
|
||||
from homeassistant.components.esphome.const import NO_WAKE_WORD
|
||||
from homeassistant.components.select import (
|
||||
DOMAIN as SELECT_DOMAIN,
|
||||
SERVICE_SELECT_OPTION,
|
||||
@@ -45,6 +47,7 @@ from homeassistant.const import ATTR_ENTITY_ID, STATE_UNAVAILABLE
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import device_registry as dr, intent as intent_helper
|
||||
from homeassistant.helpers.network import get_url
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from .common import get_satellite_entity
|
||||
from .conftest import MockESPHomeDeviceType
|
||||
@@ -1737,7 +1740,7 @@ async def test_get_set_configuration(
|
||||
AssistSatelliteWakeWord("5678", "hey jarvis", ["en"]),
|
||||
],
|
||||
active_wake_words=["1234"],
|
||||
max_active_wake_words=1,
|
||||
max_active_wake_words=2,
|
||||
)
|
||||
mock_client.get_voice_assistant_configuration.return_value = expected_config
|
||||
|
||||
@@ -1857,7 +1860,7 @@ async def test_wake_word_select(
|
||||
AssistSatelliteWakeWord("hey_mycroft", "Hey Mycroft", ["en"]),
|
||||
],
|
||||
active_wake_words=["hey_jarvis"],
|
||||
max_active_wake_words=1,
|
||||
max_active_wake_words=2,
|
||||
)
|
||||
mock_client.get_voice_assistant_configuration.return_value = device_config
|
||||
|
||||
@@ -1884,10 +1887,10 @@ async def test_wake_word_select(
|
||||
assert satellite is not None
|
||||
assert satellite.async_get_configuration().active_wake_words == ["hey_jarvis"]
|
||||
|
||||
# Active wake word should be selected
|
||||
# No wake word should be selected by default
|
||||
state = hass.states.get("select.test_wake_word")
|
||||
assert state is not None
|
||||
assert state.state == "Hey Jarvis"
|
||||
assert state.state == NO_WAKE_WORD
|
||||
|
||||
# Changing the select should set the active wake word
|
||||
await hass.services.async_call(
|
||||
@@ -1908,3 +1911,162 @@ async def test_wake_word_select(
|
||||
|
||||
# Satellite config should have been updated
|
||||
assert satellite.async_get_configuration().active_wake_words == ["okay_nabu"]
|
||||
|
||||
# No secondary wake word should be selected by default
|
||||
state = hass.states.get("select.test_wake_word_2")
|
||||
assert state is not None
|
||||
assert state.state == NO_WAKE_WORD
|
||||
|
||||
# Changing the secondary select should add an active wake word
|
||||
await hass.services.async_call(
|
||||
SELECT_DOMAIN,
|
||||
SERVICE_SELECT_OPTION,
|
||||
{ATTR_ENTITY_ID: "select.test_wake_word_2", "option": "Hey Jarvis"},
|
||||
blocking=True,
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
state = hass.states.get("select.test_wake_word_2")
|
||||
assert state is not None
|
||||
assert state.state == "Hey Jarvis"
|
||||
|
||||
# Wait for device config to be updated
|
||||
async with asyncio.timeout(1):
|
||||
await configuration_set.wait()
|
||||
|
||||
# Satellite config should have been updated
|
||||
assert set(satellite.async_get_configuration().active_wake_words) == {
|
||||
"okay_nabu",
|
||||
"hey_jarvis",
|
||||
}
|
||||
|
||||
# Remove the secondary wake word
|
||||
await hass.services.async_call(
|
||||
SELECT_DOMAIN,
|
||||
SERVICE_SELECT_OPTION,
|
||||
{ATTR_ENTITY_ID: "select.test_wake_word_2", "option": NO_WAKE_WORD},
|
||||
blocking=True,
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
async with asyncio.timeout(1):
|
||||
await configuration_set.wait()
|
||||
|
||||
# Only primary wake word remains
|
||||
assert satellite.async_get_configuration().active_wake_words == ["okay_nabu"]
|
||||
|
||||
|
||||
async def test_secondary_pipeline(
|
||||
hass: HomeAssistant,
|
||||
mock_client: APIClient,
|
||||
mock_esphome_device: MockESPHomeDeviceType,
|
||||
) -> None:
|
||||
"""Test that the secondary pipeline is used when the secondary wake word is given."""
|
||||
assert await async_setup_component(hass, "assist_pipeline", {})
|
||||
pipeline_data = hass.data[KEY_ASSIST_PIPELINE]
|
||||
pipeline_id_to_name: dict[str, str] = {}
|
||||
for pipeline_name in ("Primary Pipeline", "Secondary Pipeline"):
|
||||
pipeline = await pipeline_data.pipeline_store.async_create_item(
|
||||
{
|
||||
"name": pipeline_name,
|
||||
"language": "en-US",
|
||||
"conversation_engine": None,
|
||||
"conversation_language": "en-US",
|
||||
"tts_engine": None,
|
||||
"tts_language": None,
|
||||
"tts_voice": None,
|
||||
"stt_engine": None,
|
||||
"stt_language": None,
|
||||
"wake_word_entity": None,
|
||||
"wake_word_id": None,
|
||||
}
|
||||
)
|
||||
pipeline_id_to_name[pipeline.id] = pipeline_name
|
||||
|
||||
device_config = AssistSatelliteConfiguration(
|
||||
available_wake_words=[
|
||||
AssistSatelliteWakeWord("okay_nabu", "Okay Nabu", ["en"]),
|
||||
AssistSatelliteWakeWord("hey_jarvis", "Hey Jarvis", ["en"]),
|
||||
AssistSatelliteWakeWord("hey_mycroft", "Hey Mycroft", ["en"]),
|
||||
],
|
||||
active_wake_words=["hey_jarvis"],
|
||||
max_active_wake_words=2,
|
||||
)
|
||||
mock_client.get_voice_assistant_configuration.return_value = device_config
|
||||
|
||||
# Wrap mock so we can tell when it's done
|
||||
configuration_set = asyncio.Event()
|
||||
|
||||
async def wrapper(*args, **kwargs):
|
||||
# Update device config because entity will request it after update
|
||||
device_config.active_wake_words = kwargs["active_wake_words"]
|
||||
configuration_set.set()
|
||||
|
||||
mock_client.set_voice_assistant_configuration = AsyncMock(side_effect=wrapper)
|
||||
|
||||
mock_device = await mock_esphome_device(
|
||||
mock_client=mock_client,
|
||||
device_info={
|
||||
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
|
||||
| VoiceAssistantFeature.ANNOUNCE
|
||||
},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
satellite = get_satellite_entity(hass, mock_device.device_info.mac_address)
|
||||
assert satellite is not None
|
||||
|
||||
# Set primary/secondary wake words and assistants
|
||||
await hass.services.async_call(
|
||||
SELECT_DOMAIN,
|
||||
SERVICE_SELECT_OPTION,
|
||||
{ATTR_ENTITY_ID: "select.test_wake_word", "option": "Okay Nabu"},
|
||||
blocking=True,
|
||||
)
|
||||
await hass.services.async_call(
|
||||
SELECT_DOMAIN,
|
||||
SERVICE_SELECT_OPTION,
|
||||
{ATTR_ENTITY_ID: "select.test_assistant", "option": "Primary Pipeline"},
|
||||
blocking=True,
|
||||
)
|
||||
await hass.services.async_call(
|
||||
SELECT_DOMAIN,
|
||||
SERVICE_SELECT_OPTION,
|
||||
{ATTR_ENTITY_ID: "select.test_wake_word_2", "option": "Hey Jarvis"},
|
||||
blocking=True,
|
||||
)
|
||||
await hass.services.async_call(
|
||||
SELECT_DOMAIN,
|
||||
SERVICE_SELECT_OPTION,
|
||||
{
|
||||
ATTR_ENTITY_ID: "select.test_assistant_2",
|
||||
"option": "Secondary Pipeline",
|
||||
},
|
||||
blocking=True,
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
async def get_pipeline(wake_word_phrase):
|
||||
with patch(
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
) as mock_pipeline_from_audio_stream:
|
||||
await satellite.handle_pipeline_start(
|
||||
conversation_id="",
|
||||
flags=0,
|
||||
audio_settings=VoiceAssistantAudioSettings(),
|
||||
wake_word_phrase=wake_word_phrase,
|
||||
)
|
||||
|
||||
mock_pipeline_from_audio_stream.assert_called_once()
|
||||
kwargs = mock_pipeline_from_audio_stream.call_args_list[0].kwargs
|
||||
return pipeline_id_to_name[kwargs["pipeline_id"]]
|
||||
|
||||
# Primary pipeline is the default
|
||||
for wake_word_phrase in (None, "Okay Nabu"):
|
||||
assert (await get_pipeline(wake_word_phrase)) == "Primary Pipeline"
|
||||
|
||||
# Secondary pipeline requires secondary wake word
|
||||
assert (await get_pipeline("Hey Jarvis")) == "Secondary Pipeline"
|
||||
|
||||
# Primary pipeline should be restored after
|
||||
assert (await get_pipeline(None)) == "Primary Pipeline"
|
||||
|
||||
@@ -9,6 +9,7 @@ from homeassistant.components.assist_satellite import (
|
||||
AssistSatelliteConfiguration,
|
||||
AssistSatelliteWakeWord,
|
||||
)
|
||||
from homeassistant.components.esphome.const import NO_WAKE_WORD
|
||||
from homeassistant.components.select import (
|
||||
ATTR_OPTION,
|
||||
DOMAIN as SELECT_DOMAIN,
|
||||
@@ -32,6 +33,17 @@ async def test_pipeline_selector(
|
||||
assert state.state == "preferred"
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_voice_assistant_v1_entry")
|
||||
async def test_secondary_pipeline_selector(
|
||||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test secondary assist pipeline selector."""
|
||||
|
||||
state = hass.states.get("select.test_assistant_2")
|
||||
assert state is not None
|
||||
assert state.state == "preferred"
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_voice_assistant_v1_entry")
|
||||
async def test_vad_sensitivity_select(
|
||||
hass: HomeAssistant,
|
||||
@@ -56,6 +68,16 @@ async def test_wake_word_select(
|
||||
assert state.state == STATE_UNAVAILABLE
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_voice_assistant_v1_entry")
|
||||
async def test_secondary_wake_word_select(
|
||||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test that secondary wake word select is unavailable initially."""
|
||||
state = hass.states.get("select.test_wake_word_2")
|
||||
assert state is not None
|
||||
assert state.state == STATE_UNAVAILABLE
|
||||
|
||||
|
||||
async def test_select_generic_entity(
|
||||
hass: HomeAssistant,
|
||||
mock_client: APIClient,
|
||||
@@ -117,10 +139,11 @@ async def test_wake_word_select_no_wake_words(
|
||||
assert satellite is not None
|
||||
assert not satellite.async_get_configuration().available_wake_words
|
||||
|
||||
# Select should be unavailable
|
||||
state = hass.states.get("select.test_wake_word")
|
||||
assert state is not None
|
||||
assert state.state == STATE_UNAVAILABLE
|
||||
# Selects should be unavailable
|
||||
for entity_id in ("select.test_wake_word", "select.test_wake_word_2"):
|
||||
state = hass.states.get(entity_id)
|
||||
assert state is not None
|
||||
assert state.state == STATE_UNAVAILABLE
|
||||
|
||||
|
||||
async def test_wake_word_select_zero_max_wake_words(
|
||||
@@ -151,10 +174,11 @@ async def test_wake_word_select_zero_max_wake_words(
|
||||
assert satellite is not None
|
||||
assert satellite.async_get_configuration().max_active_wake_words == 0
|
||||
|
||||
# Select should be unavailable
|
||||
state = hass.states.get("select.test_wake_word")
|
||||
assert state is not None
|
||||
assert state.state == STATE_UNAVAILABLE
|
||||
# Selects should be unavailable
|
||||
for entity_id in ("select.test_wake_word", "select.test_wake_word_2"):
|
||||
state = hass.states.get(entity_id)
|
||||
assert state is not None
|
||||
assert state.state == STATE_UNAVAILABLE
|
||||
|
||||
|
||||
async def test_wake_word_select_no_active_wake_words(
|
||||
@@ -186,7 +210,8 @@ async def test_wake_word_select_no_active_wake_words(
|
||||
assert satellite is not None
|
||||
assert not satellite.async_get_configuration().active_wake_words
|
||||
|
||||
# First available wake word should be selected
|
||||
state = hass.states.get("select.test_wake_word")
|
||||
assert state is not None
|
||||
assert state.state == "Okay Nabu"
|
||||
# No wake words should be selected
|
||||
for entity_id in ("select.test_wake_word", "select.test_wake_word_2"):
|
||||
state = hass.states.get(entity_id)
|
||||
assert state is not None
|
||||
assert state.state == NO_WAKE_WORD
|
||||
|
||||
Reference in New Issue
Block a user