diff --git a/homeassistant/components/cloudflare_r2/backup.py b/homeassistant/components/cloudflare_r2/backup.py index 1279c1a2ba4..cef9294182e 100644 --- a/homeassistant/components/cloudflare_r2/backup.py +++ b/homeassistant/components/cloudflare_r2/backup.py @@ -5,7 +5,7 @@ import functools import json import logging from time import time -from typing import Any +from typing import Any, cast from botocore.exceptions import BotoCoreError @@ -190,58 +190,77 @@ class R2BackupAgent(BackupAgent): :param open_stream: A function returning an async iterator that yields bytes. """ _LOGGER.debug("Starting multipart upload for %s", tar_filename) + key = self._with_prefix(tar_filename) multipart_upload = await self._client.create_multipart_upload( Bucket=self._bucket, - Key=self._with_prefix(tar_filename), + Key=key, ) upload_id = multipart_upload["UploadId"] try: parts: list[dict[str, Any]] = [] part_number = 1 buffer = bytearray() # bytes buffer to store the data + offset = 0 # start index of unread data inside buffer stream = await open_stream() async for chunk in stream: buffer.extend(chunk) - # upload parts of exactly MULTIPART_MIN_PART_SIZE_BYTES to ensure - # all non-trailing parts have the same size (required by S3/R2) - while len(buffer) >= MULTIPART_MIN_PART_SIZE_BYTES: - part_data = bytes(buffer[:MULTIPART_MIN_PART_SIZE_BYTES]) - del buffer[:MULTIPART_MIN_PART_SIZE_BYTES] + # Upload parts of exactly MULTIPART_MIN_PART_SIZE_BYTES to ensure + # all non-trailing parts have the same size (defensive implementation) + view = memoryview(buffer) + try: + while len(buffer) - offset >= MULTIPART_MIN_PART_SIZE_BYTES: + start = offset + end = offset + MULTIPART_MIN_PART_SIZE_BYTES + part_data = view[start:end] + offset = end - _LOGGER.debug( - "Uploading part number %d, size %d", - part_number, - len(part_data), - ) - part = await self._client.upload_part( - Bucket=self._bucket, - Key=self._with_prefix(tar_filename), - PartNumber=part_number, - UploadId=upload_id, - Body=part_data, - ) - parts.append({"PartNumber": part_number, "ETag": part["ETag"]}) - part_number += 1 + _LOGGER.debug( + "Uploading part number %d, size %d", + part_number, + len(part_data), + ) + part = await cast(Any, self._client).upload_part( + Bucket=self._bucket, + Key=key, + PartNumber=part_number, + UploadId=upload_id, + Body=part_data.tobytes(), + ) + parts.append({"PartNumber": part_number, "ETag": part["ETag"]}) + part_number += 1 + finally: + view.release() + + # Compact the buffer if the consumed offset has grown large enough. This + # avoids unnecessary memory copies when compacting after every part upload. + if offset and offset >= MULTIPART_MIN_PART_SIZE_BYTES: + buffer = bytearray(buffer[offset:]) + offset = 0 # Upload the final buffer as the last part (no minimum size requirement) - if buffer: + # Offset should be 0 after the last compaction, but we use it as the start + # index to be defensive in case the buffer was not compacted. + if offset < len(buffer): + remaining_data = memoryview(buffer)[offset:] _LOGGER.debug( - "Uploading final part number %d, size %d", part_number, len(buffer) + "Uploading final part number %d, size %d", + part_number, + len(remaining_data), ) - part = await self._client.upload_part( + part = await cast(Any, self._client).upload_part( Bucket=self._bucket, - Key=self._with_prefix(tar_filename), + Key=key, PartNumber=part_number, UploadId=upload_id, - Body=bytes(buffer), + Body=remaining_data.tobytes(), ) parts.append({"PartNumber": part_number, "ETag": part["ETag"]}) - await self._client.complete_multipart_upload( + await cast(Any, self._client).complete_multipart_upload( Bucket=self._bucket, - Key=self._with_prefix(tar_filename), + Key=key, UploadId=upload_id, MultipartUpload={"Parts": parts}, ) @@ -250,7 +269,7 @@ class R2BackupAgent(BackupAgent): try: await self._client.abort_multipart_upload( Bucket=self._bucket, - Key=self._with_prefix(tar_filename), + Key=key, UploadId=upload_id, ) except BotoCoreError: diff --git a/tests/components/cloudflare_r2/conftest.py b/tests/components/cloudflare_r2/conftest.py index b67039f393c..1a34e8264b3 100644 --- a/tests/components/cloudflare_r2/conftest.py +++ b/tests/components/cloudflare_r2/conftest.py @@ -11,7 +11,7 @@ from homeassistant.components.cloudflare_r2.backup import ( MULTIPART_MIN_PART_SIZE_BYTES, suggested_filenames, ) -from homeassistant.components.cloudflare_r2.const import DOMAIN +from homeassistant.components.cloudflare_r2.const import CONF_PREFIX, DOMAIN from .const import USER_INPUT @@ -79,3 +79,19 @@ def mock_config_entry() -> MockConfigEntry: domain=DOMAIN, data=USER_INPUT, ) + + +@pytest.fixture +def mock_config_entry_with_prefix( + mock_config_entry: MockConfigEntry, +) -> MockConfigEntry: + """Return a mocked config entry with a prefix configured.""" + data = dict(mock_config_entry.data) + data[CONF_PREFIX] = "ha/backups" + + return MockConfigEntry( + entry_id=mock_config_entry.entry_id, + title=mock_config_entry.title, + domain=mock_config_entry.domain, + data=data, + ) diff --git a/tests/components/cloudflare_r2/test_backup.py b/tests/components/cloudflare_r2/test_backup.py index c721468e80f..378b7b80c4a 100644 --- a/tests/components/cloudflare_r2/test_backup.py +++ b/tests/components/cloudflare_r2/test_backup.py @@ -1,18 +1,17 @@ """Test the Cloudflare R2 backup platform.""" -from collections.abc import AsyncGenerator +from collections.abc import AsyncGenerator, AsyncIterator from io import StringIO import json from time import time from unittest.mock import AsyncMock, Mock, patch -from botocore.exceptions import ConnectTimeoutError +from botocore.exceptions import BotoCoreError, ConnectTimeoutError import pytest from homeassistant.components.backup import DOMAIN as BACKUP_DOMAIN, AgentBackup from homeassistant.components.cloudflare_r2.backup import ( MULTIPART_MIN_PART_SIZE_BYTES, - BotoCoreError, R2BackupAgent, async_register_backup_agents_listener, suggested_filenames, @@ -521,3 +520,57 @@ async def test_listeners_get_cleaned_up(hass: HomeAssistant) -> None: remove_listener() assert DATA_BACKUP_AGENT_LISTENERS not in hass.data + + +async def test_multipart_upload_uses_prefix_for_all_calls( + hass: HomeAssistant, + mock_client: MagicMock, + mock_config_entry_with_prefix: MockConfigEntry, +) -> None: + """Test multipart upload uses the configured prefix for all S3 calls.""" + mock_config_entry_with_prefix.runtime_data = mock_client + agent = R2BackupAgent(hass, mock_config_entry_with_prefix) + + async def stream() -> AsyncIterator[bytes]: + # Force multipart: > MIN_PART_SIZE + yield b"x" * (MULTIPART_MIN_PART_SIZE_BYTES + 1) + + async def open_stream(): + return stream() + + await agent._upload_multipart("test.tar", open_stream) + + prefixed_key = "ha/backups/test.tar" + + assert mock_client.create_multipart_upload.await_args.kwargs["Key"] == prefixed_key + + for call in mock_client.upload_part.await_args_list: + assert call.kwargs["Key"] == prefixed_key + + assert ( + mock_client.complete_multipart_upload.await_args.kwargs["Key"] == prefixed_key + ) + + +async def test_list_backups_passes_prefix_to_list_objects( + hass: HomeAssistant, + mock_client: MagicMock, + mock_config_entry_with_prefix: MockConfigEntry, + test_backup: AgentBackup, +) -> None: + """Test list_objects_v2 is called with Prefix when configured.""" + mock_config_entry_with_prefix.runtime_data = mock_client + agent = R2BackupAgent(hass, mock_config_entry_with_prefix) + + # Make the listing for a prefixed bucket + tar_filename, metadata_filename = suggested_filenames(test_backup) + mock_client.list_objects_v2.return_value = { + "Contents": [ + {"Key": f"ha/backups/{metadata_filename}"}, + {"Key": f"ha/backups/{tar_filename}"}, + ] + } + + await agent.async_list_backups() + + assert mock_client.list_objects_v2.call_args.kwargs["Prefix"] == "ha/backups/"