Fix Tuya device registry cleanup (#161268)

This commit is contained in:
epenet
2026-01-27 15:23:08 +01:00
committed by GitHub
parent 6e2092b784
commit 71f17f2cf1
4 changed files with 217 additions and 109 deletions

View File

@@ -82,7 +82,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: TuyaConfigEntry) -> bool
entry.runtime_data = HomeAssistantTuyaData(manager=manager, listener=listener)
# Cleanup device registry
await cleanup_device_registry(hass, manager)
await cleanup_device_registry(hass, manager, entry)
# Register known device IDs
device_registry = dr.async_get(hass)
@@ -114,13 +114,19 @@ async def async_setup_entry(hass: HomeAssistant, entry: TuyaConfigEntry) -> bool
return True
async def cleanup_device_registry(hass: HomeAssistant, device_manager: Manager) -> None:
"""Remove deleted device registry entry if there are no remaining entities."""
async def cleanup_device_registry(
hass: HomeAssistant, device_manager: Manager, entry: TuyaConfigEntry
) -> None:
"""Unlink device registry entry if there are no remaining entities."""
device_registry = dr.async_get(hass)
for dev_id, device_entry in list(device_registry.devices.items()):
for device_entry in dr.async_entries_for_config_entry(
device_registry, entry.entry_id
):
for item in device_entry.identifiers:
if item[0] == DOMAIN and item[1] not in device_manager.device_map:
device_registry.async_remove_device(dev_id)
device_registry.async_update_device(
device_entry.id, remove_config_entry_id=entry.entry_id
)
break

View File

@@ -4,15 +4,23 @@ from __future__ import annotations
import pathlib
from typing import Any
from unittest.mock import patch
from unittest.mock import MagicMock, patch
from freezegun.api import FrozenDateTimeFactory
from tuya_sharing import CustomerDevice, Manager
from tuya_sharing import (
CustomerApi,
CustomerDevice,
DeviceFunction,
DeviceStatusRange,
Manager,
)
from homeassistant.components.tuya import DeviceListener
from homeassistant.components.tuya import DOMAIN, DeviceListener
from homeassistant.core import HomeAssistant
from homeassistant.helpers.json import json_dumps
from homeassistant.util import dt as dt_util
from tests.common import MockConfigEntry
from tests.common import MockConfigEntry, async_load_json_object_fixture
FIXTURES_DIR = pathlib.Path(__file__).parent / "fixtures"
DEVICE_MOCKS = sorted(
@@ -45,6 +53,98 @@ class MockDeviceListener(DeviceListener):
await hass.async_block_till_done()
async def create_device(hass: HomeAssistant, mock_device_code: str) -> CustomerDevice:
"""Create a CustomerDevice for testing."""
details = await async_load_json_object_fixture(
hass, f"{mock_device_code}.json", DOMAIN
)
device = MagicMock(spec=CustomerDevice)
# Use reverse of the product_id for testing
device.id = mock_device_code.replace("_", "")[::-1]
device.name = details["name"]
device.category = details["category"]
device.product_id = details["product_id"]
device.product_name = details["product_name"]
device.online = details["online"]
device.sub = details.get("sub")
device.time_zone = details.get("time_zone")
device.active_time = details.get("active_time")
if device.active_time:
device.active_time = int(dt_util.as_timestamp(device.active_time))
device.create_time = details.get("create_time")
if device.create_time:
device.create_time = int(dt_util.as_timestamp(device.create_time))
device.update_time = details.get("update_time")
if device.update_time:
device.update_time = int(dt_util.as_timestamp(device.update_time))
device.support_local = details.get("support_local")
device.local_strategy = details.get("local_strategy")
device.mqtt_connected = details.get("mqtt_connected")
device.function = {
key: DeviceFunction(
code=key,
type=value["type"],
values=(
values
if isinstance(values := value["value"], str)
else json_dumps(values)
),
)
for key, value in details["function"].items()
}
device.status_range = {
key: DeviceStatusRange(
code=key,
report_type=value.get("report_type"),
type=value["type"],
values=(
values
if isinstance(values := value["value"], str)
else json_dumps(values)
),
)
for key, value in details["status_range"].items()
}
device.status = details["status"]
for key, value in device.status.items():
# Some devices do not provide a status_range for all status DPs
# Others set the type as String in status_range and as Json in function
if ((dp_type := device.status_range.get(key)) and dp_type.type == "Json") or (
(dp_type := device.function.get(key)) and dp_type.type == "Json"
):
device.status[key] = json_dumps(value)
if value == "**REDACTED**":
# It was redacted, which may cause issue with b64decode
device.status[key] = ""
return device
def create_listener(hass: HomeAssistant, manager: Manager) -> MockDeviceListener:
"""Create a DeviceListener for testing."""
listener = MockDeviceListener(hass, manager)
manager.add_device_listener(listener)
return listener
def create_manager(
terminal_id: str = "7cd96aff-6ec8-4006-b093-3dbff7947591",
) -> Manager:
"""Create a Manager for testing."""
manager = MagicMock(spec=Manager)
manager.device_map = {}
manager.mq = MagicMock()
manager.mq.client = MagicMock()
manager.mq.client.is_connected = MagicMock(return_value=True)
manager.customer_api = MagicMock(spec=CustomerApi)
# Meaningless URL / UUIDs
manager.customer_api.endpoint = "https://apigw.tuyaeu.com"
manager.terminal_id = terminal_id
return manager
async def initialize_entry(
hass: HomeAssistant,
mock_manager: Manager,

View File

@@ -6,13 +6,7 @@ from collections.abc import Generator
from unittest.mock import MagicMock, patch
import pytest
from tuya_sharing import (
CustomerApi,
CustomerDevice,
DeviceFunction,
DeviceStatusRange,
Manager,
)
from tuya_sharing import CustomerDevice, Manager
from homeassistant.components.tuya.const import (
CONF_ENDPOINT,
@@ -22,12 +16,16 @@ from homeassistant.components.tuya.const import (
DOMAIN,
)
from homeassistant.core import HomeAssistant
from homeassistant.helpers.json import json_dumps
from homeassistant.util import dt as dt_util
from . import DEVICE_MOCKS, MockDeviceListener
from . import (
DEVICE_MOCKS,
MockDeviceListener,
create_device,
create_listener,
create_manager,
)
from tests.common import MockConfigEntry, async_load_json_object_fixture
from tests.common import MockConfigEntry
@pytest.fixture
@@ -108,17 +106,8 @@ def mock_tuya_login_control() -> Generator[MagicMock]:
@pytest.fixture
def mock_manager() -> Manager:
"""Mock Tuya Manager."""
manager = MagicMock(spec=Manager)
manager.device_map = {}
manager.mq = MagicMock()
manager.mq.client = MagicMock()
manager.mq.client.is_connected = MagicMock(return_value=True)
manager.customer_api = MagicMock(spec=CustomerApi)
# Meaningless URL / UUIDs
manager.customer_api.endpoint = "https://apigw.tuyaeu.com"
manager.terminal_id = "7cd96aff-6ec8-4006-b093-3dbff7947591"
return manager
"""Fixture for Tuya Manager."""
return create_manager()
@pytest.fixture
@@ -137,7 +126,7 @@ async def mock_devices(hass: HomeAssistant) -> list[CustomerDevice]:
Use this to generate global snapshots for each platform.
"""
return [await _create_device(hass, device_code) for device_code in DEVICE_MOCKS]
return [await create_device(hass, device_code) for device_code in DEVICE_MOCKS]
@pytest.fixture
@@ -146,81 +135,10 @@ async def mock_device(hass: HomeAssistant, mock_device_code: str) -> CustomerDev
Use this for testing behavior on a specific device.
"""
return await _create_device(hass, mock_device_code)
async def _create_device(hass: HomeAssistant, mock_device_code: str) -> CustomerDevice:
"""Mock a Tuya CustomerDevice."""
details = await async_load_json_object_fixture(
hass, f"{mock_device_code}.json", DOMAIN
)
device = MagicMock(spec=CustomerDevice)
# Use reverse of the product_id for testing
device.id = mock_device_code.replace("_", "")[::-1]
device.name = details["name"]
device.category = details["category"]
device.product_id = details["product_id"]
device.product_name = details["product_name"]
device.online = details["online"]
device.sub = details.get("sub")
device.time_zone = details.get("time_zone")
device.active_time = details.get("active_time")
if device.active_time:
device.active_time = int(dt_util.as_timestamp(device.active_time))
device.create_time = details.get("create_time")
if device.create_time:
device.create_time = int(dt_util.as_timestamp(device.create_time))
device.update_time = details.get("update_time")
if device.update_time:
device.update_time = int(dt_util.as_timestamp(device.update_time))
device.support_local = details.get("support_local")
device.local_strategy = details.get("local_strategy")
device.mqtt_connected = details.get("mqtt_connected")
device.function = {
key: DeviceFunction(
code=key,
type=value["type"],
values=(
values
if isinstance(values := value["value"], str)
else json_dumps(values)
),
)
for key, value in details["function"].items()
}
device.status_range = {
key: DeviceStatusRange(
code=key,
report_type=value.get("report_type"),
type=value["type"],
values=(
values
if isinstance(values := value["value"], str)
else json_dumps(values)
),
)
for key, value in details["status_range"].items()
}
device.status = details["status"]
for key, value in device.status.items():
# Some devices do not provide a status_range for all status DPs
# Others set the type as String in status_range and as Json in function
if ((dp_type := device.status_range.get(key)) and dp_type.type == "Json") or (
(dp_type := device.function.get(key)) and dp_type.type == "Json"
):
device.status[key] = json_dumps(value)
if value == "**REDACTED**":
# It was redacted, which may cause issue with b64decode
device.status[key] = ""
return device
return await create_device(hass, mock_device_code)
@pytest.fixture
def mock_listener(hass: HomeAssistant, mock_manager: Manager) -> MockDeviceListener:
"""Create a DeviceListener for testing."""
listener = MockDeviceListener(hass, mock_manager)
mock_manager.add_device_listener(listener)
return listener
"""Fixture for Tuya DeviceListener."""
return create_listener(hass, mock_manager)

View File

@@ -2,24 +2,108 @@
from __future__ import annotations
from unittest.mock import patch
from syrupy.assertion import SnapshotAssertion
from tuya_sharing import CustomerDevice, Manager
from homeassistant.components.tuya.const import DOMAIN
from homeassistant.components.tuya.const import (
CONF_ENDPOINT,
CONF_TERMINAL_ID,
CONF_TOKEN_INFO,
CONF_USER_CODE,
DOMAIN,
)
from homeassistant.components.tuya.diagnostics import _REDACTED_DPCODES
from homeassistant.core import HomeAssistant
from homeassistant.helpers import device_registry as dr, entity_registry as er
from . import DEVICE_MOCKS, initialize_entry
from . import DEVICE_MOCKS, create_device, create_manager, initialize_entry
from tests.common import MockConfigEntry, async_load_json_object_fixture
async def test_registry_cleanup(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
device_registry: dr.DeviceRegistry,
) -> None:
"""Ensure no-longer-present devices are removed from the device registry."""
# Initialize with two devices
main_manager = create_manager()
main_device = await create_device(hass, "mcs_8yhypbo7")
second_device = await create_device(hass, "clkg_y7j64p60glp8qpx7")
await initialize_entry(
hass, main_manager, mock_config_entry, [main_device, second_device]
)
# Initialize should have two devices
all_entries = dr.async_entries_for_config_entry(
device_registry, mock_config_entry.entry_id
)
assert len(all_entries) == 2
# Now remove the second device from the manager and re-initialize
del main_manager.device_map[second_device.id]
with patch("homeassistant.components.tuya.Manager", return_value=main_manager):
await hass.config_entries.async_reload(mock_config_entry.entry_id)
await hass.async_block_till_done()
# Only the main device should remain
all_entries = dr.async_entries_for_config_entry(
device_registry, mock_config_entry.entry_id
)
assert len(all_entries) == 1
assert all_entries[0].identifiers == {(DOMAIN, main_device.id)}
async def test_registry_cleanup_multiple_entries(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
entity_registry: er.EntityRegistry,
) -> None:
"""Ensure multiple config entries do not remove items from other entries."""
main_entity_id = "sensor.boite_aux_lettres_arriere_battery"
second_entity_id = "binary_sensor.window_downstairs_door"
main_manager = create_manager()
main_device = await create_device(hass, "mcs_8yhypbo7")
await initialize_entry(hass, main_manager, mock_config_entry, main_device)
# Ensure initial setup is correct (main present, second absent)
assert hass.states.get(main_entity_id)
assert entity_registry.async_get(main_entity_id)
assert not hass.states.get(second_entity_id)
assert not entity_registry.async_get(second_entity_id)
# Create a second config entry
second_config_entry = MockConfigEntry(
title="Test Tuya entry",
domain=DOMAIN,
data={
CONF_ENDPOINT: "test_endpoint",
CONF_TERMINAL_ID: "test_terminal",
CONF_TOKEN_INFO: "test_token",
CONF_USER_CODE: "test_user_code",
},
unique_id="56789",
)
second_manager = create_manager()
second_device = await create_device(hass, "mcs_oxslv1c9")
await initialize_entry(hass, second_manager, second_config_entry, second_device)
# Ensure setup is correct (both present)
assert hass.states.get(main_entity_id)
assert entity_registry.async_get(main_entity_id)
assert hass.states.get(second_entity_id)
assert entity_registry.async_get(second_entity_id)
async def test_device_registry(
hass: HomeAssistant,
mock_manager: Manager,
mock_config_entry: MockConfigEntry,
mock_devices: CustomerDevice,
mock_devices: list[CustomerDevice],
device_registry: dr.DeviceRegistry,
entity_registry: er.EntityRegistry,
snapshot: SnapshotAssertion,