mirror of
https://github.com/Electric-Special/ha-core.git
synced 2026-03-21 08:06:00 +01:00
Fix user store not loaded on restart (#157616)
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable, Coroutine
|
||||
from functools import wraps
|
||||
from typing import Any
|
||||
@@ -15,7 +16,9 @@ from homeassistant.helpers import singleton
|
||||
from homeassistant.helpers.storage import Store
|
||||
from homeassistant.util.hass_dict import HassKey
|
||||
|
||||
DATA_STORAGE: HassKey[dict[str, UserStore]] = HassKey("frontend_storage")
|
||||
DATA_STORAGE: HassKey[dict[str, asyncio.Future[UserStore]]] = HassKey(
|
||||
"frontend_storage"
|
||||
)
|
||||
DATA_SYSTEM_STORAGE: HassKey[SystemStore] = HassKey("frontend_system_storage")
|
||||
STORAGE_VERSION_USER_DATA = 1
|
||||
STORAGE_VERSION_SYSTEM_DATA = 1
|
||||
@@ -34,11 +37,18 @@ async def async_setup_frontend_storage(hass: HomeAssistant) -> None:
|
||||
async def async_user_store(hass: HomeAssistant, user_id: str) -> UserStore:
|
||||
"""Access a user store."""
|
||||
stores = hass.data.setdefault(DATA_STORAGE, {})
|
||||
if (store := stores.get(user_id)) is None:
|
||||
store = stores[user_id] = UserStore(hass, user_id)
|
||||
await store.async_load()
|
||||
if (future := stores.get(user_id)) is None:
|
||||
future = stores[user_id] = hass.loop.create_future()
|
||||
store = UserStore(hass, user_id)
|
||||
try:
|
||||
await store.async_load()
|
||||
except BaseException as ex:
|
||||
del stores[user_id]
|
||||
future.set_exception(ex)
|
||||
raise
|
||||
future.set_result(store)
|
||||
|
||||
return store
|
||||
return await future
|
||||
|
||||
|
||||
class UserStore:
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
"""The tests for frontend storage."""
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.components.frontend import DOMAIN
|
||||
from homeassistant.components.frontend.storage import async_user_store
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.storage import Store
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import MockUser
|
||||
@@ -572,3 +576,92 @@ async def test_set_system_data_requires_admin(
|
||||
assert not res["success"], res
|
||||
assert res["error"]["code"] == "unauthorized"
|
||||
assert res["error"]["message"] == "Unauthorized"
|
||||
|
||||
|
||||
async def test_user_store_concurrent_access(
|
||||
hass: HomeAssistant,
|
||||
hass_admin_user: MockUser,
|
||||
hass_storage: dict[str, Any],
|
||||
) -> None:
|
||||
"""Test that concurrent access to user store returns loaded data."""
|
||||
storage_key = f"{DOMAIN}.user_data_{hass_admin_user.id}"
|
||||
hass_storage[storage_key] = {
|
||||
"version": 1,
|
||||
"data": {"test-key": "test-value"},
|
||||
}
|
||||
|
||||
load_count = 0
|
||||
original_async_load = Store.async_load
|
||||
|
||||
async def slow_async_load(self: Store) -> Any:
|
||||
"""Simulate slow loading to trigger race condition."""
|
||||
nonlocal load_count
|
||||
load_count += 1
|
||||
await asyncio.sleep(0) # Yield to allow other coroutines to run
|
||||
return await original_async_load(self)
|
||||
|
||||
with patch.object(Store, "async_load", slow_async_load):
|
||||
# Request the same user store concurrently
|
||||
results = await asyncio.gather(
|
||||
async_user_store(hass, hass_admin_user.id),
|
||||
async_user_store(hass, hass_admin_user.id),
|
||||
async_user_store(hass, hass_admin_user.id),
|
||||
)
|
||||
|
||||
# All results should be the same store instance with loaded data
|
||||
assert results[0] is results[1] is results[2]
|
||||
assert results[0].data == {"test-key": "test-value"}
|
||||
# Store should only be loaded once due to Future synchronization
|
||||
assert load_count == 1
|
||||
|
||||
|
||||
async def test_user_store_load_error(
|
||||
hass: HomeAssistant,
|
||||
hass_admin_user: MockUser,
|
||||
) -> None:
|
||||
"""Test that load errors are propagated and allow retry."""
|
||||
|
||||
async def failing_async_load(self: Store) -> Any:
|
||||
"""Simulate a load failure."""
|
||||
raise OSError("Storage read error")
|
||||
|
||||
with (
|
||||
patch.object(Store, "async_load", failing_async_load),
|
||||
pytest.raises(OSError, match="Storage read error"),
|
||||
):
|
||||
await async_user_store(hass, hass_admin_user.id)
|
||||
|
||||
# After error, the future should be removed, allowing retry
|
||||
# This time without the patch, it should work (empty store)
|
||||
store = await async_user_store(hass, hass_admin_user.id)
|
||||
assert store.data == {}
|
||||
|
||||
|
||||
async def test_user_store_concurrent_load_error(
|
||||
hass: HomeAssistant,
|
||||
hass_admin_user: MockUser,
|
||||
) -> None:
|
||||
"""Test that concurrent callers all receive the same error."""
|
||||
|
||||
async def failing_async_load(self: Store) -> Any:
|
||||
"""Simulate a slow load failure."""
|
||||
await asyncio.sleep(0) # Yield to allow other coroutines to run
|
||||
raise OSError("Storage read error")
|
||||
|
||||
with patch.object(Store, "async_load", failing_async_load):
|
||||
results = await asyncio.gather(
|
||||
async_user_store(hass, hass_admin_user.id),
|
||||
async_user_store(hass, hass_admin_user.id),
|
||||
async_user_store(hass, hass_admin_user.id),
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
# All callers should receive the same OSError
|
||||
assert len(results) == 3
|
||||
for result in results:
|
||||
assert isinstance(result, OSError)
|
||||
assert str(result) == "Storage read error"
|
||||
|
||||
# After error, retry should work
|
||||
store = await async_user_store(hass, hass_admin_user.id)
|
||||
assert store.data == {}
|
||||
|
||||
Reference in New Issue
Block a user