Acknowledge if targets in same area (#150655)

Co-authored-by: Artur Pragacz <49985303+arturpragacz@users.noreply.github.com>
This commit is contained in:
Michael Hansen
2025-09-12 13:40:16 -05:00
committed by GitHub
parent bfe1dd65b3
commit 3de701a9ab
7 changed files with 405 additions and 14 deletions

View File

@@ -1,5 +1,7 @@
"""Constants for the Assist pipeline integration."""
from pathlib import Path
DOMAIN = "assist_pipeline"
DATA_CONFIG = f"{DOMAIN}.config"
@@ -23,3 +25,5 @@ SAMPLES_PER_CHUNK = SAMPLE_RATE // (1000 // MS_PER_CHUNK) # 10 ms @ 16Khz
BYTES_PER_CHUNK = SAMPLES_PER_CHUNK * SAMPLE_WIDTH * SAMPLE_CHANNELS # 16-bit
OPTION_PREFERRED = "preferred"
ACKNOWLEDGE_PATH = Path(__file__).parent / "acknowledge.mp3"

View File

@@ -23,7 +23,12 @@ from homeassistant.components import conversation, stt, tts, wake_word, websocke
from homeassistant.const import ATTR_SUPPORTED_FEATURES, MATCH_ALL
from homeassistant.core import Context, HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import chat_session, intent
from homeassistant.helpers import (
chat_session,
device_registry as dr,
entity_registry as er,
intent,
)
from homeassistant.helpers.collection import (
CHANGE_UPDATED,
CollectionError,
@@ -45,6 +50,7 @@ from homeassistant.util.limited_size_dict import LimitedSizeDict
from .audio_enhancer import AudioEnhancer, EnhancedAudioChunk, MicroVadSpeexEnhancer
from .const import (
ACKNOWLEDGE_PATH,
BYTES_PER_CHUNK,
CONF_DEBUG_RECORDING_DIR,
DATA_CONFIG,
@@ -113,6 +119,7 @@ PIPELINE_FIELDS: VolDictType = {
vol.Required("wake_word_entity"): vol.Any(str, None),
vol.Required("wake_word_id"): vol.Any(str, None),
vol.Optional("prefer_local_intents"): bool,
vol.Optional("acknowledge_media_id"): str,
}
STORED_PIPELINE_RUNS = 10
@@ -1066,8 +1073,11 @@ class PipelineRun:
intent_input: str,
conversation_id: str,
conversation_extra_system_prompt: str | None,
) -> str:
"""Run intent recognition portion of pipeline. Returns text to speak."""
) -> tuple[str, bool]:
"""Run intent recognition portion of pipeline.
Returns (speech, all_targets_in_satellite_area).
"""
if self.intent_agent is None or self._conversation_data is None:
raise RuntimeError("Recognize intent was not prepared")
@@ -1116,6 +1126,7 @@ class PipelineRun:
agent_id = self.intent_agent.id
processed_locally = agent_id == conversation.HOME_ASSISTANT_AGENT
all_targets_in_satellite_area = False
intent_response: intent.IntentResponse | None = None
if not processed_locally and not self._intent_agent_only:
# Sentence triggers override conversation agent
@@ -1290,6 +1301,17 @@ class PipelineRun:
if tts_input_stream and self._streamed_response_text:
tts_input_stream.put_nowait(None)
if agent_id == conversation.HOME_ASSISTANT_AGENT:
# Check if all targeted entities were in the same area as
# the satellite device.
# If so, the satellite should respond with an acknowledge beep
# instead of a full response.
all_targets_in_satellite_area = (
self._get_all_targets_in_satellite_area(
conversation_result.response, self._device_id
)
)
except Exception as src_error:
_LOGGER.exception("Unexpected error during intent recognition")
raise IntentRecognitionError(
@@ -1312,7 +1334,45 @@ class PipelineRun:
if conversation_result.continue_conversation:
self._conversation_data.continue_conversation_agent = agent_id
return speech
return (speech, all_targets_in_satellite_area)
def _get_all_targets_in_satellite_area(
self, intent_response: intent.IntentResponse, device_id: str | None
) -> bool:
"""Return true if all targeted entities were in the same area as the device."""
if (
(intent_response.response_type != intent.IntentResponseType.ACTION_DONE)
or (not intent_response.matched_states)
or (not device_id)
):
return False
device_registry = dr.async_get(self.hass)
if (not (device := device_registry.async_get(device_id))) or (
not device.area_id
):
return False
entity_registry = er.async_get(self.hass)
for state in intent_response.matched_states:
entity = entity_registry.async_get(state.entity_id)
if not entity:
return False
if (entity_area_id := entity.area_id) is None:
if (entity.device_id is None) or (
(entity_device := device_registry.async_get(entity.device_id))
is None
):
return False
entity_area_id = entity_device.area_id
if entity_area_id != device.area_id:
return False
return True
async def prepare_text_to_speech(self) -> None:
"""Prepare text-to-speech."""
@@ -1350,7 +1410,9 @@ class PipelineRun:
),
) from err
async def text_to_speech(self, tts_input: str) -> None:
async def text_to_speech(
self, tts_input: str, override_media_path: Path | None = None
) -> None:
"""Run text-to-speech portion of pipeline."""
assert self.tts_stream is not None
@@ -1362,11 +1424,14 @@ class PipelineRun:
"language": self.pipeline.tts_language,
"voice": self.pipeline.tts_voice,
"tts_input": tts_input,
"acknowledge_override": override_media_path is not None,
},
)
)
if not self._streamed_response_text:
if override_media_path:
self.tts_stream.async_override_result(override_media_path)
elif not self._streamed_response_text:
self.tts_stream.async_set_message(tts_input)
tts_output = {
@@ -1664,16 +1729,20 @@ class PipelineInput:
if self.run.end_stage != PipelineStage.STT:
tts_input = self.tts_input
all_targets_in_satellite_area = False
if current_stage == PipelineStage.INTENT:
# intent-recognition
assert intent_input is not None
tts_input = await self.run.recognize_intent(
(
tts_input,
all_targets_in_satellite_area,
) = await self.run.recognize_intent(
intent_input,
self.session.conversation_id,
self.conversation_extra_system_prompt,
)
if tts_input.strip():
if all_targets_in_satellite_area or tts_input.strip():
current_stage = PipelineStage.TTS
else:
# Skip TTS
@@ -1682,8 +1751,14 @@ class PipelineInput:
if self.run.end_stage != PipelineStage.INTENT:
# text-to-speech
if current_stage == PipelineStage.TTS:
assert tts_input is not None
await self.run.text_to_speech(tts_input)
if all_targets_in_satellite_area:
# Use acknowledge media instead of full response
await self.run.text_to_speech(
tts_input or "", override_media_path=ACKNOWLEDGE_PATH
)
else:
assert tts_input is not None
await self.run.text_to_speech(tts_input)
except PipelineError as err:
self.run.process_event(

View File

@@ -76,6 +76,7 @@
}),
dict({
'data': dict({
'acknowledge_override': False,
'engine': 'tts.test',
'language': 'en_US',
'tts_input': "Sorry, I couldn't understand that",
@@ -177,6 +178,7 @@
}),
dict({
'data': dict({
'acknowledge_override': False,
'engine': 'test',
'language': 'en-US',
'tts_input': "Sorry, I couldn't understand that",
@@ -278,6 +280,7 @@
}),
dict({
'data': dict({
'acknowledge_override': False,
'engine': 'test',
'language': 'en-US',
'tts_input': "Sorry, I couldn't understand that",
@@ -403,6 +406,7 @@
}),
dict({
'data': dict({
'acknowledge_override': False,
'engine': 'tts.test',
'language': 'en_US',
'tts_input': "Sorry, I couldn't understand that",

View File

@@ -131,6 +131,7 @@
}),
dict({
'data': dict({
'acknowledge_override': False,
'engine': 'tts.test',
'language': 'en_US',
'tts_input': 'hello, how are you?',
@@ -365,6 +366,7 @@
}),
dict({
'data': dict({
'acknowledge_override': False,
'engine': 'tts.test',
'language': 'en_US',
'tts_input': "hello, how are you? I'm doing well, thank you. What about you?!",
@@ -595,6 +597,7 @@
}),
dict({
'data': dict({
'acknowledge_override': False,
'engine': 'tts.test',
'language': 'en_US',
'tts_input': "I'm doing well, thank you.",

View File

@@ -73,6 +73,7 @@
# ---
# name: test_audio_pipeline.5
dict({
'acknowledge_override': False,
'engine': 'tts.test',
'language': 'en_US',
'tts_input': "Sorry, I couldn't understand that",
@@ -166,6 +167,7 @@
# ---
# name: test_audio_pipeline_debug.5
dict({
'acknowledge_override': False,
'engine': 'tts.test',
'language': 'en_US',
'tts_input': "Sorry, I couldn't understand that",
@@ -271,6 +273,7 @@
# ---
# name: test_audio_pipeline_with_enhancements.5
dict({
'acknowledge_override': False,
'engine': 'tts.test',
'language': 'en_US',
'tts_input': "Sorry, I couldn't understand that",
@@ -386,6 +389,7 @@
# ---
# name: test_audio_pipeline_with_wake_word_no_timeout.7
dict({
'acknowledge_override': False,
'engine': 'tts.test',
'language': 'en_US',
'tts_input': "Sorry, I couldn't understand that",

View File

@@ -16,13 +16,14 @@ from homeassistant.components import (
stt,
tts,
)
from homeassistant.components.assist_pipeline.const import DOMAIN
from homeassistant.components.assist_pipeline.const import ACKNOWLEDGE_PATH, DOMAIN
from homeassistant.components.assist_pipeline.pipeline import (
STORAGE_KEY,
STORAGE_VERSION,
STORAGE_VERSION_MINOR,
Pipeline,
PipelineData,
PipelineEventType,
PipelineStorageCollection,
PipelineStore,
_async_local_fallback_intent_filter,
@@ -31,9 +32,16 @@ from homeassistant.components.assist_pipeline.pipeline import (
async_get_pipelines,
async_update_pipeline,
)
from homeassistant.const import MATCH_ALL
from homeassistant.const import ATTR_FRIENDLY_NAME, MATCH_ALL
from homeassistant.core import Context, HomeAssistant
from homeassistant.helpers import chat_session, intent, llm
from homeassistant.helpers import (
area_registry as ar,
chat_session,
device_registry as dr,
entity_registry as er,
intent,
llm,
)
from homeassistant.setup import async_setup_component
from . import MANY_LANGUAGES, process_events
@@ -46,7 +54,7 @@ from .conftest import (
make_10ms_chunk,
)
from tests.common import flush_store
from tests.common import MockConfigEntry, async_mock_service, flush_store
from tests.typing import ClientSessionGenerator, WebSocketGenerator
@@ -1787,3 +1795,296 @@ async def test_chat_log_tts_streaming(
assert "".join(received_tts) == chunk_text
assert process_events(events) == snapshot
async def test_acknowledge(
hass: HomeAssistant,
init_components,
pipeline_data: assist_pipeline.pipeline.PipelineData,
mock_chat_session: chat_session.ChatSession,
entity_registry: er.EntityRegistry,
area_registry: ar.AreaRegistry,
device_registry: dr.DeviceRegistry,
) -> None:
"""Test that acknowledge sound is played when targets are in the same area."""
area_1 = area_registry.async_get_or_create("area_1")
light_1 = entity_registry.async_get_or_create("light", "demo", "1234")
hass.states.async_set(light_1.entity_id, "off", {ATTR_FRIENDLY_NAME: "light 1"})
light_1 = entity_registry.async_update_entity(light_1.entity_id, area_id=area_1.id)
light_2 = entity_registry.async_get_or_create("light", "demo", "5678")
hass.states.async_set(light_2.entity_id, "off", {ATTR_FRIENDLY_NAME: "light 2"})
light_2 = entity_registry.async_update_entity(light_2.entity_id, area_id=area_1.id)
entry = MockConfigEntry()
entry.add_to_hass(hass)
satellite = device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections=set(),
identifiers={("demo", "id-1234")},
)
device_registry.async_update_device(satellite.id, area_id=area_1.id)
events: list[assist_pipeline.PipelineEvent] = []
turn_on = async_mock_service(hass, "light", "turn_on")
pipeline_store = pipeline_data.pipeline_store
pipeline_id = pipeline_store.async_get_preferred_item()
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
async def _run(text: str) -> None:
pipeline_input = assist_pipeline.pipeline.PipelineInput(
intent_input=text,
session=mock_chat_session,
device_id=satellite.id,
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.INTENT,
end_stage=assist_pipeline.PipelineStage.TTS,
event_callback=events.append,
),
)
await pipeline_input.validate()
await pipeline_input.execute()
with patch(
"homeassistant.components.assist_pipeline.PipelineRun.text_to_speech"
) as text_to_speech:
def _reset() -> None:
events.clear()
text_to_speech.reset_mock()
turn_on.clear()
# 1. All targets in same area
await _run("turn on the lights")
# Acknowledgment sound should be played (same area)
text_to_speech.assert_called_once()
assert (
text_to_speech.call_args.kwargs["override_media_path"] == ACKNOWLEDGE_PATH
)
assert len(turn_on) == 2
# 2. One light in a different area
area_2 = area_registry.async_get_or_create("area_2")
light_2 = entity_registry.async_update_entity(
light_2.entity_id, area_id=area_2.id
)
_reset()
await _run("turn on light 2")
# Acknowledgment sound should be not played (different area)
text_to_speech.assert_called_once()
assert text_to_speech.call_args.kwargs.get("override_media_path") is None
assert len(turn_on) == 1
# Restore
light_2 = entity_registry.async_update_entity(
light_2.entity_id, area_id=area_1.id
)
# 3. Remove satellite device area
device_registry.async_update_device(satellite.id, area_id=None)
_reset()
await _run("turn on light 1")
# Acknowledgment sound should be not played (no satellite area)
text_to_speech.assert_called_once()
assert text_to_speech.call_args.kwargs.get("override_media_path") is None
assert len(turn_on) == 1
# Restore
device_registry.async_update_device(satellite.id, area_id=area_1.id)
# 4. Check device area instead of entity area
light_device = device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections=set(),
identifiers={("demo", "id-5678")},
)
device_registry.async_update_device(light_device.id, area_id=area_1.id)
light_2 = entity_registry.async_update_entity(
light_2.entity_id, area_id=None, device_id=light_device.id
)
_reset()
await _run("turn on the lights")
# Acknowledgment sound should be played (same area)
text_to_speech.assert_called_once()
assert (
text_to_speech.call_args.kwargs["override_media_path"] == ACKNOWLEDGE_PATH
)
assert len(turn_on) == 2
# 5. Move device to different area
device_registry.async_update_device(light_device.id, area_id=area_2.id)
_reset()
await _run("turn on light 2")
# Acknowledgment sound should be not played (different device area)
text_to_speech.assert_called_once()
assert text_to_speech.call_args.kwargs.get("override_media_path") is None
assert len(turn_on) == 1
# 6. No device or area
light_2 = entity_registry.async_update_entity(
light_2.entity_id, area_id=None, device_id=None
)
_reset()
await _run("turn on light 2")
# Acknowledgment sound should be not played (no area)
text_to_speech.assert_called_once()
assert text_to_speech.call_args.kwargs.get("override_media_path") is None
assert len(turn_on) == 1
# 7. Not in entity registry
hass.states.async_set("light.light_3", "off", {ATTR_FRIENDLY_NAME: "light 3"})
_reset()
await _run("turn on light 3")
# Acknowledgment sound should be not played (not in entity registry)
text_to_speech.assert_called_once()
assert text_to_speech.call_args.kwargs.get("override_media_path") is None
assert len(turn_on) == 1
# Check TTS event
events.clear()
await _run("turn on light 1")
has_acknowledge_override: bool | None = None
for event in events:
if event.type == PipelineEventType.TTS_START:
assert event.data
has_acknowledge_override = event.data["acknowledge_override"]
break
assert has_acknowledge_override
async def test_acknowledge_other_agents(
hass: HomeAssistant,
init_components,
pipeline_data: assist_pipeline.pipeline.PipelineData,
mock_chat_session: chat_session.ChatSession,
entity_registry: er.EntityRegistry,
area_registry: ar.AreaRegistry,
device_registry: dr.DeviceRegistry,
) -> None:
"""Test that acknowledge sound is only played when intents are processed locally for other agents."""
area_1 = area_registry.async_get_or_create("area_1")
light_1 = entity_registry.async_get_or_create("light", "demo", "1234")
hass.states.async_set(light_1.entity_id, "off", {ATTR_FRIENDLY_NAME: "light 1"})
light_1 = entity_registry.async_update_entity(light_1.entity_id, area_id=area_1.id)
light_2 = entity_registry.async_get_or_create("light", "demo", "5678")
hass.states.async_set(light_2.entity_id, "off", {ATTR_FRIENDLY_NAME: "light 2"})
light_2 = entity_registry.async_update_entity(light_2.entity_id, area_id=area_1.id)
entry = MockConfigEntry()
entry.add_to_hass(hass)
satellite = device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections=set(),
identifiers={("demo", "id-1234")},
)
device_registry.async_update_device(satellite.id, area_id=area_1.id)
events: list[assist_pipeline.PipelineEvent] = []
async_mock_service(hass, "light", "turn_on")
pipeline_store = pipeline_data.pipeline_store
pipeline = await pipeline_store.async_create_item(
{
"name": "Test 1",
"language": "en-US",
"conversation_engine": "test agent",
"conversation_language": "en-US",
"tts_engine": "test tts",
"tts_language": "en-US",
"tts_voice": "test voice",
"stt_engine": "test stt",
"stt_language": "en-US",
"wake_word_entity": None,
"wake_word_id": None,
"prefer_local_intents": True,
}
)
with (
patch(
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
return_value=conversation.AgentInfo(
id="test-agent",
name="Test Agent",
supports_streaming=False,
),
),
patch(
"homeassistant.components.assist_pipeline.PipelineRun.prepare_text_to_speech"
),
patch(
"homeassistant.components.assist_pipeline.PipelineRun.text_to_speech"
) as text_to_speech,
patch(
"homeassistant.components.conversation.async_converse", return_value=None
) as async_converse,
patch(
"homeassistant.components.assist_pipeline.PipelineRun._get_all_targets_in_satellite_area"
) as get_all_targets_in_satellite_area,
):
pipeline_input = assist_pipeline.pipeline.PipelineInput(
intent_input="turn on the lights",
session=mock_chat_session,
device_id=satellite.id,
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.INTENT,
end_stage=assist_pipeline.PipelineStage.TTS,
event_callback=events.append,
),
)
await pipeline_input.validate()
await pipeline_input.execute()
# Processed locally
async_converse.assert_not_called()
# Not processed locally
text_to_speech.reset_mock()
get_all_targets_in_satellite_area.reset_mock()
pipeline_input = assist_pipeline.pipeline.PipelineInput(
intent_input="not processed locally",
session=mock_chat_session,
device_id=satellite.id,
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.INTENT,
end_stage=assist_pipeline.PipelineStage.TTS,
event_callback=events.append,
),
)
await pipeline_input.validate()
await pipeline_input.execute()
# The acknowledgment should not have even been checked for because the
# default agent didn't handle the intent.
text_to_speech.assert_not_called()
async_converse.assert_called_once()
get_all_targets_in_satellite_area.assert_not_called()