Skip to content

Commit d80b5cd

Browse files
committed
Feat: azure backend to save cogs with dask
1 parent 831b7e2 commit d80b5cd

File tree

2 files changed

+265
-0
lines changed

2 files changed

+265
-0
lines changed

odc/geo/cog/_az.py

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
import base64
2+
from typing import Any, Union
3+
4+
from azure.storage.blob import BlobBlock, BlobServiceClient
5+
from dask.delayed import Delayed
6+
7+
from ._mpu import mpu_write
8+
from ._multipart import MultiPartUploadBase
9+
10+
11+
class AzureLimits:
12+
"""
13+
Common Azure writer settings.
14+
"""
15+
16+
@property
17+
def min_write_sz(self) -> int:
18+
# Azure minimum write size for blocks (default is 4 MiB)
19+
return 4 * (1 << 20)
20+
21+
@property
22+
def max_write_sz(self) -> int:
23+
# Azure maximum write size for blocks (default is 100 MiB)
24+
return 100 * (1 << 20)
25+
26+
@property
27+
def min_part(self) -> int:
28+
return 1
29+
30+
@property
31+
def max_part(self) -> int:
32+
# Azure supports up to 50,000 blocks per blob
33+
return 50_000
34+
35+
36+
class MultiPartUpload(AzureLimits, MultiPartUploadBase):
37+
def __init__(
38+
self, account_url: str, container: str, blob: str, credential: Any = None
39+
):
40+
"""
41+
Initialise Azure multipart upload.
42+
43+
:param account_url: URL of the Azure storage account.
44+
:param container: Name of the container.
45+
:param blob: Name of the blob.
46+
:param credential: Authentication credentials (e.g., SAS token or key).
47+
"""
48+
self.account_url = account_url
49+
self.container = container
50+
self.blob = blob
51+
self.credential = credential
52+
53+
# Initialise Azure Blob service client
54+
self.blob_service_client = BlobServiceClient(
55+
account_url=account_url, credential=credential
56+
)
57+
self.container_client = self.blob_service_client.get_container_client(container)
58+
self.blob_client = self.container_client.get_blob_client(blob)
59+
60+
self.block_ids: list[str] = []
61+
62+
def initiate(self, **kwargs) -> str:
63+
"""
64+
Initialise the upload. No-op for Azure.
65+
"""
66+
return "azure-block-upload"
67+
68+
def write_part(self, part: int, data: bytes) -> dict[str, Any]:
69+
"""
70+
Stage a block in Azure.
71+
72+
:param part: Part number (unique).
73+
:param data: Data for this part.
74+
:return: A dictionary containing part information.
75+
"""
76+
block_id = base64.b64encode(f"block-{part}".encode()).decode()
77+
self.blob_client.stage_block(block_id=block_id, data=data)
78+
self.block_ids.append(block_id)
79+
return {"PartNumber": part, "BlockId": block_id}
80+
81+
def finalise(self, parts: list[dict[str, Any]]) -> str:
82+
"""
83+
Commit the block list to finalise the upload.
84+
85+
:param parts: List of uploaded parts metadata.
86+
:return: The ETag of the finalised blob.
87+
"""
88+
block_list = [BlobBlock(block_id=part["BlockId"]) for part in parts]
89+
self.blob_client.commit_block_list(block_list)
90+
return self.blob_client.get_blob_properties().etag
91+
92+
def cancel(self):
93+
"""
94+
Cancel the upload by clearing the block list.
95+
"""
96+
self.block_ids.clear()
97+
98+
@property
99+
def url(self) -> str:
100+
"""
101+
Get the Azure blob URL.
102+
103+
:return: The full URL of the blob.
104+
"""
105+
return self.blob_client.url
106+
107+
@property
108+
def started(self) -> bool:
109+
"""
110+
Check if any blocks have been staged.
111+
112+
:return: True if blocks have been staged, False otherwise.
113+
"""
114+
return bool(self.block_ids)
115+
116+
def writer(self, kw: dict[str, Any], client: Any = None):
117+
"""
118+
Return a stateless writer compatible with Dask.
119+
"""
120+
return DelayedAzureWriter(self, kw)
121+
122+
def upload(
123+
self,
124+
chunks: Union["dask.bag.Bag", list["dask.bag.Bag"]],
125+
*,
126+
mk_header: Any = None,
127+
mk_footer: Any = None,
128+
user_kw: dict[str, Any] = None,
129+
writes_per_chunk: int = 1,
130+
spill_sz: int = 20 * (1 << 20),
131+
client: Any = None,
132+
**kw,
133+
) -> "Delayed":
134+
"""
135+
Upload chunks to Azure Blob Storage with multipart uploads.
136+
137+
:param chunks: Dask bag of chunks to upload.
138+
:param mk_header: Function to create header data.
139+
:param mk_footer: Function to create footer data.
140+
:param user_kw: User-provided metadata for the upload.
141+
:param writes_per_chunk: Number of writes per chunk.
142+
:param spill_sz: Spill size for buffering data.
143+
:param client: Dask client for distributed execution.
144+
:return: A Dask delayed object representing the finalised upload.
145+
"""
146+
write = self.writer(kw, client=client) if spill_sz else None
147+
return mpu_write(
148+
chunks,
149+
write,
150+
mk_header=mk_header,
151+
mk_footer=mk_footer,
152+
user_kw=user_kw,
153+
writes_per_chunk=writes_per_chunk,
154+
spill_sz=spill_sz,
155+
dask_name_prefix="azure-finalise",
156+
)
157+
158+
159+
class DelayedAzureWriter(AzureLimits):
160+
"""
161+
Dask-compatible writer for Azure Blob Storage multipart uploads.
162+
"""
163+
164+
def __init__(self, mpu: MultiPartUpload, kw: dict[str, Any]):
165+
"""
166+
Initialise the Azure writer.
167+
168+
:param mpu: MultiPartUpload instance.
169+
:param kw: Additional parameters for the writer.
170+
"""
171+
self.mpu = mpu
172+
self.kw = kw # Additional metadata like ContentType
173+
174+
def __call__(self, part: int, data: bytes) -> dict[str, Any]:
175+
"""
176+
Write a single part to Azure Blob Storage.
177+
178+
:param part: Part number.
179+
:param data: Chunk data.
180+
:return: Metadata for the written part.
181+
"""
182+
return self.mpu.write_part(part, data)
183+
184+
def finalise(self, parts: list[dict[str, Any]]) -> str:
185+
"""
186+
Finalise the upload by committing the block list.
187+
188+
:param parts: List of uploaded parts metadata.
189+
:return: ETag of the finalised blob.
190+
"""
191+
return self.mpu.finalise(parts)

tests/test_az.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
"""Tests for the Azure MultiPartUpload class."""
2+
3+
import unittest
4+
from unittest.mock import MagicMock, patch
5+
6+
from odc.geo.cog._az import MultiPartUpload
7+
8+
9+
def test_mpu_init():
10+
"""Basic test for the MultiPartUpload class."""
11+
account_url = "https://account_name.blob.core.windows.net"
12+
mpu = MultiPartUpload(account_url, "container", "some.blob", None)
13+
if mpu.account_url != account_url:
14+
raise AssertionError(f"mpu.account_url should be '{account_url}'.")
15+
if mpu.container != "container":
16+
raise AssertionError("mpu.container should be 'container'.")
17+
if mpu.blob != "some.blob":
18+
raise AssertionError("mpu.blob should be 'some.blob'.")
19+
if mpu.credential is not None:
20+
raise AssertionError("mpu.credential should be 'None'.")
21+
22+
23+
class TestMultiPartUpload(unittest.TestCase):
24+
"""Test the MultiPartUpload class."""
25+
26+
@patch("odc.geo.cog._az.BlobServiceClient")
27+
def test_azure_multipart_upload(self, mock_blob_service_client):
28+
"""Test the MultiPartUpload class."""
29+
# Arrange - mock the Azure Blob SDK
30+
# Mock the blob client and its methods
31+
mock_blob_client = MagicMock()
32+
mock_container_client = MagicMock()
33+
mcc = mock_container_client
34+
mock_blob_service_client.return_value.get_container_client.return_value = mcc
35+
mock_container_client.get_blob_client.return_value = mock_blob_client
36+
37+
# Simulate return values for Azure Blob SDK methods
38+
mock_blob_client.get_blob_properties.return_value.etag = "mock-etag"
39+
40+
# Test parameters
41+
account_url = "https://mockaccount.blob.core.windows.net"
42+
container = "mock-container"
43+
blob = "mock-blob"
44+
credential = "mock-sas-token"
45+
46+
# Act - create an instance of MultiPartUpload and call its methods
47+
azure_upload = MultiPartUpload(account_url, container, blob, credential)
48+
upload_id = azure_upload.initiate()
49+
part1 = azure_upload.write_part(1, b"first chunk of data")
50+
part2 = azure_upload.write_part(2, b"second chunk of data")
51+
etag = azure_upload.finalise([part1, part2])
52+
53+
# Assert - check the results
54+
# Check that the initiate method behaves as expected
55+
self.assertEqual(upload_id, "azure-block-upload")
56+
57+
# Verify the calls to Azure Blob SDK methods
58+
mock_blob_service_client.assert_called_once_with(
59+
account_url=account_url, credential=credential
60+
)
61+
mock_blob_client.stage_block.assert_any_call(
62+
part1["BlockId"], b"first chunk of data"
63+
)
64+
mock_blob_client.stage_block.assert_any_call(
65+
part2["BlockId"], b"second chunk of data"
66+
)
67+
mock_blob_client.commit_block_list.assert_called_once()
68+
self.assertEqual(etag, "mock-etag")
69+
70+
# Verify block list passed during finalise
71+
block_list = mock_blob_client.commit_block_list.call_args[0][0]
72+
self.assertEqual(len(block_list), 2)
73+
self.assertEqual(block_list[0].id, part1["BlockId"])
74+
self.assertEqual(block_list[1].id, part2["BlockId"])

0 commit comments

Comments
 (0)