Skip to content

Commit a54072f

Browse files
committed
Feat: safely import azure-storage-blob and boto3
1 parent d80b5cd commit a54072f

File tree

7 files changed

+112
-48
lines changed

7 files changed

+112
-48
lines changed

dev-env.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ channels:
33
- conda-forge
44

55
dependencies:
6-
- python =3.8
6+
- python =3.10
77

88
# odc-geo dependencies
99
- pyproj

odc/geo/cog/_az.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def max_part(self) -> int:
3333
return 50_000
3434

3535

36-
class MultiPartUpload(AzureLimits, MultiPartUploadBase):
36+
class AzMultiPartUpload(AzureLimits, MultiPartUploadBase):
3737
def __init__(
3838
self, account_url: str, container: str, blob: str, credential: Any = None
3939
):
@@ -161,11 +161,11 @@ class DelayedAzureWriter(AzureLimits):
161161
Dask-compatible writer for Azure Blob Storage multipart uploads.
162162
"""
163163

164-
def __init__(self, mpu: MultiPartUpload, kw: dict[str, Any]):
164+
def __init__(self, mpu: AzMultiPartUpload, kw: dict[str, Any]):
165165
"""
166166
Initialise the Azure writer.
167167
168-
:param mpu: MultiPartUpload instance.
168+
:param mpu: AzMultiPartUpload instance.
169169
:param kw: Additional parameters for the writer.
170170
"""
171171
self.mpu = mpu

odc/geo/cog/_s3.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def max_part(self) -> int:
7070
return 10_000
7171

7272

73-
class MultiPartUpload(S3Limits, MultiPartUploadBase):
73+
class S3MultiPartUpload(S3Limits, MultiPartUploadBase):
7474
"""
7575
Dask to S3 dumper.
7676
"""
@@ -237,7 +237,7 @@ class DelayedS3Writer(S3Limits):
237237

238238
# pylint: disable=import-outside-toplevel,import-error
239239

240-
def __init__(self, mpu: MultiPartUpload, kw: dict[str, Any]):
240+
def __init__(self, mpu: S3MultiPartUpload, kw: dict[str, Any]):
241241
self.mpu = mpu
242242
self.kw = kw # mostly ContentType= kinda thing
243243
self._shared_var: Optional["distributed.Variable"] = None
@@ -263,7 +263,7 @@ def _shared(self, client: "distributed.Client") -> "distributed.Variable":
263263
self._shared_var = Variable(self._build_name("MPUpload"), client)
264264
return self._shared_var
265265

266-
def _ensure_init(self, final_write: bool = False) -> MultiPartUpload:
266+
def _ensure_init(self, final_write: bool = False) -> S3MultiPartUpload:
267267
# pylint: disable=too-many-return-statements
268268
mpu = self.mpu
269269
if mpu.started:

odc/geo/cog/_tifffile.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,29 @@
1717
import numpy as np
1818
import xarray as xr
1919

20-
2120
from .._interop import have
2221
from ..geobox import GeoBox
2322
from ..math import resolve_nodata
2423
from ..types import Shape2d, SomeNodata, Unset, shape_
25-
from ._az import MultiPartUpload as AzMultiPartUpload
2624
from ._mpu import mpu_write
2725
from ._mpu_fs import MPUFileSink
28-
from ._s3 import MultiPartUpload as S3MultiPartUpload, s3_parse_url
26+
27+
try:
28+
from ._az import AzMultiPartUpload
29+
30+
HAVE_AZURE = True
31+
except ImportError:
32+
AzMultiPartUpload = None
33+
HAVE_AZURE = False
34+
try:
35+
from ._s3 import S3MultiPartUpload, s3_parse_url
36+
37+
HAVE_S3 = True
38+
except ImportError:
39+
S3MultiPartUpload = None
40+
s3_parse_url = None
41+
HAVE_S3 = False
42+
2943
from ._shared import (
3044
GDAL_COMP,
3145
GEOTIFF_TAGS,
@@ -738,9 +752,13 @@ def save_cog_with_dask(
738752
# Determine output type and initiate uploader
739753
parsed_url = urlparse(dst)
740754
if parsed_url.scheme == "s3":
755+
if not HAVE_S3:
756+
raise ImportError("Install `boto3` to enable S3 support.")
741757
bucket, key = s3_parse_url(dst)
742758
uploader = S3MultiPartUpload(bucket, key, **aws)
743759
elif parsed_url.scheme == "az":
760+
if not HAVE_AZURE:
761+
raise ImportError("Install azure-storage-blob` to enable Azure support.")
744762
uploader = AzMultiPartUpload(
745763
account_url=azure.get("account_url"),
746764
container=parsed_url.netloc,

setup.cfg

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,14 @@ tiff =
5454
s3 =
5555
boto3
5656

57+
az =
58+
azure-storage-blob
59+
5760
all =
5861
%(warp)s
5962
%(tiff)s
6063
%(s3)s
64+
%(az)s
6165

6266
test =
6367
pytest

tests/test_az.py

Lines changed: 55 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,50 @@
1-
"""Tests for the Azure MultiPartUpload class."""
1+
"""Tests for the Azure AzMultiPartUpload class."""
22

3+
import base64
34
import unittest
45
from unittest.mock import MagicMock, patch
56

6-
from odc.geo.cog._az import MultiPartUpload
7+
# Conditional import for Azure support
8+
try:
9+
from odc.geo.cog._az import AzMultiPartUpload
710

11+
HAVE_AZURE = True
12+
except ImportError:
13+
AzMultiPartUpload = None
14+
HAVE_AZURE = False
815

9-
def test_mpu_init():
10-
"""Basic test for the MultiPartUpload class."""
11-
account_url = "https://account_name.blob.core.windows.net"
12-
mpu = MultiPartUpload(account_url, "container", "some.blob", None)
13-
if mpu.account_url != account_url:
14-
raise AssertionError(f"mpu.account_url should be '{account_url}'.")
15-
if mpu.container != "container":
16-
raise AssertionError("mpu.container should be 'container'.")
17-
if mpu.blob != "some.blob":
18-
raise AssertionError("mpu.blob should be 'some.blob'.")
19-
if mpu.credential is not None:
20-
raise AssertionError("mpu.credential should be 'None'.")
2116

17+
def require_azure(test_func):
18+
"""Decorator to skip tests if Azure dependencies are not installed."""
19+
return unittest.skipUnless(HAVE_AZURE, "Azure dependencies are not installed")(
20+
test_func
21+
)
2222

23-
class TestMultiPartUpload(unittest.TestCase):
24-
"""Test the MultiPartUpload class."""
2523

24+
class TestAzMultiPartUpload(unittest.TestCase):
25+
"""Test the AzMultiPartUpload class."""
26+
27+
@require_azure
28+
def test_mpu_init(self):
29+
"""Basic test for AzMultiPartUpload initialization."""
30+
account_url = "https://account_name.blob.core.windows.net"
31+
mpu = AzMultiPartUpload(account_url, "container", "some.blob", None)
32+
33+
self.assertEqual(mpu.account_url, account_url)
34+
self.assertEqual(mpu.container, "container")
35+
self.assertEqual(mpu.blob, "some.blob")
36+
self.assertIsNone(mpu.credential)
37+
38+
@require_azure
2639
@patch("odc.geo.cog._az.BlobServiceClient")
2740
def test_azure_multipart_upload(self, mock_blob_service_client):
28-
"""Test the MultiPartUpload class."""
29-
# Arrange - mock the Azure Blob SDK
30-
# Mock the blob client and its methods
41+
"""Test the full Azure AzMultiPartUpload functionality."""
42+
# Arrange - Mock Azure Blob SDK client structure
3143
mock_blob_client = MagicMock()
3244
mock_container_client = MagicMock()
33-
mcc = mock_container_client
34-
mock_blob_service_client.return_value.get_container_client.return_value = mcc
45+
mock_blob_service_client.return_value.get_container_client.return_value = (
46+
mock_container_client
47+
)
3548
mock_container_client.get_blob_client.return_value = mock_blob_client
3649

3750
# Simulate return values for Azure Blob SDK methods
@@ -43,32 +56,41 @@ def test_azure_multipart_upload(self, mock_blob_service_client):
4356
blob = "mock-blob"
4457
credential = "mock-sas-token"
4558

46-
# Act - create an instance of MultiPartUpload and call its methods
47-
azure_upload = MultiPartUpload(account_url, container, blob, credential)
59+
# Act
60+
azure_upload = AzMultiPartUpload(account_url, container, blob, credential)
4861
upload_id = azure_upload.initiate()
4962
part1 = azure_upload.write_part(1, b"first chunk of data")
5063
part2 = azure_upload.write_part(2, b"second chunk of data")
5164
etag = azure_upload.finalise([part1, part2])
5265

53-
# Assert - check the results
54-
# Check that the initiate method behaves as expected
66+
# Correctly calculate block IDs
67+
block_id1 = base64.b64encode(b"block-1").decode("utf-8")
68+
block_id2 = base64.b64encode(b"block-2").decode("utf-8")
69+
70+
# Assert
5571
self.assertEqual(upload_id, "azure-block-upload")
72+
self.assertEqual(etag, "mock-etag")
5673

57-
# Verify the calls to Azure Blob SDK methods
74+
# Verify BlobServiceClient instantiation
5875
mock_blob_service_client.assert_called_once_with(
5976
account_url=account_url, credential=credential
6077
)
78+
79+
# Verify stage_block calls
6180
mock_blob_client.stage_block.assert_any_call(
62-
part1["BlockId"], b"first chunk of data"
81+
block_id=block_id1, data=b"first chunk of data"
6382
)
6483
mock_blob_client.stage_block.assert_any_call(
65-
part2["BlockId"], b"second chunk of data"
84+
block_id=block_id2, data=b"second chunk of data"
6685
)
67-
mock_blob_client.commit_block_list.assert_called_once()
68-
self.assertEqual(etag, "mock-etag")
6986

70-
# Verify block list passed during finalise
87+
# Verify commit_block_list was called correctly
7188
block_list = mock_blob_client.commit_block_list.call_args[0][0]
7289
self.assertEqual(len(block_list), 2)
73-
self.assertEqual(block_list[0].id, part1["BlockId"])
74-
self.assertEqual(block_list[1].id, part2["BlockId"])
90+
self.assertEqual(block_list[0].id, block_id1)
91+
self.assertEqual(block_list[1].id, block_id2)
92+
mock_blob_client.commit_block_list.assert_called_once()
93+
94+
95+
if __name__ == "__main__":
96+
unittest.main()

tests/test_s3.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,29 @@
1-
from odc.geo.cog._s3 import MultiPartUpload
1+
"""Tests for odc.geo.cog._s3."""
22

3-
# TODO: moto
3+
import unittest
44

5+
from odc.geo.cog._s3 import S3MultiPartUpload
56

7+
# Conditional import for S3 support
8+
try:
9+
from odc.geo.cog._s3 import S3MultiPartUpload
10+
11+
HAVE_S3 = True
12+
except ImportError:
13+
S3MultiPartUpload = None
14+
HAVE_S3 = False
15+
16+
17+
def require_s3(test_func):
18+
"""Decorator to skip tests if s3 dependencies are not installed."""
19+
return unittest.skipUnless(HAVE_S3, "s3 dependencies are not installed")(test_func)
20+
21+
22+
@require_s3
623
def test_s3_mpu():
7-
mpu = MultiPartUpload("bucket", "file.dat")
8-
assert mpu.bucket == "bucket"
9-
assert mpu.key == "file.dat"
24+
"""Test S3MultiPartUpload class initialization."""
25+
mpu = S3MultiPartUpload("bucket", "file.dat")
26+
if mpu.bucket != "bucket":
27+
raise ValueError("Invalid bucket")
28+
if mpu.key != "file.dat":
29+
raise ValueError("Invalid key")

0 commit comments

Comments
 (0)