diff --git a/homeassistant/components/assist_pipeline/acknowledge.mp3 b/homeassistant/components/assist_pipeline/acknowledge.mp3 new file mode 100644 index 00000000000..1709ff20bc2 Binary files /dev/null and b/homeassistant/components/assist_pipeline/acknowledge.mp3 differ diff --git a/homeassistant/components/assist_pipeline/const.py b/homeassistant/components/assist_pipeline/const.py index 52583cf21a4..54829a48f88 100644 --- a/homeassistant/components/assist_pipeline/const.py +++ b/homeassistant/components/assist_pipeline/const.py @@ -1,5 +1,7 @@ """Constants for the Assist pipeline integration.""" +from pathlib import Path + DOMAIN = "assist_pipeline" DATA_CONFIG = f"{DOMAIN}.config" @@ -23,3 +25,5 @@ SAMPLES_PER_CHUNK = SAMPLE_RATE // (1000 // MS_PER_CHUNK) # 10 ms @ 16Khz BYTES_PER_CHUNK = SAMPLES_PER_CHUNK * SAMPLE_WIDTH * SAMPLE_CHANNELS # 16-bit OPTION_PREFERRED = "preferred" + +ACKNOWLEDGE_PATH = Path(__file__).parent / "acknowledge.mp3" diff --git a/homeassistant/components/assist_pipeline/pipeline.py b/homeassistant/components/assist_pipeline/pipeline.py index ad291b3427b..8af0c9157b5 100644 --- a/homeassistant/components/assist_pipeline/pipeline.py +++ b/homeassistant/components/assist_pipeline/pipeline.py @@ -23,7 +23,12 @@ from homeassistant.components import conversation, stt, tts, wake_word, websocke from homeassistant.const import ATTR_SUPPORTED_FEATURES, MATCH_ALL from homeassistant.core import Context, HomeAssistant, callback from homeassistant.exceptions import HomeAssistantError -from homeassistant.helpers import chat_session, intent +from homeassistant.helpers import ( + chat_session, + device_registry as dr, + entity_registry as er, + intent, +) from homeassistant.helpers.collection import ( CHANGE_UPDATED, CollectionError, @@ -45,6 +50,7 @@ from homeassistant.util.limited_size_dict import LimitedSizeDict from .audio_enhancer import AudioEnhancer, EnhancedAudioChunk, MicroVadSpeexEnhancer from .const import ( + ACKNOWLEDGE_PATH, BYTES_PER_CHUNK, CONF_DEBUG_RECORDING_DIR, DATA_CONFIG, @@ -113,6 +119,7 @@ PIPELINE_FIELDS: VolDictType = { vol.Required("wake_word_entity"): vol.Any(str, None), vol.Required("wake_word_id"): vol.Any(str, None), vol.Optional("prefer_local_intents"): bool, + vol.Optional("acknowledge_media_id"): str, } STORED_PIPELINE_RUNS = 10 @@ -1066,8 +1073,11 @@ class PipelineRun: intent_input: str, conversation_id: str, conversation_extra_system_prompt: str | None, - ) -> str: - """Run intent recognition portion of pipeline. Returns text to speak.""" + ) -> tuple[str, bool]: + """Run intent recognition portion of pipeline. + + Returns (speech, all_targets_in_satellite_area). + """ if self.intent_agent is None or self._conversation_data is None: raise RuntimeError("Recognize intent was not prepared") @@ -1116,6 +1126,7 @@ class PipelineRun: agent_id = self.intent_agent.id processed_locally = agent_id == conversation.HOME_ASSISTANT_AGENT + all_targets_in_satellite_area = False intent_response: intent.IntentResponse | None = None if not processed_locally and not self._intent_agent_only: # Sentence triggers override conversation agent @@ -1290,6 +1301,17 @@ class PipelineRun: if tts_input_stream and self._streamed_response_text: tts_input_stream.put_nowait(None) + if agent_id == conversation.HOME_ASSISTANT_AGENT: + # Check if all targeted entities were in the same area as + # the satellite device. + # If so, the satellite should respond with an acknowledge beep + # instead of a full response. + all_targets_in_satellite_area = ( + self._get_all_targets_in_satellite_area( + conversation_result.response, self._device_id + ) + ) + except Exception as src_error: _LOGGER.exception("Unexpected error during intent recognition") raise IntentRecognitionError( @@ -1312,7 +1334,45 @@ class PipelineRun: if conversation_result.continue_conversation: self._conversation_data.continue_conversation_agent = agent_id - return speech + return (speech, all_targets_in_satellite_area) + + def _get_all_targets_in_satellite_area( + self, intent_response: intent.IntentResponse, device_id: str | None + ) -> bool: + """Return true if all targeted entities were in the same area as the device.""" + if ( + (intent_response.response_type != intent.IntentResponseType.ACTION_DONE) + or (not intent_response.matched_states) + or (not device_id) + ): + return False + + device_registry = dr.async_get(self.hass) + + if (not (device := device_registry.async_get(device_id))) or ( + not device.area_id + ): + return False + + entity_registry = er.async_get(self.hass) + for state in intent_response.matched_states: + entity = entity_registry.async_get(state.entity_id) + if not entity: + return False + + if (entity_area_id := entity.area_id) is None: + if (entity.device_id is None) or ( + (entity_device := device_registry.async_get(entity.device_id)) + is None + ): + return False + + entity_area_id = entity_device.area_id + + if entity_area_id != device.area_id: + return False + + return True async def prepare_text_to_speech(self) -> None: """Prepare text-to-speech.""" @@ -1350,7 +1410,9 @@ class PipelineRun: ), ) from err - async def text_to_speech(self, tts_input: str) -> None: + async def text_to_speech( + self, tts_input: str, override_media_path: Path | None = None + ) -> None: """Run text-to-speech portion of pipeline.""" assert self.tts_stream is not None @@ -1362,11 +1424,14 @@ class PipelineRun: "language": self.pipeline.tts_language, "voice": self.pipeline.tts_voice, "tts_input": tts_input, + "acknowledge_override": override_media_path is not None, }, ) ) - if not self._streamed_response_text: + if override_media_path: + self.tts_stream.async_override_result(override_media_path) + elif not self._streamed_response_text: self.tts_stream.async_set_message(tts_input) tts_output = { @@ -1664,16 +1729,20 @@ class PipelineInput: if self.run.end_stage != PipelineStage.STT: tts_input = self.tts_input + all_targets_in_satellite_area = False if current_stage == PipelineStage.INTENT: # intent-recognition assert intent_input is not None - tts_input = await self.run.recognize_intent( + ( + tts_input, + all_targets_in_satellite_area, + ) = await self.run.recognize_intent( intent_input, self.session.conversation_id, self.conversation_extra_system_prompt, ) - if tts_input.strip(): + if all_targets_in_satellite_area or tts_input.strip(): current_stage = PipelineStage.TTS else: # Skip TTS @@ -1682,8 +1751,14 @@ class PipelineInput: if self.run.end_stage != PipelineStage.INTENT: # text-to-speech if current_stage == PipelineStage.TTS: - assert tts_input is not None - await self.run.text_to_speech(tts_input) + if all_targets_in_satellite_area: + # Use acknowledge media instead of full response + await self.run.text_to_speech( + tts_input or "", override_media_path=ACKNOWLEDGE_PATH + ) + else: + assert tts_input is not None + await self.run.text_to_speech(tts_input) except PipelineError as err: self.run.process_event( diff --git a/tests/components/assist_pipeline/snapshots/test_init.ambr b/tests/components/assist_pipeline/snapshots/test_init.ambr index 56ca8bde0ba..5e77b7e9291 100644 --- a/tests/components/assist_pipeline/snapshots/test_init.ambr +++ b/tests/components/assist_pipeline/snapshots/test_init.ambr @@ -76,6 +76,7 @@ }), dict({ 'data': dict({ + 'acknowledge_override': False, 'engine': 'tts.test', 'language': 'en_US', 'tts_input': "Sorry, I couldn't understand that", @@ -177,6 +178,7 @@ }), dict({ 'data': dict({ + 'acknowledge_override': False, 'engine': 'test', 'language': 'en-US', 'tts_input': "Sorry, I couldn't understand that", @@ -278,6 +280,7 @@ }), dict({ 'data': dict({ + 'acknowledge_override': False, 'engine': 'test', 'language': 'en-US', 'tts_input': "Sorry, I couldn't understand that", @@ -403,6 +406,7 @@ }), dict({ 'data': dict({ + 'acknowledge_override': False, 'engine': 'tts.test', 'language': 'en_US', 'tts_input': "Sorry, I couldn't understand that", diff --git a/tests/components/assist_pipeline/snapshots/test_pipeline.ambr b/tests/components/assist_pipeline/snapshots/test_pipeline.ambr index 7a51eddf8d6..e92f3aec3fb 100644 --- a/tests/components/assist_pipeline/snapshots/test_pipeline.ambr +++ b/tests/components/assist_pipeline/snapshots/test_pipeline.ambr @@ -131,6 +131,7 @@ }), dict({ 'data': dict({ + 'acknowledge_override': False, 'engine': 'tts.test', 'language': 'en_US', 'tts_input': 'hello, how are you?', @@ -365,6 +366,7 @@ }), dict({ 'data': dict({ + 'acknowledge_override': False, 'engine': 'tts.test', 'language': 'en_US', 'tts_input': "hello, how are you? I'm doing well, thank you. What about you?!", @@ -595,6 +597,7 @@ }), dict({ 'data': dict({ + 'acknowledge_override': False, 'engine': 'tts.test', 'language': 'en_US', 'tts_input': "I'm doing well, thank you.", diff --git a/tests/components/assist_pipeline/snapshots/test_websocket.ambr b/tests/components/assist_pipeline/snapshots/test_websocket.ambr index 5e0d915a77e..5b5ed44e24d 100644 --- a/tests/components/assist_pipeline/snapshots/test_websocket.ambr +++ b/tests/components/assist_pipeline/snapshots/test_websocket.ambr @@ -73,6 +73,7 @@ # --- # name: test_audio_pipeline.5 dict({ + 'acknowledge_override': False, 'engine': 'tts.test', 'language': 'en_US', 'tts_input': "Sorry, I couldn't understand that", @@ -166,6 +167,7 @@ # --- # name: test_audio_pipeline_debug.5 dict({ + 'acknowledge_override': False, 'engine': 'tts.test', 'language': 'en_US', 'tts_input': "Sorry, I couldn't understand that", @@ -271,6 +273,7 @@ # --- # name: test_audio_pipeline_with_enhancements.5 dict({ + 'acknowledge_override': False, 'engine': 'tts.test', 'language': 'en_US', 'tts_input': "Sorry, I couldn't understand that", @@ -386,6 +389,7 @@ # --- # name: test_audio_pipeline_with_wake_word_no_timeout.7 dict({ + 'acknowledge_override': False, 'engine': 'tts.test', 'language': 'en_US', 'tts_input': "Sorry, I couldn't understand that", diff --git a/tests/components/assist_pipeline/test_pipeline.py b/tests/components/assist_pipeline/test_pipeline.py index 75234122368..fe82f693fde 100644 --- a/tests/components/assist_pipeline/test_pipeline.py +++ b/tests/components/assist_pipeline/test_pipeline.py @@ -16,13 +16,14 @@ from homeassistant.components import ( stt, tts, ) -from homeassistant.components.assist_pipeline.const import DOMAIN +from homeassistant.components.assist_pipeline.const import ACKNOWLEDGE_PATH, DOMAIN from homeassistant.components.assist_pipeline.pipeline import ( STORAGE_KEY, STORAGE_VERSION, STORAGE_VERSION_MINOR, Pipeline, PipelineData, + PipelineEventType, PipelineStorageCollection, PipelineStore, _async_local_fallback_intent_filter, @@ -31,9 +32,16 @@ from homeassistant.components.assist_pipeline.pipeline import ( async_get_pipelines, async_update_pipeline, ) -from homeassistant.const import MATCH_ALL +from homeassistant.const import ATTR_FRIENDLY_NAME, MATCH_ALL from homeassistant.core import Context, HomeAssistant -from homeassistant.helpers import chat_session, intent, llm +from homeassistant.helpers import ( + area_registry as ar, + chat_session, + device_registry as dr, + entity_registry as er, + intent, + llm, +) from homeassistant.setup import async_setup_component from . import MANY_LANGUAGES, process_events @@ -46,7 +54,7 @@ from .conftest import ( make_10ms_chunk, ) -from tests.common import flush_store +from tests.common import MockConfigEntry, async_mock_service, flush_store from tests.typing import ClientSessionGenerator, WebSocketGenerator @@ -1787,3 +1795,296 @@ async def test_chat_log_tts_streaming( assert "".join(received_tts) == chunk_text assert process_events(events) == snapshot + + +async def test_acknowledge( + hass: HomeAssistant, + init_components, + pipeline_data: assist_pipeline.pipeline.PipelineData, + mock_chat_session: chat_session.ChatSession, + entity_registry: er.EntityRegistry, + area_registry: ar.AreaRegistry, + device_registry: dr.DeviceRegistry, +) -> None: + """Test that acknowledge sound is played when targets are in the same area.""" + area_1 = area_registry.async_get_or_create("area_1") + + light_1 = entity_registry.async_get_or_create("light", "demo", "1234") + hass.states.async_set(light_1.entity_id, "off", {ATTR_FRIENDLY_NAME: "light 1"}) + light_1 = entity_registry.async_update_entity(light_1.entity_id, area_id=area_1.id) + + light_2 = entity_registry.async_get_or_create("light", "demo", "5678") + hass.states.async_set(light_2.entity_id, "off", {ATTR_FRIENDLY_NAME: "light 2"}) + light_2 = entity_registry.async_update_entity(light_2.entity_id, area_id=area_1.id) + + entry = MockConfigEntry() + entry.add_to_hass(hass) + satellite = device_registry.async_get_or_create( + config_entry_id=entry.entry_id, + connections=set(), + identifiers={("demo", "id-1234")}, + ) + device_registry.async_update_device(satellite.id, area_id=area_1.id) + + events: list[assist_pipeline.PipelineEvent] = [] + turn_on = async_mock_service(hass, "light", "turn_on") + + pipeline_store = pipeline_data.pipeline_store + pipeline_id = pipeline_store.async_get_preferred_item() + pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id) + + async def _run(text: str) -> None: + pipeline_input = assist_pipeline.pipeline.PipelineInput( + intent_input=text, + session=mock_chat_session, + device_id=satellite.id, + run=assist_pipeline.pipeline.PipelineRun( + hass, + context=Context(), + pipeline=pipeline, + start_stage=assist_pipeline.PipelineStage.INTENT, + end_stage=assist_pipeline.PipelineStage.TTS, + event_callback=events.append, + ), + ) + await pipeline_input.validate() + await pipeline_input.execute() + + with patch( + "homeassistant.components.assist_pipeline.PipelineRun.text_to_speech" + ) as text_to_speech: + + def _reset() -> None: + events.clear() + text_to_speech.reset_mock() + turn_on.clear() + + # 1. All targets in same area + await _run("turn on the lights") + + # Acknowledgment sound should be played (same area) + text_to_speech.assert_called_once() + assert ( + text_to_speech.call_args.kwargs["override_media_path"] == ACKNOWLEDGE_PATH + ) + assert len(turn_on) == 2 + + # 2. One light in a different area + area_2 = area_registry.async_get_or_create("area_2") + light_2 = entity_registry.async_update_entity( + light_2.entity_id, area_id=area_2.id + ) + + _reset() + await _run("turn on light 2") + + # Acknowledgment sound should be not played (different area) + text_to_speech.assert_called_once() + assert text_to_speech.call_args.kwargs.get("override_media_path") is None + assert len(turn_on) == 1 + + # Restore + light_2 = entity_registry.async_update_entity( + light_2.entity_id, area_id=area_1.id + ) + + # 3. Remove satellite device area + device_registry.async_update_device(satellite.id, area_id=None) + + _reset() + await _run("turn on light 1") + + # Acknowledgment sound should be not played (no satellite area) + text_to_speech.assert_called_once() + assert text_to_speech.call_args.kwargs.get("override_media_path") is None + assert len(turn_on) == 1 + + # Restore + device_registry.async_update_device(satellite.id, area_id=area_1.id) + + # 4. Check device area instead of entity area + light_device = device_registry.async_get_or_create( + config_entry_id=entry.entry_id, + connections=set(), + identifiers={("demo", "id-5678")}, + ) + device_registry.async_update_device(light_device.id, area_id=area_1.id) + light_2 = entity_registry.async_update_entity( + light_2.entity_id, area_id=None, device_id=light_device.id + ) + + _reset() + await _run("turn on the lights") + + # Acknowledgment sound should be played (same area) + text_to_speech.assert_called_once() + assert ( + text_to_speech.call_args.kwargs["override_media_path"] == ACKNOWLEDGE_PATH + ) + assert len(turn_on) == 2 + + # 5. Move device to different area + device_registry.async_update_device(light_device.id, area_id=area_2.id) + + _reset() + await _run("turn on light 2") + + # Acknowledgment sound should be not played (different device area) + text_to_speech.assert_called_once() + assert text_to_speech.call_args.kwargs.get("override_media_path") is None + assert len(turn_on) == 1 + + # 6. No device or area + light_2 = entity_registry.async_update_entity( + light_2.entity_id, area_id=None, device_id=None + ) + + _reset() + await _run("turn on light 2") + + # Acknowledgment sound should be not played (no area) + text_to_speech.assert_called_once() + assert text_to_speech.call_args.kwargs.get("override_media_path") is None + assert len(turn_on) == 1 + + # 7. Not in entity registry + hass.states.async_set("light.light_3", "off", {ATTR_FRIENDLY_NAME: "light 3"}) + + _reset() + await _run("turn on light 3") + + # Acknowledgment sound should be not played (not in entity registry) + text_to_speech.assert_called_once() + assert text_to_speech.call_args.kwargs.get("override_media_path") is None + assert len(turn_on) == 1 + + # Check TTS event + events.clear() + await _run("turn on light 1") + + has_acknowledge_override: bool | None = None + for event in events: + if event.type == PipelineEventType.TTS_START: + assert event.data + has_acknowledge_override = event.data["acknowledge_override"] + break + + assert has_acknowledge_override + + +async def test_acknowledge_other_agents( + hass: HomeAssistant, + init_components, + pipeline_data: assist_pipeline.pipeline.PipelineData, + mock_chat_session: chat_session.ChatSession, + entity_registry: er.EntityRegistry, + area_registry: ar.AreaRegistry, + device_registry: dr.DeviceRegistry, +) -> None: + """Test that acknowledge sound is only played when intents are processed locally for other agents.""" + area_1 = area_registry.async_get_or_create("area_1") + + light_1 = entity_registry.async_get_or_create("light", "demo", "1234") + hass.states.async_set(light_1.entity_id, "off", {ATTR_FRIENDLY_NAME: "light 1"}) + light_1 = entity_registry.async_update_entity(light_1.entity_id, area_id=area_1.id) + + light_2 = entity_registry.async_get_or_create("light", "demo", "5678") + hass.states.async_set(light_2.entity_id, "off", {ATTR_FRIENDLY_NAME: "light 2"}) + light_2 = entity_registry.async_update_entity(light_2.entity_id, area_id=area_1.id) + + entry = MockConfigEntry() + entry.add_to_hass(hass) + satellite = device_registry.async_get_or_create( + config_entry_id=entry.entry_id, + connections=set(), + identifiers={("demo", "id-1234")}, + ) + device_registry.async_update_device(satellite.id, area_id=area_1.id) + + events: list[assist_pipeline.PipelineEvent] = [] + async_mock_service(hass, "light", "turn_on") + + pipeline_store = pipeline_data.pipeline_store + pipeline = await pipeline_store.async_create_item( + { + "name": "Test 1", + "language": "en-US", + "conversation_engine": "test agent", + "conversation_language": "en-US", + "tts_engine": "test tts", + "tts_language": "en-US", + "tts_voice": "test voice", + "stt_engine": "test stt", + "stt_language": "en-US", + "wake_word_entity": None, + "wake_word_id": None, + "prefer_local_intents": True, + } + ) + + with ( + patch( + "homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info", + return_value=conversation.AgentInfo( + id="test-agent", + name="Test Agent", + supports_streaming=False, + ), + ), + patch( + "homeassistant.components.assist_pipeline.PipelineRun.prepare_text_to_speech" + ), + patch( + "homeassistant.components.assist_pipeline.PipelineRun.text_to_speech" + ) as text_to_speech, + patch( + "homeassistant.components.conversation.async_converse", return_value=None + ) as async_converse, + patch( + "homeassistant.components.assist_pipeline.PipelineRun._get_all_targets_in_satellite_area" + ) as get_all_targets_in_satellite_area, + ): + pipeline_input = assist_pipeline.pipeline.PipelineInput( + intent_input="turn on the lights", + session=mock_chat_session, + device_id=satellite.id, + run=assist_pipeline.pipeline.PipelineRun( + hass, + context=Context(), + pipeline=pipeline, + start_stage=assist_pipeline.PipelineStage.INTENT, + end_stage=assist_pipeline.PipelineStage.TTS, + event_callback=events.append, + ), + ) + await pipeline_input.validate() + await pipeline_input.execute() + + # Processed locally + async_converse.assert_not_called() + + # Not processed locally + text_to_speech.reset_mock() + get_all_targets_in_satellite_area.reset_mock() + + pipeline_input = assist_pipeline.pipeline.PipelineInput( + intent_input="not processed locally", + session=mock_chat_session, + device_id=satellite.id, + run=assist_pipeline.pipeline.PipelineRun( + hass, + context=Context(), + pipeline=pipeline, + start_stage=assist_pipeline.PipelineStage.INTENT, + end_stage=assist_pipeline.PipelineStage.TTS, + event_callback=events.append, + ), + ) + await pipeline_input.validate() + await pipeline_input.execute() + + # The acknowledgment should not have even been checked for because the + # default agent didn't handle the intent. + text_to_speech.assert_not_called() + async_converse.assert_called_once() + get_all_targets_in_satellite_area.assert_not_called()