mirror of
https://github.com/Electric-Special/ha-core.git
synced 2026-03-21 03:03:17 +01:00
Fix Tuya device registry cleanup (#161268)
This commit is contained in:
@@ -82,7 +82,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: TuyaConfigEntry) -> bool
|
||||
entry.runtime_data = HomeAssistantTuyaData(manager=manager, listener=listener)
|
||||
|
||||
# Cleanup device registry
|
||||
await cleanup_device_registry(hass, manager)
|
||||
await cleanup_device_registry(hass, manager, entry)
|
||||
|
||||
# Register known device IDs
|
||||
device_registry = dr.async_get(hass)
|
||||
@@ -114,13 +114,19 @@ async def async_setup_entry(hass: HomeAssistant, entry: TuyaConfigEntry) -> bool
|
||||
return True
|
||||
|
||||
|
||||
async def cleanup_device_registry(hass: HomeAssistant, device_manager: Manager) -> None:
|
||||
"""Remove deleted device registry entry if there are no remaining entities."""
|
||||
async def cleanup_device_registry(
|
||||
hass: HomeAssistant, device_manager: Manager, entry: TuyaConfigEntry
|
||||
) -> None:
|
||||
"""Unlink device registry entry if there are no remaining entities."""
|
||||
device_registry = dr.async_get(hass)
|
||||
for dev_id, device_entry in list(device_registry.devices.items()):
|
||||
for device_entry in dr.async_entries_for_config_entry(
|
||||
device_registry, entry.entry_id
|
||||
):
|
||||
for item in device_entry.identifiers:
|
||||
if item[0] == DOMAIN and item[1] not in device_manager.device_map:
|
||||
device_registry.async_remove_device(dev_id)
|
||||
device_registry.async_update_device(
|
||||
device_entry.id, remove_config_entry_id=entry.entry_id
|
||||
)
|
||||
break
|
||||
|
||||
|
||||
|
||||
@@ -4,15 +4,23 @@ from __future__ import annotations
|
||||
|
||||
import pathlib
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from freezegun.api import FrozenDateTimeFactory
|
||||
from tuya_sharing import CustomerDevice, Manager
|
||||
from tuya_sharing import (
|
||||
CustomerApi,
|
||||
CustomerDevice,
|
||||
DeviceFunction,
|
||||
DeviceStatusRange,
|
||||
Manager,
|
||||
)
|
||||
|
||||
from homeassistant.components.tuya import DeviceListener
|
||||
from homeassistant.components.tuya import DOMAIN, DeviceListener
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.json import json_dumps
|
||||
from homeassistant.util import dt as dt_util
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
from tests.common import MockConfigEntry, async_load_json_object_fixture
|
||||
|
||||
FIXTURES_DIR = pathlib.Path(__file__).parent / "fixtures"
|
||||
DEVICE_MOCKS = sorted(
|
||||
@@ -45,6 +53,98 @@ class MockDeviceListener(DeviceListener):
|
||||
await hass.async_block_till_done()
|
||||
|
||||
|
||||
async def create_device(hass: HomeAssistant, mock_device_code: str) -> CustomerDevice:
|
||||
"""Create a CustomerDevice for testing."""
|
||||
details = await async_load_json_object_fixture(
|
||||
hass, f"{mock_device_code}.json", DOMAIN
|
||||
)
|
||||
device = MagicMock(spec=CustomerDevice)
|
||||
|
||||
# Use reverse of the product_id for testing
|
||||
device.id = mock_device_code.replace("_", "")[::-1]
|
||||
|
||||
device.name = details["name"]
|
||||
device.category = details["category"]
|
||||
device.product_id = details["product_id"]
|
||||
device.product_name = details["product_name"]
|
||||
device.online = details["online"]
|
||||
device.sub = details.get("sub")
|
||||
device.time_zone = details.get("time_zone")
|
||||
device.active_time = details.get("active_time")
|
||||
if device.active_time:
|
||||
device.active_time = int(dt_util.as_timestamp(device.active_time))
|
||||
device.create_time = details.get("create_time")
|
||||
if device.create_time:
|
||||
device.create_time = int(dt_util.as_timestamp(device.create_time))
|
||||
device.update_time = details.get("update_time")
|
||||
if device.update_time:
|
||||
device.update_time = int(dt_util.as_timestamp(device.update_time))
|
||||
device.support_local = details.get("support_local")
|
||||
device.local_strategy = details.get("local_strategy")
|
||||
device.mqtt_connected = details.get("mqtt_connected")
|
||||
|
||||
device.function = {
|
||||
key: DeviceFunction(
|
||||
code=key,
|
||||
type=value["type"],
|
||||
values=(
|
||||
values
|
||||
if isinstance(values := value["value"], str)
|
||||
else json_dumps(values)
|
||||
),
|
||||
)
|
||||
for key, value in details["function"].items()
|
||||
}
|
||||
device.status_range = {
|
||||
key: DeviceStatusRange(
|
||||
code=key,
|
||||
report_type=value.get("report_type"),
|
||||
type=value["type"],
|
||||
values=(
|
||||
values
|
||||
if isinstance(values := value["value"], str)
|
||||
else json_dumps(values)
|
||||
),
|
||||
)
|
||||
for key, value in details["status_range"].items()
|
||||
}
|
||||
device.status = details["status"]
|
||||
for key, value in device.status.items():
|
||||
# Some devices do not provide a status_range for all status DPs
|
||||
# Others set the type as String in status_range and as Json in function
|
||||
if ((dp_type := device.status_range.get(key)) and dp_type.type == "Json") or (
|
||||
(dp_type := device.function.get(key)) and dp_type.type == "Json"
|
||||
):
|
||||
device.status[key] = json_dumps(value)
|
||||
if value == "**REDACTED**":
|
||||
# It was redacted, which may cause issue with b64decode
|
||||
device.status[key] = ""
|
||||
return device
|
||||
|
||||
|
||||
def create_listener(hass: HomeAssistant, manager: Manager) -> MockDeviceListener:
|
||||
"""Create a DeviceListener for testing."""
|
||||
listener = MockDeviceListener(hass, manager)
|
||||
manager.add_device_listener(listener)
|
||||
return listener
|
||||
|
||||
|
||||
def create_manager(
|
||||
terminal_id: str = "7cd96aff-6ec8-4006-b093-3dbff7947591",
|
||||
) -> Manager:
|
||||
"""Create a Manager for testing."""
|
||||
manager = MagicMock(spec=Manager)
|
||||
manager.device_map = {}
|
||||
manager.mq = MagicMock()
|
||||
manager.mq.client = MagicMock()
|
||||
manager.mq.client.is_connected = MagicMock(return_value=True)
|
||||
manager.customer_api = MagicMock(spec=CustomerApi)
|
||||
# Meaningless URL / UUIDs
|
||||
manager.customer_api.endpoint = "https://apigw.tuyaeu.com"
|
||||
manager.terminal_id = terminal_id
|
||||
return manager
|
||||
|
||||
|
||||
async def initialize_entry(
|
||||
hass: HomeAssistant,
|
||||
mock_manager: Manager,
|
||||
|
||||
@@ -6,13 +6,7 @@ from collections.abc import Generator
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from tuya_sharing import (
|
||||
CustomerApi,
|
||||
CustomerDevice,
|
||||
DeviceFunction,
|
||||
DeviceStatusRange,
|
||||
Manager,
|
||||
)
|
||||
from tuya_sharing import CustomerDevice, Manager
|
||||
|
||||
from homeassistant.components.tuya.const import (
|
||||
CONF_ENDPOINT,
|
||||
@@ -22,12 +16,16 @@ from homeassistant.components.tuya.const import (
|
||||
DOMAIN,
|
||||
)
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.json import json_dumps
|
||||
from homeassistant.util import dt as dt_util
|
||||
|
||||
from . import DEVICE_MOCKS, MockDeviceListener
|
||||
from . import (
|
||||
DEVICE_MOCKS,
|
||||
MockDeviceListener,
|
||||
create_device,
|
||||
create_listener,
|
||||
create_manager,
|
||||
)
|
||||
|
||||
from tests.common import MockConfigEntry, async_load_json_object_fixture
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -108,17 +106,8 @@ def mock_tuya_login_control() -> Generator[MagicMock]:
|
||||
|
||||
@pytest.fixture
|
||||
def mock_manager() -> Manager:
|
||||
"""Mock Tuya Manager."""
|
||||
manager = MagicMock(spec=Manager)
|
||||
manager.device_map = {}
|
||||
manager.mq = MagicMock()
|
||||
manager.mq.client = MagicMock()
|
||||
manager.mq.client.is_connected = MagicMock(return_value=True)
|
||||
manager.customer_api = MagicMock(spec=CustomerApi)
|
||||
# Meaningless URL / UUIDs
|
||||
manager.customer_api.endpoint = "https://apigw.tuyaeu.com"
|
||||
manager.terminal_id = "7cd96aff-6ec8-4006-b093-3dbff7947591"
|
||||
return manager
|
||||
"""Fixture for Tuya Manager."""
|
||||
return create_manager()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -137,7 +126,7 @@ async def mock_devices(hass: HomeAssistant) -> list[CustomerDevice]:
|
||||
|
||||
Use this to generate global snapshots for each platform.
|
||||
"""
|
||||
return [await _create_device(hass, device_code) for device_code in DEVICE_MOCKS]
|
||||
return [await create_device(hass, device_code) for device_code in DEVICE_MOCKS]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -146,81 +135,10 @@ async def mock_device(hass: HomeAssistant, mock_device_code: str) -> CustomerDev
|
||||
|
||||
Use this for testing behavior on a specific device.
|
||||
"""
|
||||
return await _create_device(hass, mock_device_code)
|
||||
|
||||
|
||||
async def _create_device(hass: HomeAssistant, mock_device_code: str) -> CustomerDevice:
|
||||
"""Mock a Tuya CustomerDevice."""
|
||||
details = await async_load_json_object_fixture(
|
||||
hass, f"{mock_device_code}.json", DOMAIN
|
||||
)
|
||||
device = MagicMock(spec=CustomerDevice)
|
||||
|
||||
# Use reverse of the product_id for testing
|
||||
device.id = mock_device_code.replace("_", "")[::-1]
|
||||
|
||||
device.name = details["name"]
|
||||
device.category = details["category"]
|
||||
device.product_id = details["product_id"]
|
||||
device.product_name = details["product_name"]
|
||||
device.online = details["online"]
|
||||
device.sub = details.get("sub")
|
||||
device.time_zone = details.get("time_zone")
|
||||
device.active_time = details.get("active_time")
|
||||
if device.active_time:
|
||||
device.active_time = int(dt_util.as_timestamp(device.active_time))
|
||||
device.create_time = details.get("create_time")
|
||||
if device.create_time:
|
||||
device.create_time = int(dt_util.as_timestamp(device.create_time))
|
||||
device.update_time = details.get("update_time")
|
||||
if device.update_time:
|
||||
device.update_time = int(dt_util.as_timestamp(device.update_time))
|
||||
device.support_local = details.get("support_local")
|
||||
device.local_strategy = details.get("local_strategy")
|
||||
device.mqtt_connected = details.get("mqtt_connected")
|
||||
|
||||
device.function = {
|
||||
key: DeviceFunction(
|
||||
code=key,
|
||||
type=value["type"],
|
||||
values=(
|
||||
values
|
||||
if isinstance(values := value["value"], str)
|
||||
else json_dumps(values)
|
||||
),
|
||||
)
|
||||
for key, value in details["function"].items()
|
||||
}
|
||||
device.status_range = {
|
||||
key: DeviceStatusRange(
|
||||
code=key,
|
||||
report_type=value.get("report_type"),
|
||||
type=value["type"],
|
||||
values=(
|
||||
values
|
||||
if isinstance(values := value["value"], str)
|
||||
else json_dumps(values)
|
||||
),
|
||||
)
|
||||
for key, value in details["status_range"].items()
|
||||
}
|
||||
device.status = details["status"]
|
||||
for key, value in device.status.items():
|
||||
# Some devices do not provide a status_range for all status DPs
|
||||
# Others set the type as String in status_range and as Json in function
|
||||
if ((dp_type := device.status_range.get(key)) and dp_type.type == "Json") or (
|
||||
(dp_type := device.function.get(key)) and dp_type.type == "Json"
|
||||
):
|
||||
device.status[key] = json_dumps(value)
|
||||
if value == "**REDACTED**":
|
||||
# It was redacted, which may cause issue with b64decode
|
||||
device.status[key] = ""
|
||||
return device
|
||||
return await create_device(hass, mock_device_code)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_listener(hass: HomeAssistant, mock_manager: Manager) -> MockDeviceListener:
|
||||
"""Create a DeviceListener for testing."""
|
||||
listener = MockDeviceListener(hass, mock_manager)
|
||||
mock_manager.add_device_listener(listener)
|
||||
return listener
|
||||
"""Fixture for Tuya DeviceListener."""
|
||||
return create_listener(hass, mock_manager)
|
||||
|
||||
@@ -2,24 +2,108 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
from tuya_sharing import CustomerDevice, Manager
|
||||
|
||||
from homeassistant.components.tuya.const import DOMAIN
|
||||
from homeassistant.components.tuya.const import (
|
||||
CONF_ENDPOINT,
|
||||
CONF_TERMINAL_ID,
|
||||
CONF_TOKEN_INFO,
|
||||
CONF_USER_CODE,
|
||||
DOMAIN,
|
||||
)
|
||||
from homeassistant.components.tuya.diagnostics import _REDACTED_DPCODES
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import device_registry as dr, entity_registry as er
|
||||
|
||||
from . import DEVICE_MOCKS, initialize_entry
|
||||
from . import DEVICE_MOCKS, create_device, create_manager, initialize_entry
|
||||
|
||||
from tests.common import MockConfigEntry, async_load_json_object_fixture
|
||||
|
||||
|
||||
async def test_registry_cleanup(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
device_registry: dr.DeviceRegistry,
|
||||
) -> None:
|
||||
"""Ensure no-longer-present devices are removed from the device registry."""
|
||||
# Initialize with two devices
|
||||
main_manager = create_manager()
|
||||
main_device = await create_device(hass, "mcs_8yhypbo7")
|
||||
second_device = await create_device(hass, "clkg_y7j64p60glp8qpx7")
|
||||
await initialize_entry(
|
||||
hass, main_manager, mock_config_entry, [main_device, second_device]
|
||||
)
|
||||
|
||||
# Initialize should have two devices
|
||||
all_entries = dr.async_entries_for_config_entry(
|
||||
device_registry, mock_config_entry.entry_id
|
||||
)
|
||||
assert len(all_entries) == 2
|
||||
|
||||
# Now remove the second device from the manager and re-initialize
|
||||
del main_manager.device_map[second_device.id]
|
||||
with patch("homeassistant.components.tuya.Manager", return_value=main_manager):
|
||||
await hass.config_entries.async_reload(mock_config_entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# Only the main device should remain
|
||||
all_entries = dr.async_entries_for_config_entry(
|
||||
device_registry, mock_config_entry.entry_id
|
||||
)
|
||||
assert len(all_entries) == 1
|
||||
assert all_entries[0].identifiers == {(DOMAIN, main_device.id)}
|
||||
|
||||
|
||||
async def test_registry_cleanup_multiple_entries(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
entity_registry: er.EntityRegistry,
|
||||
) -> None:
|
||||
"""Ensure multiple config entries do not remove items from other entries."""
|
||||
main_entity_id = "sensor.boite_aux_lettres_arriere_battery"
|
||||
second_entity_id = "binary_sensor.window_downstairs_door"
|
||||
|
||||
main_manager = create_manager()
|
||||
main_device = await create_device(hass, "mcs_8yhypbo7")
|
||||
await initialize_entry(hass, main_manager, mock_config_entry, main_device)
|
||||
|
||||
# Ensure initial setup is correct (main present, second absent)
|
||||
assert hass.states.get(main_entity_id)
|
||||
assert entity_registry.async_get(main_entity_id)
|
||||
assert not hass.states.get(second_entity_id)
|
||||
assert not entity_registry.async_get(second_entity_id)
|
||||
|
||||
# Create a second config entry
|
||||
second_config_entry = MockConfigEntry(
|
||||
title="Test Tuya entry",
|
||||
domain=DOMAIN,
|
||||
data={
|
||||
CONF_ENDPOINT: "test_endpoint",
|
||||
CONF_TERMINAL_ID: "test_terminal",
|
||||
CONF_TOKEN_INFO: "test_token",
|
||||
CONF_USER_CODE: "test_user_code",
|
||||
},
|
||||
unique_id="56789",
|
||||
)
|
||||
second_manager = create_manager()
|
||||
second_device = await create_device(hass, "mcs_oxslv1c9")
|
||||
await initialize_entry(hass, second_manager, second_config_entry, second_device)
|
||||
|
||||
# Ensure setup is correct (both present)
|
||||
assert hass.states.get(main_entity_id)
|
||||
assert entity_registry.async_get(main_entity_id)
|
||||
assert hass.states.get(second_entity_id)
|
||||
assert entity_registry.async_get(second_entity_id)
|
||||
|
||||
|
||||
async def test_device_registry(
|
||||
hass: HomeAssistant,
|
||||
mock_manager: Manager,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
mock_devices: CustomerDevice,
|
||||
mock_devices: list[CustomerDevice],
|
||||
device_registry: dr.DeviceRegistry,
|
||||
entity_registry: er.EntityRegistry,
|
||||
snapshot: SnapshotAssertion,
|
||||
|
||||
Reference in New Issue
Block a user