Skip to content

Commit 672944d

Browse files
committed
Feat: allow s3 and or az dependencies
1 parent 9f740f6 commit 672944d

File tree

5 files changed

+115
-123
lines changed

5 files changed

+115
-123
lines changed

odc/geo/_interop.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,14 @@ def datacube(self) -> bool:
4343
def tifffile(self) -> bool:
4444
return self._check("tifffile")
4545

46+
@property
47+
def azure(self) -> bool:
48+
return self._check("azure.storage.blob")
49+
50+
@property
51+
def s3(self) -> bool:
52+
return self._check("boto3")
53+
4654
@staticmethod
4755
def _check(lib_name: str) -> bool:
4856
return importlib.util.find_spec(lib_name) is not None

odc/geo/cog/_az.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
import base64
22
from typing import Any, Union
33

4-
from azure.storage.blob import BlobBlock, BlobServiceClient
5-
from dask.delayed import Delayed
4+
import dask
65

76
from ._mpu import mpu_write
87
from ._multipart import MultiPartUploadBase
@@ -51,6 +50,9 @@ def __init__(
5150
self.credential = credential
5251

5352
# Initialise Azure Blob service client
53+
# pylint: disable=import-outside-toplevel,import-error
54+
from azure.storage.blob import BlobServiceClient
55+
5456
self.blob_service_client = BlobServiceClient(
5557
account_url=account_url, credential=credential
5658
)
@@ -85,6 +87,9 @@ def finalise(self, parts: list[dict[str, Any]]) -> str:
8587
:param parts: List of uploaded parts metadata.
8688
:return: The ETag of the finalised blob.
8789
"""
90+
# pylint: disable=import-outside-toplevel,import-error
91+
from azure.storage.blob import BlobBlock
92+
8893
block_list = [BlobBlock(block_id=part["BlockId"]) for part in parts]
8994
self.blob_client.commit_block_list(block_list)
9095
return self.blob_client.get_blob_properties().etag
@@ -121,7 +126,7 @@ def writer(self, kw: dict[str, Any], client: Any = None):
121126

122127
def upload(
123128
self,
124-
chunks: Union["dask.bag.Bag", list["dask.bag.Bag"]],
129+
chunks: Union[dask.bag.Bag, list[dask.bag.Bag]],
125130
*,
126131
mk_header: Any = None,
127132
mk_footer: Any = None,
@@ -130,7 +135,7 @@ def upload(
130135
spill_sz: int = 20 * (1 << 20),
131136
client: Any = None,
132137
**kw,
133-
) -> "Delayed":
138+
) -> dask.delayed.Delayed:
134139
"""
135140
Upload chunks to Azure Blob Storage with multipart uploads.
136141

odc/geo/cog/_multipart.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
"""
22
Multipart upload interface.
3+
4+
Defines the `MultiPartUploadBase` class for implementing multipart upload functionality.
5+
This interface standardises methods for initiating, uploading, and finalising
6+
multipart uploads across storage backends.
37
"""
48

59
from abc import ABC, abstractmethod

odc/geo/cog/_tifffile.py

Lines changed: 25 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from functools import partial
1212
from io import BytesIO
1313
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
14+
from urllib.parse import urlparse
1415
from xml.sax.saxutils import escape as xml_escape
1516

1617
import numpy as np
@@ -23,22 +24,6 @@
2324
from ._mpu import mpu_write
2425
from ._mpu_fs import MPUFileSink
2526

26-
try:
27-
from ._az import AzMultiPartUpload
28-
29-
HAVE_AZURE = True
30-
except ImportError:
31-
AzMultiPartUpload = None
32-
HAVE_AZURE = False
33-
try:
34-
from ._s3 import S3MultiPartUpload, s3_parse_url
35-
36-
HAVE_S3 = True
37-
except ImportError:
38-
S3MultiPartUpload = None
39-
s3_parse_url = None
40-
HAVE_S3 = False
41-
4227
from ._shared import (
4328
GDAL_COMP,
4429
GEOTIFF_TAGS,
@@ -641,6 +626,7 @@ def save_cog_with_dask(
641626
bigtiff: bool = True,
642627
overview_resampling: Union[int, str] = "nearest",
643628
aws: Optional[dict[str, Any]] = None,
629+
azure: Optional[dict[str, Any]] = None,
644630
client: Any = None,
645631
stats: bool | int = True,
646632
**kw,
@@ -669,13 +655,12 @@ def save_cog_with_dask(
669655

670656
from ..xr import ODCExtensionDa
671657

672-
if aws is None:
673-
aws = {}
658+
aws = aws or {}
659+
azure = azure or {}
674660

675-
upload_params = {k: kw.pop(k) for k in ["writes_per_chunk", "spill_sz"] if k in kw}
676-
upload_params.update(
677-
{k: aws.pop(k) for k in ["writes_per_chunk", "spill_sz"] if k in aws}
678-
)
661+
upload_params = {
662+
k: kw.pop(k, None) for k in ["writes_per_chunk", "spill_sz"] if k in kw
663+
}
679664
parts_base = kw.pop("parts_base", None)
680665

681666
# Normalise compression settings and remove GDAL compat options from kw
@@ -750,19 +735,25 @@ def save_cog_with_dask(
750735
# Determine output type and initiate uploader
751736
parsed_url = urlparse(dst)
752737
if parsed_url.scheme == "s3":
753-
if not HAVE_S3:
754-
raise ImportError("Install `boto3` to enable S3 support.")
755-
bucket, key = s3_parse_url(dst)
756-
uploader = S3MultiPartUpload(bucket, key, **aws)
738+
if have.s3:
739+
from ._s3 import S3MultiPartUpload, s3_parse_url
740+
741+
bucket, key = s3_parse_url(dst)
742+
uploader = S3MultiPartUpload(bucket, key, **aws)
743+
else:
744+
raise RuntimeError("Please install `boto3` to use S3")
757745
elif parsed_url.scheme == "az":
758-
if not HAVE_AZURE:
759-
raise ImportError("Install azure-storage-blob` to enable Azure support.")
760-
uploader = AzMultiPartUpload(
761-
account_url=azure.get("account_url"),
762-
container=parsed_url.netloc,
763-
blob=parsed_url.path.lstrip("/"),
764-
credential=azure.get("credential"),
765-
)
746+
if have.azure:
747+
from ._az import AzMultiPartUpload
748+
749+
uploader = AzMultiPartUpload(
750+
account_url=azure.get("account_url"),
751+
container=parsed_url.netloc,
752+
blob=parsed_url.path.lstrip("/"),
753+
credential=azure.get("credential"),
754+
)
755+
else:
756+
raise RuntimeError("Please install `azure-storage-blob` to use Azure")
766757
else:
767758
# Assume local disk
768759
write = MPUFileSink(dst, parts_base=parts_base)

tests/test_az.py

Lines changed: 69 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1,96 +1,80 @@
11
"""Tests for the Azure AzMultiPartUpload class."""
22

33
import base64
4-
import unittest
54
from unittest.mock import MagicMock, patch
65

7-
# Conditional import for Azure support
8-
try:
9-
from odc.geo.cog._az import AzMultiPartUpload
6+
import pytest
107

11-
HAVE_AZURE = True
12-
except ImportError:
13-
AzMultiPartUpload = None
14-
HAVE_AZURE = False
8+
pytest.importorskip("odc.geo.cog._az")
9+
from odc.geo.cog._az import AzMultiPartUpload # noqa: E402
1510

1611

17-
def require_azure(test_func):
18-
"""Decorator to skip tests if Azure dependencies are not installed."""
19-
return unittest.skipUnless(HAVE_AZURE, "Azure dependencies are not installed")(
20-
test_func
12+
@pytest.fixture
13+
def azure_mpu():
14+
"""Fixture for initializing AzMultiPartUpload."""
15+
account_url = "https://account_name.blob.core.windows.net"
16+
return AzMultiPartUpload(account_url, "container", "some.blob", None)
17+
18+
19+
def test_mpu_init(azure_mpu):
20+
"""Basic test for AzMultiPartUpload initialization."""
21+
assert azure_mpu.account_url == "https://account_name.blob.core.windows.net"
22+
assert azure_mpu.container == "container"
23+
assert azure_mpu.blob == "some.blob"
24+
assert azure_mpu.credential is None
25+
26+
27+
@patch("odc.geo.cog._az.BlobServiceClient")
28+
def test_azure_multipart_upload(mock_blob_service_client):
29+
"""Test the full Azure AzMultiPartUpload functionality."""
30+
# Mock Azure Blob SDK client structure
31+
mock_blob_client = MagicMock()
32+
mock_container_client = MagicMock()
33+
mock_blob_service_client.return_value.get_container_client.return_value = (
34+
mock_container_client
35+
)
36+
mock_container_client.get_blob_client.return_value = mock_blob_client
37+
38+
# Simulate return values for Azure Blob SDK methods
39+
mock_blob_client.get_blob_properties.return_value.etag = "mock-etag"
40+
41+
# Test parameters
42+
account_url = "https://mockaccount.blob.core.windows.net"
43+
container = "mock-container"
44+
blob = "mock-blob"
45+
credential = "mock-sas-token"
46+
47+
# Create an instance of AzMultiPartUpload and call its methods
48+
azure_upload = AzMultiPartUpload(account_url, container, blob, credential)
49+
upload_id = azure_upload.initiate()
50+
part1 = azure_upload.write_part(1, b"first chunk of data")
51+
part2 = azure_upload.write_part(2, b"second chunk of data")
52+
etag = azure_upload.finalise([part1, part2])
53+
54+
# Define block IDs
55+
block_id1 = base64.b64encode(b"block-1").decode("utf-8")
56+
block_id2 = base64.b64encode(b"block-2").decode("utf-8")
57+
58+
# Verify the results
59+
assert upload_id == "azure-block-upload"
60+
assert etag == "mock-etag"
61+
62+
# Verify BlobServiceClient instantiation
63+
mock_blob_service_client.assert_called_once_with(
64+
account_url=account_url, credential=credential
2165
)
2266

67+
# Verify stage_block calls
68+
mock_blob_client.stage_block.assert_any_call(
69+
block_id=block_id1, data=b"first chunk of data"
70+
)
71+
mock_blob_client.stage_block.assert_any_call(
72+
block_id=block_id2, data=b"second chunk of data"
73+
)
2374

24-
class TestAzMultiPartUpload(unittest.TestCase):
25-
"""Test the AzMultiPartUpload class."""
26-
27-
@require_azure
28-
def test_mpu_init(self):
29-
"""Basic test for AzMultiPartUpload initialization."""
30-
account_url = "https://account_name.blob.core.windows.net"
31-
mpu = AzMultiPartUpload(account_url, "container", "some.blob", None)
32-
33-
self.assertEqual(mpu.account_url, account_url)
34-
self.assertEqual(mpu.container, "container")
35-
self.assertEqual(mpu.blob, "some.blob")
36-
self.assertIsNone(mpu.credential)
37-
38-
@require_azure
39-
@patch("odc.geo.cog._az.BlobServiceClient")
40-
def test_azure_multipart_upload(self, mock_blob_service_client):
41-
"""Test the full Azure AzMultiPartUpload functionality."""
42-
# Arrange - Mock Azure Blob SDK client structure
43-
mock_blob_client = MagicMock()
44-
mock_container_client = MagicMock()
45-
mock_blob_service_client.return_value.get_container_client.return_value = (
46-
mock_container_client
47-
)
48-
mock_container_client.get_blob_client.return_value = mock_blob_client
49-
50-
# Simulate return values for Azure Blob SDK methods
51-
mock_blob_client.get_blob_properties.return_value.etag = "mock-etag"
52-
53-
# Test parameters
54-
account_url = "https://mockaccount.blob.core.windows.net"
55-
container = "mock-container"
56-
blob = "mock-blob"
57-
credential = "mock-sas-token"
58-
59-
# Act
60-
azure_upload = AzMultiPartUpload(account_url, container, blob, credential)
61-
upload_id = azure_upload.initiate()
62-
part1 = azure_upload.write_part(1, b"first chunk of data")
63-
part2 = azure_upload.write_part(2, b"second chunk of data")
64-
etag = azure_upload.finalise([part1, part2])
65-
66-
# Correctly calculate block IDs
67-
block_id1 = base64.b64encode(b"block-1").decode("utf-8")
68-
block_id2 = base64.b64encode(b"block-2").decode("utf-8")
69-
70-
# Assert
71-
self.assertEqual(upload_id, "azure-block-upload")
72-
self.assertEqual(etag, "mock-etag")
73-
74-
# Verify BlobServiceClient instantiation
75-
mock_blob_service_client.assert_called_once_with(
76-
account_url=account_url, credential=credential
77-
)
78-
79-
# Verify stage_block calls
80-
mock_blob_client.stage_block.assert_any_call(
81-
block_id=block_id1, data=b"first chunk of data"
82-
)
83-
mock_blob_client.stage_block.assert_any_call(
84-
block_id=block_id2, data=b"second chunk of data"
85-
)
86-
87-
# Verify commit_block_list was called correctly
88-
block_list = mock_blob_client.commit_block_list.call_args[0][0]
89-
self.assertEqual(len(block_list), 2)
90-
self.assertEqual(block_list[0].id, block_id1)
91-
self.assertEqual(block_list[1].id, block_id2)
92-
mock_blob_client.commit_block_list.assert_called_once()
93-
94-
95-
if __name__ == "__main__":
96-
unittest.main()
75+
# Verify commit_block_list was called correctly
76+
block_list = mock_blob_client.commit_block_list.call_args[0][0]
77+
assert len(block_list) == 2
78+
assert block_list[0].id == block_id1
79+
assert block_list[1].id == block_id2
80+
mock_blob_client.commit_block_list.assert_called_once()

0 commit comments

Comments
 (0)