mirror of
https://github.com/Electric-Special/ha-core.git
synced 2026-03-21 04:05:20 +01:00
Cloudflare R2 backup - Improved buffer handling (#162958)
This commit is contained in:
@@ -5,7 +5,7 @@ import functools
|
||||
import json
|
||||
import logging
|
||||
from time import time
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from botocore.exceptions import BotoCoreError
|
||||
|
||||
@@ -190,58 +190,77 @@ class R2BackupAgent(BackupAgent):
|
||||
:param open_stream: A function returning an async iterator that yields bytes.
|
||||
"""
|
||||
_LOGGER.debug("Starting multipart upload for %s", tar_filename)
|
||||
key = self._with_prefix(tar_filename)
|
||||
multipart_upload = await self._client.create_multipart_upload(
|
||||
Bucket=self._bucket,
|
||||
Key=self._with_prefix(tar_filename),
|
||||
Key=key,
|
||||
)
|
||||
upload_id = multipart_upload["UploadId"]
|
||||
try:
|
||||
parts: list[dict[str, Any]] = []
|
||||
part_number = 1
|
||||
buffer = bytearray() # bytes buffer to store the data
|
||||
offset = 0 # start index of unread data inside buffer
|
||||
|
||||
stream = await open_stream()
|
||||
async for chunk in stream:
|
||||
buffer.extend(chunk)
|
||||
|
||||
# upload parts of exactly MULTIPART_MIN_PART_SIZE_BYTES to ensure
|
||||
# all non-trailing parts have the same size (required by S3/R2)
|
||||
while len(buffer) >= MULTIPART_MIN_PART_SIZE_BYTES:
|
||||
part_data = bytes(buffer[:MULTIPART_MIN_PART_SIZE_BYTES])
|
||||
del buffer[:MULTIPART_MIN_PART_SIZE_BYTES]
|
||||
# Upload parts of exactly MULTIPART_MIN_PART_SIZE_BYTES to ensure
|
||||
# all non-trailing parts have the same size (defensive implementation)
|
||||
view = memoryview(buffer)
|
||||
try:
|
||||
while len(buffer) - offset >= MULTIPART_MIN_PART_SIZE_BYTES:
|
||||
start = offset
|
||||
end = offset + MULTIPART_MIN_PART_SIZE_BYTES
|
||||
part_data = view[start:end]
|
||||
offset = end
|
||||
|
||||
_LOGGER.debug(
|
||||
"Uploading part number %d, size %d",
|
||||
part_number,
|
||||
len(part_data),
|
||||
)
|
||||
part = await self._client.upload_part(
|
||||
Bucket=self._bucket,
|
||||
Key=self._with_prefix(tar_filename),
|
||||
PartNumber=part_number,
|
||||
UploadId=upload_id,
|
||||
Body=part_data,
|
||||
)
|
||||
parts.append({"PartNumber": part_number, "ETag": part["ETag"]})
|
||||
part_number += 1
|
||||
_LOGGER.debug(
|
||||
"Uploading part number %d, size %d",
|
||||
part_number,
|
||||
len(part_data),
|
||||
)
|
||||
part = await cast(Any, self._client).upload_part(
|
||||
Bucket=self._bucket,
|
||||
Key=key,
|
||||
PartNumber=part_number,
|
||||
UploadId=upload_id,
|
||||
Body=part_data.tobytes(),
|
||||
)
|
||||
parts.append({"PartNumber": part_number, "ETag": part["ETag"]})
|
||||
part_number += 1
|
||||
finally:
|
||||
view.release()
|
||||
|
||||
# Compact the buffer if the consumed offset has grown large enough. This
|
||||
# avoids unnecessary memory copies when compacting after every part upload.
|
||||
if offset and offset >= MULTIPART_MIN_PART_SIZE_BYTES:
|
||||
buffer = bytearray(buffer[offset:])
|
||||
offset = 0
|
||||
|
||||
# Upload the final buffer as the last part (no minimum size requirement)
|
||||
if buffer:
|
||||
# Offset should be 0 after the last compaction, but we use it as the start
|
||||
# index to be defensive in case the buffer was not compacted.
|
||||
if offset < len(buffer):
|
||||
remaining_data = memoryview(buffer)[offset:]
|
||||
_LOGGER.debug(
|
||||
"Uploading final part number %d, size %d", part_number, len(buffer)
|
||||
"Uploading final part number %d, size %d",
|
||||
part_number,
|
||||
len(remaining_data),
|
||||
)
|
||||
part = await self._client.upload_part(
|
||||
part = await cast(Any, self._client).upload_part(
|
||||
Bucket=self._bucket,
|
||||
Key=self._with_prefix(tar_filename),
|
||||
Key=key,
|
||||
PartNumber=part_number,
|
||||
UploadId=upload_id,
|
||||
Body=bytes(buffer),
|
||||
Body=remaining_data.tobytes(),
|
||||
)
|
||||
parts.append({"PartNumber": part_number, "ETag": part["ETag"]})
|
||||
|
||||
await self._client.complete_multipart_upload(
|
||||
await cast(Any, self._client).complete_multipart_upload(
|
||||
Bucket=self._bucket,
|
||||
Key=self._with_prefix(tar_filename),
|
||||
Key=key,
|
||||
UploadId=upload_id,
|
||||
MultipartUpload={"Parts": parts},
|
||||
)
|
||||
@@ -250,7 +269,7 @@ class R2BackupAgent(BackupAgent):
|
||||
try:
|
||||
await self._client.abort_multipart_upload(
|
||||
Bucket=self._bucket,
|
||||
Key=self._with_prefix(tar_filename),
|
||||
Key=key,
|
||||
UploadId=upload_id,
|
||||
)
|
||||
except BotoCoreError:
|
||||
|
||||
@@ -11,7 +11,7 @@ from homeassistant.components.cloudflare_r2.backup import (
|
||||
MULTIPART_MIN_PART_SIZE_BYTES,
|
||||
suggested_filenames,
|
||||
)
|
||||
from homeassistant.components.cloudflare_r2.const import DOMAIN
|
||||
from homeassistant.components.cloudflare_r2.const import CONF_PREFIX, DOMAIN
|
||||
|
||||
from .const import USER_INPUT
|
||||
|
||||
@@ -79,3 +79,19 @@ def mock_config_entry() -> MockConfigEntry:
|
||||
domain=DOMAIN,
|
||||
data=USER_INPUT,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config_entry_with_prefix(
|
||||
mock_config_entry: MockConfigEntry,
|
||||
) -> MockConfigEntry:
|
||||
"""Return a mocked config entry with a prefix configured."""
|
||||
data = dict(mock_config_entry.data)
|
||||
data[CONF_PREFIX] = "ha/backups"
|
||||
|
||||
return MockConfigEntry(
|
||||
entry_id=mock_config_entry.entry_id,
|
||||
title=mock_config_entry.title,
|
||||
domain=mock_config_entry.domain,
|
||||
data=data,
|
||||
)
|
||||
|
||||
@@ -1,18 +1,17 @@
|
||||
"""Test the Cloudflare R2 backup platform."""
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from io import StringIO
|
||||
import json
|
||||
from time import time
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
from botocore.exceptions import ConnectTimeoutError
|
||||
from botocore.exceptions import BotoCoreError, ConnectTimeoutError
|
||||
import pytest
|
||||
|
||||
from homeassistant.components.backup import DOMAIN as BACKUP_DOMAIN, AgentBackup
|
||||
from homeassistant.components.cloudflare_r2.backup import (
|
||||
MULTIPART_MIN_PART_SIZE_BYTES,
|
||||
BotoCoreError,
|
||||
R2BackupAgent,
|
||||
async_register_backup_agents_listener,
|
||||
suggested_filenames,
|
||||
@@ -521,3 +520,57 @@ async def test_listeners_get_cleaned_up(hass: HomeAssistant) -> None:
|
||||
remove_listener()
|
||||
|
||||
assert DATA_BACKUP_AGENT_LISTENERS not in hass.data
|
||||
|
||||
|
||||
async def test_multipart_upload_uses_prefix_for_all_calls(
|
||||
hass: HomeAssistant,
|
||||
mock_client: MagicMock,
|
||||
mock_config_entry_with_prefix: MockConfigEntry,
|
||||
) -> None:
|
||||
"""Test multipart upload uses the configured prefix for all S3 calls."""
|
||||
mock_config_entry_with_prefix.runtime_data = mock_client
|
||||
agent = R2BackupAgent(hass, mock_config_entry_with_prefix)
|
||||
|
||||
async def stream() -> AsyncIterator[bytes]:
|
||||
# Force multipart: > MIN_PART_SIZE
|
||||
yield b"x" * (MULTIPART_MIN_PART_SIZE_BYTES + 1)
|
||||
|
||||
async def open_stream():
|
||||
return stream()
|
||||
|
||||
await agent._upload_multipart("test.tar", open_stream)
|
||||
|
||||
prefixed_key = "ha/backups/test.tar"
|
||||
|
||||
assert mock_client.create_multipart_upload.await_args.kwargs["Key"] == prefixed_key
|
||||
|
||||
for call in mock_client.upload_part.await_args_list:
|
||||
assert call.kwargs["Key"] == prefixed_key
|
||||
|
||||
assert (
|
||||
mock_client.complete_multipart_upload.await_args.kwargs["Key"] == prefixed_key
|
||||
)
|
||||
|
||||
|
||||
async def test_list_backups_passes_prefix_to_list_objects(
|
||||
hass: HomeAssistant,
|
||||
mock_client: MagicMock,
|
||||
mock_config_entry_with_prefix: MockConfigEntry,
|
||||
test_backup: AgentBackup,
|
||||
) -> None:
|
||||
"""Test list_objects_v2 is called with Prefix when configured."""
|
||||
mock_config_entry_with_prefix.runtime_data = mock_client
|
||||
agent = R2BackupAgent(hass, mock_config_entry_with_prefix)
|
||||
|
||||
# Make the listing for a prefixed bucket
|
||||
tar_filename, metadata_filename = suggested_filenames(test_backup)
|
||||
mock_client.list_objects_v2.return_value = {
|
||||
"Contents": [
|
||||
{"Key": f"ha/backups/{metadata_filename}"},
|
||||
{"Key": f"ha/backups/{tar_filename}"},
|
||||
]
|
||||
}
|
||||
|
||||
await agent.async_list_backups()
|
||||
|
||||
assert mock_client.list_objects_v2.call_args.kwargs["Prefix"] == "ha/backups/"
|
||||
|
||||
Reference in New Issue
Block a user