diff --git a/homeassistant/components/tuya/__init__.py b/homeassistant/components/tuya/__init__.py index 70c517a3cc3..0555f8a145a 100644 --- a/homeassistant/components/tuya/__init__.py +++ b/homeassistant/components/tuya/__init__.py @@ -82,7 +82,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: TuyaConfigEntry) -> bool entry.runtime_data = HomeAssistantTuyaData(manager=manager, listener=listener) # Cleanup device registry - await cleanup_device_registry(hass, manager) + await cleanup_device_registry(hass, manager, entry) # Register known device IDs device_registry = dr.async_get(hass) @@ -114,13 +114,19 @@ async def async_setup_entry(hass: HomeAssistant, entry: TuyaConfigEntry) -> bool return True -async def cleanup_device_registry(hass: HomeAssistant, device_manager: Manager) -> None: - """Remove deleted device registry entry if there are no remaining entities.""" +async def cleanup_device_registry( + hass: HomeAssistant, device_manager: Manager, entry: TuyaConfigEntry +) -> None: + """Unlink device registry entry if there are no remaining entities.""" device_registry = dr.async_get(hass) - for dev_id, device_entry in list(device_registry.devices.items()): + for device_entry in dr.async_entries_for_config_entry( + device_registry, entry.entry_id + ): for item in device_entry.identifiers: if item[0] == DOMAIN and item[1] not in device_manager.device_map: - device_registry.async_remove_device(dev_id) + device_registry.async_update_device( + device_entry.id, remove_config_entry_id=entry.entry_id + ) break diff --git a/tests/components/tuya/__init__.py b/tests/components/tuya/__init__.py index 80fa925b4d0..21815ddb99c 100644 --- a/tests/components/tuya/__init__.py +++ b/tests/components/tuya/__init__.py @@ -4,15 +4,23 @@ from __future__ import annotations import pathlib from typing import Any -from unittest.mock import patch +from unittest.mock import MagicMock, patch from freezegun.api import FrozenDateTimeFactory -from tuya_sharing import CustomerDevice, Manager +from tuya_sharing import ( + CustomerApi, + CustomerDevice, + DeviceFunction, + DeviceStatusRange, + Manager, +) -from homeassistant.components.tuya import DeviceListener +from homeassistant.components.tuya import DOMAIN, DeviceListener from homeassistant.core import HomeAssistant +from homeassistant.helpers.json import json_dumps +from homeassistant.util import dt as dt_util -from tests.common import MockConfigEntry +from tests.common import MockConfigEntry, async_load_json_object_fixture FIXTURES_DIR = pathlib.Path(__file__).parent / "fixtures" DEVICE_MOCKS = sorted( @@ -45,6 +53,98 @@ class MockDeviceListener(DeviceListener): await hass.async_block_till_done() +async def create_device(hass: HomeAssistant, mock_device_code: str) -> CustomerDevice: + """Create a CustomerDevice for testing.""" + details = await async_load_json_object_fixture( + hass, f"{mock_device_code}.json", DOMAIN + ) + device = MagicMock(spec=CustomerDevice) + + # Use reverse of the product_id for testing + device.id = mock_device_code.replace("_", "")[::-1] + + device.name = details["name"] + device.category = details["category"] + device.product_id = details["product_id"] + device.product_name = details["product_name"] + device.online = details["online"] + device.sub = details.get("sub") + device.time_zone = details.get("time_zone") + device.active_time = details.get("active_time") + if device.active_time: + device.active_time = int(dt_util.as_timestamp(device.active_time)) + device.create_time = details.get("create_time") + if device.create_time: + device.create_time = int(dt_util.as_timestamp(device.create_time)) + device.update_time = details.get("update_time") + if device.update_time: + device.update_time = int(dt_util.as_timestamp(device.update_time)) + device.support_local = details.get("support_local") + device.local_strategy = details.get("local_strategy") + device.mqtt_connected = details.get("mqtt_connected") + + device.function = { + key: DeviceFunction( + code=key, + type=value["type"], + values=( + values + if isinstance(values := value["value"], str) + else json_dumps(values) + ), + ) + for key, value in details["function"].items() + } + device.status_range = { + key: DeviceStatusRange( + code=key, + report_type=value.get("report_type"), + type=value["type"], + values=( + values + if isinstance(values := value["value"], str) + else json_dumps(values) + ), + ) + for key, value in details["status_range"].items() + } + device.status = details["status"] + for key, value in device.status.items(): + # Some devices do not provide a status_range for all status DPs + # Others set the type as String in status_range and as Json in function + if ((dp_type := device.status_range.get(key)) and dp_type.type == "Json") or ( + (dp_type := device.function.get(key)) and dp_type.type == "Json" + ): + device.status[key] = json_dumps(value) + if value == "**REDACTED**": + # It was redacted, which may cause issue with b64decode + device.status[key] = "" + return device + + +def create_listener(hass: HomeAssistant, manager: Manager) -> MockDeviceListener: + """Create a DeviceListener for testing.""" + listener = MockDeviceListener(hass, manager) + manager.add_device_listener(listener) + return listener + + +def create_manager( + terminal_id: str = "7cd96aff-6ec8-4006-b093-3dbff7947591", +) -> Manager: + """Create a Manager for testing.""" + manager = MagicMock(spec=Manager) + manager.device_map = {} + manager.mq = MagicMock() + manager.mq.client = MagicMock() + manager.mq.client.is_connected = MagicMock(return_value=True) + manager.customer_api = MagicMock(spec=CustomerApi) + # Meaningless URL / UUIDs + manager.customer_api.endpoint = "https://apigw.tuyaeu.com" + manager.terminal_id = terminal_id + return manager + + async def initialize_entry( hass: HomeAssistant, mock_manager: Manager, diff --git a/tests/components/tuya/conftest.py b/tests/components/tuya/conftest.py index b0df4e881f3..41101b45045 100644 --- a/tests/components/tuya/conftest.py +++ b/tests/components/tuya/conftest.py @@ -6,13 +6,7 @@ from collections.abc import Generator from unittest.mock import MagicMock, patch import pytest -from tuya_sharing import ( - CustomerApi, - CustomerDevice, - DeviceFunction, - DeviceStatusRange, - Manager, -) +from tuya_sharing import CustomerDevice, Manager from homeassistant.components.tuya.const import ( CONF_ENDPOINT, @@ -22,12 +16,16 @@ from homeassistant.components.tuya.const import ( DOMAIN, ) from homeassistant.core import HomeAssistant -from homeassistant.helpers.json import json_dumps -from homeassistant.util import dt as dt_util -from . import DEVICE_MOCKS, MockDeviceListener +from . import ( + DEVICE_MOCKS, + MockDeviceListener, + create_device, + create_listener, + create_manager, +) -from tests.common import MockConfigEntry, async_load_json_object_fixture +from tests.common import MockConfigEntry @pytest.fixture @@ -108,17 +106,8 @@ def mock_tuya_login_control() -> Generator[MagicMock]: @pytest.fixture def mock_manager() -> Manager: - """Mock Tuya Manager.""" - manager = MagicMock(spec=Manager) - manager.device_map = {} - manager.mq = MagicMock() - manager.mq.client = MagicMock() - manager.mq.client.is_connected = MagicMock(return_value=True) - manager.customer_api = MagicMock(spec=CustomerApi) - # Meaningless URL / UUIDs - manager.customer_api.endpoint = "https://apigw.tuyaeu.com" - manager.terminal_id = "7cd96aff-6ec8-4006-b093-3dbff7947591" - return manager + """Fixture for Tuya Manager.""" + return create_manager() @pytest.fixture @@ -137,7 +126,7 @@ async def mock_devices(hass: HomeAssistant) -> list[CustomerDevice]: Use this to generate global snapshots for each platform. """ - return [await _create_device(hass, device_code) for device_code in DEVICE_MOCKS] + return [await create_device(hass, device_code) for device_code in DEVICE_MOCKS] @pytest.fixture @@ -146,81 +135,10 @@ async def mock_device(hass: HomeAssistant, mock_device_code: str) -> CustomerDev Use this for testing behavior on a specific device. """ - return await _create_device(hass, mock_device_code) - - -async def _create_device(hass: HomeAssistant, mock_device_code: str) -> CustomerDevice: - """Mock a Tuya CustomerDevice.""" - details = await async_load_json_object_fixture( - hass, f"{mock_device_code}.json", DOMAIN - ) - device = MagicMock(spec=CustomerDevice) - - # Use reverse of the product_id for testing - device.id = mock_device_code.replace("_", "")[::-1] - - device.name = details["name"] - device.category = details["category"] - device.product_id = details["product_id"] - device.product_name = details["product_name"] - device.online = details["online"] - device.sub = details.get("sub") - device.time_zone = details.get("time_zone") - device.active_time = details.get("active_time") - if device.active_time: - device.active_time = int(dt_util.as_timestamp(device.active_time)) - device.create_time = details.get("create_time") - if device.create_time: - device.create_time = int(dt_util.as_timestamp(device.create_time)) - device.update_time = details.get("update_time") - if device.update_time: - device.update_time = int(dt_util.as_timestamp(device.update_time)) - device.support_local = details.get("support_local") - device.local_strategy = details.get("local_strategy") - device.mqtt_connected = details.get("mqtt_connected") - - device.function = { - key: DeviceFunction( - code=key, - type=value["type"], - values=( - values - if isinstance(values := value["value"], str) - else json_dumps(values) - ), - ) - for key, value in details["function"].items() - } - device.status_range = { - key: DeviceStatusRange( - code=key, - report_type=value.get("report_type"), - type=value["type"], - values=( - values - if isinstance(values := value["value"], str) - else json_dumps(values) - ), - ) - for key, value in details["status_range"].items() - } - device.status = details["status"] - for key, value in device.status.items(): - # Some devices do not provide a status_range for all status DPs - # Others set the type as String in status_range and as Json in function - if ((dp_type := device.status_range.get(key)) and dp_type.type == "Json") or ( - (dp_type := device.function.get(key)) and dp_type.type == "Json" - ): - device.status[key] = json_dumps(value) - if value == "**REDACTED**": - # It was redacted, which may cause issue with b64decode - device.status[key] = "" - return device + return await create_device(hass, mock_device_code) @pytest.fixture def mock_listener(hass: HomeAssistant, mock_manager: Manager) -> MockDeviceListener: - """Create a DeviceListener for testing.""" - listener = MockDeviceListener(hass, mock_manager) - mock_manager.add_device_listener(listener) - return listener + """Fixture for Tuya DeviceListener.""" + return create_listener(hass, mock_manager) diff --git a/tests/components/tuya/test_init.py b/tests/components/tuya/test_init.py index 9bd08a746a1..3045cca3349 100644 --- a/tests/components/tuya/test_init.py +++ b/tests/components/tuya/test_init.py @@ -2,24 +2,108 @@ from __future__ import annotations +from unittest.mock import patch + from syrupy.assertion import SnapshotAssertion from tuya_sharing import CustomerDevice, Manager -from homeassistant.components.tuya.const import DOMAIN +from homeassistant.components.tuya.const import ( + CONF_ENDPOINT, + CONF_TERMINAL_ID, + CONF_TOKEN_INFO, + CONF_USER_CODE, + DOMAIN, +) from homeassistant.components.tuya.diagnostics import _REDACTED_DPCODES from homeassistant.core import HomeAssistant from homeassistant.helpers import device_registry as dr, entity_registry as er -from . import DEVICE_MOCKS, initialize_entry +from . import DEVICE_MOCKS, create_device, create_manager, initialize_entry from tests.common import MockConfigEntry, async_load_json_object_fixture +async def test_registry_cleanup( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + device_registry: dr.DeviceRegistry, +) -> None: + """Ensure no-longer-present devices are removed from the device registry.""" + # Initialize with two devices + main_manager = create_manager() + main_device = await create_device(hass, "mcs_8yhypbo7") + second_device = await create_device(hass, "clkg_y7j64p60glp8qpx7") + await initialize_entry( + hass, main_manager, mock_config_entry, [main_device, second_device] + ) + + # Initialize should have two devices + all_entries = dr.async_entries_for_config_entry( + device_registry, mock_config_entry.entry_id + ) + assert len(all_entries) == 2 + + # Now remove the second device from the manager and re-initialize + del main_manager.device_map[second_device.id] + with patch("homeassistant.components.tuya.Manager", return_value=main_manager): + await hass.config_entries.async_reload(mock_config_entry.entry_id) + await hass.async_block_till_done() + + # Only the main device should remain + all_entries = dr.async_entries_for_config_entry( + device_registry, mock_config_entry.entry_id + ) + assert len(all_entries) == 1 + assert all_entries[0].identifiers == {(DOMAIN, main_device.id)} + + +async def test_registry_cleanup_multiple_entries( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + entity_registry: er.EntityRegistry, +) -> None: + """Ensure multiple config entries do not remove items from other entries.""" + main_entity_id = "sensor.boite_aux_lettres_arriere_battery" + second_entity_id = "binary_sensor.window_downstairs_door" + + main_manager = create_manager() + main_device = await create_device(hass, "mcs_8yhypbo7") + await initialize_entry(hass, main_manager, mock_config_entry, main_device) + + # Ensure initial setup is correct (main present, second absent) + assert hass.states.get(main_entity_id) + assert entity_registry.async_get(main_entity_id) + assert not hass.states.get(second_entity_id) + assert not entity_registry.async_get(second_entity_id) + + # Create a second config entry + second_config_entry = MockConfigEntry( + title="Test Tuya entry", + domain=DOMAIN, + data={ + CONF_ENDPOINT: "test_endpoint", + CONF_TERMINAL_ID: "test_terminal", + CONF_TOKEN_INFO: "test_token", + CONF_USER_CODE: "test_user_code", + }, + unique_id="56789", + ) + second_manager = create_manager() + second_device = await create_device(hass, "mcs_oxslv1c9") + await initialize_entry(hass, second_manager, second_config_entry, second_device) + + # Ensure setup is correct (both present) + assert hass.states.get(main_entity_id) + assert entity_registry.async_get(main_entity_id) + assert hass.states.get(second_entity_id) + assert entity_registry.async_get(second_entity_id) + + async def test_device_registry( hass: HomeAssistant, mock_manager: Manager, mock_config_entry: MockConfigEntry, - mock_devices: CustomerDevice, + mock_devices: list[CustomerDevice], device_registry: dr.DeviceRegistry, entity_registry: er.EntityRegistry, snapshot: SnapshotAssertion,