diff --git a/homeassistant/components/frontend/pr_download.py b/homeassistant/components/frontend/pr_download.py index 4de28d7c405..1d4c28a0471 100644 --- a/homeassistant/components/frontend/pr_download.py +++ b/homeassistant/components/frontend/pr_download.py @@ -43,13 +43,13 @@ ERROR_RATE_LIMIT = ( ) -async def _get_pr_head_sha(client: GitHubAPI, pr_number: int) -> str: - """Get the head SHA for the PR.""" +async def _get_pr_shas(client: GitHubAPI, pr_number: int) -> tuple[str, str]: + """Get the head and base SHAs for a PR.""" try: response = await client.generic( endpoint=f"/repos/home-assistant/frontend/pulls/{pr_number}", ) - return str(response.data["head"]["sha"]) + return str(response.data["head"]["sha"]), str(response.data["base"]["sha"]) except GitHubAuthenticationException as err: raise HomeAssistantError(ERROR_INVALID_TOKEN) from err except (GitHubRatelimitException, GitHubPermissionException) as err: @@ -137,9 +137,9 @@ async def _download_artifact_data( def _extract_artifact( artifact_data: bytes, cache_dir: pathlib.Path, - head_sha: str, + cache_key: str, ) -> None: - """Extract artifact and save SHA (runs in executor).""" + """Extract artifact and save cache key (runs in executor).""" frontend_dir = cache_dir / "hass_frontend" if cache_dir.exists(): @@ -163,9 +163,8 @@ def _extract_artifact( ) zip_file.extractall(str(frontend_dir)) - # Save the commit SHA for cache validation sha_file = cache_dir / ".sha" - sha_file.write_text(head_sha) + sha_file.write_text(cache_key) async def download_pr_artifact( @@ -186,27 +185,29 @@ async def download_pr_artifact( client = GitHubAPI(token=github_token, session=session) - head_sha = await _get_pr_head_sha(client, pr_number) + head_sha, base_sha = await _get_pr_shas(client, pr_number) + cache_key = f"{head_sha}:{base_sha}" frontend_dir = tmp_dir / "hass_frontend" sha_file = tmp_dir / ".sha" if frontend_dir.exists() and sha_file.exists(): try: - cached_sha = await hass.async_add_executor_job(sha_file.read_text) - if cached_sha.strip() == head_sha: + cached_key = await hass.async_add_executor_job(sha_file.read_text) + cached_key = cached_key.strip() + if cached_key == cache_key: _LOGGER.info( "Using cached PR #%s (commit %s) from %s", pr_number, - head_sha[:8], + cache_key, tmp_dir, ) return tmp_dir _LOGGER.info( - "PR #%s has new commits (cached: %s, current: %s), re-downloading", + "PR #%s cache outdated (cached: %s, current: %s), re-downloading", pr_number, - cached_sha[:8], - head_sha[:8], + cached_key, + cache_key, ) except OSError as err: _LOGGER.debug("Failed to read cache SHA file: %s", err) @@ -218,7 +219,7 @@ async def download_pr_artifact( try: await hass.async_add_executor_job( - _extract_artifact, artifact_data, tmp_dir, head_sha + _extract_artifact, artifact_data, tmp_dir, cache_key ) except zipfile.BadZipFile as err: raise HomeAssistantError( diff --git a/tests/components/frontend/conftest.py b/tests/components/frontend/conftest.py index 7ec108a316b..5191d6bbfa6 100644 --- a/tests/components/frontend/conftest.py +++ b/tests/components/frontend/conftest.py @@ -19,7 +19,10 @@ def mock_github_api() -> Generator[AsyncMock]: # Mock PR response pr_response = AsyncMock() - pr_response.data = {"head": {"sha": "abc123def456"}} + pr_response.data = { + "head": {"sha": "abc123def456"}, + "base": {"sha": "base789abc012"}, + } # Mock workflow runs response workflow_response = AsyncMock() diff --git a/tests/components/frontend/test_pr_download.py b/tests/components/frontend/test_pr_download.py index 0af85a66a15..352040917b1 100644 --- a/tests/components/frontend/test_pr_download.py +++ b/tests/components/frontend/test_pr_download.py @@ -64,7 +64,7 @@ async def test_pr_download_uses_cache( frontend_dir = pr_cache_dir / "hass_frontend" frontend_dir.mkdir(parents=True) (frontend_dir / "index.html").write_text("test") - (pr_cache_dir / ".sha").write_text("abc123def456") + (pr_cache_dir / ".sha").write_text("abc123def456:base789abc012") with patch( "homeassistant.components.frontend.pr_download.GitHubAPI" @@ -73,7 +73,10 @@ async def test_pr_download_uses_cache( mock_gh_class.return_value = mock_client pr_response = AsyncMock() - pr_response.data = {"head": {"sha": "abc123def456"}} + pr_response.data = { + "head": {"sha": "abc123def456"}, + "base": {"sha": "base789abc012"}, + } mock_client.generic.return_value = pr_response config = { @@ -93,21 +96,29 @@ async def test_pr_download_uses_cache( assert "pulls" in str(calls[0]) +@pytest.mark.parametrize( + ("cache_key"), + [ + ("old_head_sha:base789abc012"), + ("abc123def456:old_base_sha"), + ], +) async def test_pr_download_cache_invalidated( hass: HomeAssistant, tmp_path: Path, mock_github_api, aioclient_mock: AiohttpClientMocker, mock_zipfile, + cache_key: str, ) -> None: - """Test that cache is invalidated when commit changes.""" + """Test that cache is invalidated when head commit changes.""" hass.config.config_dir = str(tmp_path) pr_cache_dir = tmp_path / ".cache" / "frontend" / "development_artifacts" frontend_dir = pr_cache_dir / "hass_frontend" frontend_dir.mkdir(parents=True) (frontend_dir / "index.html").write_text("test") - (pr_cache_dir / ".sha").write_text("old_commit_sha") + (pr_cache_dir / ".sha").write_text(cache_key) aioclient_mock.get( "https://api.github.com/artifact/download", @@ -124,7 +135,7 @@ async def test_pr_download_cache_invalidated( assert await async_setup_component(hass, DOMAIN, config) await hass.async_block_till_done() - # Should download - commit changed + # Should download - head commit changed assert len(aioclient_mock.mock_calls) == 1 @@ -261,7 +272,10 @@ async def test_pr_download_artifact_search_github_errors( mock_gh_class.return_value = mock_client pr_response = AsyncMock() - pr_response.data = {"head": {"sha": "abc123def456"}} + pr_response.data = { + "head": {"sha": "abc123def456"}, + "base": {"sha": "base789abc012"}, + } async def generic_side_effect(endpoint, **_kwargs): if "pulls" in endpoint: @@ -299,7 +313,10 @@ async def test_pr_download_artifact_not_found( mock_gh_class.return_value = mock_client pr_response = AsyncMock() - pr_response.data = {"head": {"sha": "abc123def456"}} + pr_response.data = { + "head": {"sha": "abc123def456"}, + "base": {"sha": "base789abc012"}, + } workflow_response = AsyncMock() workflow_response.data = {"workflow_runs": []}