Cloudflare R2 backup - Improved buffer handling (#162958)

This commit is contained in:
Patrick Vorgers
2026-02-15 16:16:10 +01:00
committed by GitHub
parent 32092c73c6
commit 7a52d71b40
3 changed files with 121 additions and 33 deletions

View File

@@ -5,7 +5,7 @@ import functools
import json
import logging
from time import time
from typing import Any
from typing import Any, cast
from botocore.exceptions import BotoCoreError
@@ -190,58 +190,77 @@ class R2BackupAgent(BackupAgent):
:param open_stream: A function returning an async iterator that yields bytes.
"""
_LOGGER.debug("Starting multipart upload for %s", tar_filename)
key = self._with_prefix(tar_filename)
multipart_upload = await self._client.create_multipart_upload(
Bucket=self._bucket,
Key=self._with_prefix(tar_filename),
Key=key,
)
upload_id = multipart_upload["UploadId"]
try:
parts: list[dict[str, Any]] = []
part_number = 1
buffer = bytearray() # bytes buffer to store the data
offset = 0 # start index of unread data inside buffer
stream = await open_stream()
async for chunk in stream:
buffer.extend(chunk)
# upload parts of exactly MULTIPART_MIN_PART_SIZE_BYTES to ensure
# all non-trailing parts have the same size (required by S3/R2)
while len(buffer) >= MULTIPART_MIN_PART_SIZE_BYTES:
part_data = bytes(buffer[:MULTIPART_MIN_PART_SIZE_BYTES])
del buffer[:MULTIPART_MIN_PART_SIZE_BYTES]
# Upload parts of exactly MULTIPART_MIN_PART_SIZE_BYTES to ensure
# all non-trailing parts have the same size (defensive implementation)
view = memoryview(buffer)
try:
while len(buffer) - offset >= MULTIPART_MIN_PART_SIZE_BYTES:
start = offset
end = offset + MULTIPART_MIN_PART_SIZE_BYTES
part_data = view[start:end]
offset = end
_LOGGER.debug(
"Uploading part number %d, size %d",
part_number,
len(part_data),
)
part = await self._client.upload_part(
Bucket=self._bucket,
Key=self._with_prefix(tar_filename),
PartNumber=part_number,
UploadId=upload_id,
Body=part_data,
)
parts.append({"PartNumber": part_number, "ETag": part["ETag"]})
part_number += 1
_LOGGER.debug(
"Uploading part number %d, size %d",
part_number,
len(part_data),
)
part = await cast(Any, self._client).upload_part(
Bucket=self._bucket,
Key=key,
PartNumber=part_number,
UploadId=upload_id,
Body=part_data.tobytes(),
)
parts.append({"PartNumber": part_number, "ETag": part["ETag"]})
part_number += 1
finally:
view.release()
# Compact the buffer if the consumed offset has grown large enough. This
# avoids unnecessary memory copies when compacting after every part upload.
if offset and offset >= MULTIPART_MIN_PART_SIZE_BYTES:
buffer = bytearray(buffer[offset:])
offset = 0
# Upload the final buffer as the last part (no minimum size requirement)
if buffer:
# Offset should be 0 after the last compaction, but we use it as the start
# index to be defensive in case the buffer was not compacted.
if offset < len(buffer):
remaining_data = memoryview(buffer)[offset:]
_LOGGER.debug(
"Uploading final part number %d, size %d", part_number, len(buffer)
"Uploading final part number %d, size %d",
part_number,
len(remaining_data),
)
part = await self._client.upload_part(
part = await cast(Any, self._client).upload_part(
Bucket=self._bucket,
Key=self._with_prefix(tar_filename),
Key=key,
PartNumber=part_number,
UploadId=upload_id,
Body=bytes(buffer),
Body=remaining_data.tobytes(),
)
parts.append({"PartNumber": part_number, "ETag": part["ETag"]})
await self._client.complete_multipart_upload(
await cast(Any, self._client).complete_multipart_upload(
Bucket=self._bucket,
Key=self._with_prefix(tar_filename),
Key=key,
UploadId=upload_id,
MultipartUpload={"Parts": parts},
)
@@ -250,7 +269,7 @@ class R2BackupAgent(BackupAgent):
try:
await self._client.abort_multipart_upload(
Bucket=self._bucket,
Key=self._with_prefix(tar_filename),
Key=key,
UploadId=upload_id,
)
except BotoCoreError:

View File

@@ -11,7 +11,7 @@ from homeassistant.components.cloudflare_r2.backup import (
MULTIPART_MIN_PART_SIZE_BYTES,
suggested_filenames,
)
from homeassistant.components.cloudflare_r2.const import DOMAIN
from homeassistant.components.cloudflare_r2.const import CONF_PREFIX, DOMAIN
from .const import USER_INPUT
@@ -79,3 +79,19 @@ def mock_config_entry() -> MockConfigEntry:
domain=DOMAIN,
data=USER_INPUT,
)
@pytest.fixture
def mock_config_entry_with_prefix(
mock_config_entry: MockConfigEntry,
) -> MockConfigEntry:
"""Return a mocked config entry with a prefix configured."""
data = dict(mock_config_entry.data)
data[CONF_PREFIX] = "ha/backups"
return MockConfigEntry(
entry_id=mock_config_entry.entry_id,
title=mock_config_entry.title,
domain=mock_config_entry.domain,
data=data,
)

View File

@@ -1,18 +1,17 @@
"""Test the Cloudflare R2 backup platform."""
from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, AsyncIterator
from io import StringIO
import json
from time import time
from unittest.mock import AsyncMock, Mock, patch
from botocore.exceptions import ConnectTimeoutError
from botocore.exceptions import BotoCoreError, ConnectTimeoutError
import pytest
from homeassistant.components.backup import DOMAIN as BACKUP_DOMAIN, AgentBackup
from homeassistant.components.cloudflare_r2.backup import (
MULTIPART_MIN_PART_SIZE_BYTES,
BotoCoreError,
R2BackupAgent,
async_register_backup_agents_listener,
suggested_filenames,
@@ -521,3 +520,57 @@ async def test_listeners_get_cleaned_up(hass: HomeAssistant) -> None:
remove_listener()
assert DATA_BACKUP_AGENT_LISTENERS not in hass.data
async def test_multipart_upload_uses_prefix_for_all_calls(
hass: HomeAssistant,
mock_client: MagicMock,
mock_config_entry_with_prefix: MockConfigEntry,
) -> None:
"""Test multipart upload uses the configured prefix for all S3 calls."""
mock_config_entry_with_prefix.runtime_data = mock_client
agent = R2BackupAgent(hass, mock_config_entry_with_prefix)
async def stream() -> AsyncIterator[bytes]:
# Force multipart: > MIN_PART_SIZE
yield b"x" * (MULTIPART_MIN_PART_SIZE_BYTES + 1)
async def open_stream():
return stream()
await agent._upload_multipart("test.tar", open_stream)
prefixed_key = "ha/backups/test.tar"
assert mock_client.create_multipart_upload.await_args.kwargs["Key"] == prefixed_key
for call in mock_client.upload_part.await_args_list:
assert call.kwargs["Key"] == prefixed_key
assert (
mock_client.complete_multipart_upload.await_args.kwargs["Key"] == prefixed_key
)
async def test_list_backups_passes_prefix_to_list_objects(
hass: HomeAssistant,
mock_client: MagicMock,
mock_config_entry_with_prefix: MockConfigEntry,
test_backup: AgentBackup,
) -> None:
"""Test list_objects_v2 is called with Prefix when configured."""
mock_config_entry_with_prefix.runtime_data = mock_client
agent = R2BackupAgent(hass, mock_config_entry_with_prefix)
# Make the listing for a prefixed bucket
tar_filename, metadata_filename = suggested_filenames(test_backup)
mock_client.list_objects_v2.return_value = {
"Contents": [
{"Key": f"ha/backups/{metadata_filename}"},
{"Key": f"ha/backups/{tar_filename}"},
]
}
await agent.async_list_backups()
assert mock_client.list_objects_v2.call_args.kwargs["Prefix"] == "ha/backups/"