1
- """Tests for the Azure MultiPartUpload class."""
1
+ """Tests for the Azure AzMultiPartUpload class."""
2
2
3
+ import base64
3
4
import unittest
4
5
from unittest .mock import MagicMock , patch
5
6
6
- from odc .geo .cog ._az import MultiPartUpload
7
+ # Conditional import for Azure support
8
+ try :
9
+ from odc .geo .cog ._az import AzMultiPartUpload
7
10
11
+ HAVE_AZURE = True
12
+ except ImportError :
13
+ AzMultiPartUpload = None
14
+ HAVE_AZURE = False
8
15
9
- def test_mpu_init ():
10
- """Basic test for the MultiPartUpload class."""
11
- account_url = "https://account_name.blob.core.windows.net"
12
- mpu = MultiPartUpload (account_url , "container" , "some.blob" , None )
13
- if mpu .account_url != account_url :
14
- raise AssertionError (f"mpu.account_url should be '{ account_url } '." )
15
- if mpu .container != "container" :
16
- raise AssertionError ("mpu.container should be 'container'." )
17
- if mpu .blob != "some.blob" :
18
- raise AssertionError ("mpu.blob should be 'some.blob'." )
19
- if mpu .credential is not None :
20
- raise AssertionError ("mpu.credential should be 'None'." )
21
16
17
+ def require_azure (test_func ):
18
+ """Decorator to skip tests if Azure dependencies are not installed."""
19
+ return unittest .skipUnless (HAVE_AZURE , "Azure dependencies are not installed" )(
20
+ test_func
21
+ )
22
22
23
- class TestMultiPartUpload (unittest .TestCase ):
24
- """Test the MultiPartUpload class."""
25
23
24
+ class TestAzMultiPartUpload (unittest .TestCase ):
25
+ """Test the AzMultiPartUpload class."""
26
+
27
+ @require_azure
28
+ def test_mpu_init (self ):
29
+ """Basic test for AzMultiPartUpload initialization."""
30
+ account_url = "https://account_name.blob.core.windows.net"
31
+ mpu = AzMultiPartUpload (account_url , "container" , "some.blob" , None )
32
+
33
+ self .assertEqual (mpu .account_url , account_url )
34
+ self .assertEqual (mpu .container , "container" )
35
+ self .assertEqual (mpu .blob , "some.blob" )
36
+ self .assertIsNone (mpu .credential )
37
+
38
+ @require_azure
26
39
@patch ("odc.geo.cog._az.BlobServiceClient" )
27
40
def test_azure_multipart_upload (self , mock_blob_service_client ):
28
- """Test the MultiPartUpload class."""
29
- # Arrange - mock the Azure Blob SDK
30
- # Mock the blob client and its methods
41
+ """Test the full Azure AzMultiPartUpload functionality."""
42
+ # Arrange - Mock Azure Blob SDK client structure
31
43
mock_blob_client = MagicMock ()
32
44
mock_container_client = MagicMock ()
33
- mcc = mock_container_client
34
- mock_blob_service_client .return_value .get_container_client .return_value = mcc
45
+ mock_blob_service_client .return_value .get_container_client .return_value = (
46
+ mock_container_client
47
+ )
35
48
mock_container_client .get_blob_client .return_value = mock_blob_client
36
49
37
50
# Simulate return values for Azure Blob SDK methods
@@ -43,32 +56,41 @@ def test_azure_multipart_upload(self, mock_blob_service_client):
43
56
blob = "mock-blob"
44
57
credential = "mock-sas-token"
45
58
46
- # Act - create an instance of MultiPartUpload and call its methods
47
- azure_upload = MultiPartUpload (account_url , container , blob , credential )
59
+ # Act
60
+ azure_upload = AzMultiPartUpload (account_url , container , blob , credential )
48
61
upload_id = azure_upload .initiate ()
49
62
part1 = azure_upload .write_part (1 , b"first chunk of data" )
50
63
part2 = azure_upload .write_part (2 , b"second chunk of data" )
51
64
etag = azure_upload .finalise ([part1 , part2 ])
52
65
53
- # Assert - check the results
54
- # Check that the initiate method behaves as expected
66
+ # Correctly calculate block IDs
67
+ block_id1 = base64 .b64encode (b"block-1" ).decode ("utf-8" )
68
+ block_id2 = base64 .b64encode (b"block-2" ).decode ("utf-8" )
69
+
70
+ # Assert
55
71
self .assertEqual (upload_id , "azure-block-upload" )
72
+ self .assertEqual (etag , "mock-etag" )
56
73
57
- # Verify the calls to Azure Blob SDK methods
74
+ # Verify BlobServiceClient instantiation
58
75
mock_blob_service_client .assert_called_once_with (
59
76
account_url = account_url , credential = credential
60
77
)
78
+
79
+ # Verify stage_block calls
61
80
mock_blob_client .stage_block .assert_any_call (
62
- part1 [ "BlockId" ], b"first chunk of data"
81
+ block_id = block_id1 , data = b"first chunk of data"
63
82
)
64
83
mock_blob_client .stage_block .assert_any_call (
65
- part2 [ "BlockId" ], b"second chunk of data"
84
+ block_id = block_id2 , data = b"second chunk of data"
66
85
)
67
- mock_blob_client .commit_block_list .assert_called_once ()
68
- self .assertEqual (etag , "mock-etag" )
69
86
70
- # Verify block list passed during finalise
87
+ # Verify commit_block_list was called correctly
71
88
block_list = mock_blob_client .commit_block_list .call_args [0 ][0 ]
72
89
self .assertEqual (len (block_list ), 2 )
73
- self .assertEqual (block_list [0 ].id , part1 ["BlockId" ])
74
- self .assertEqual (block_list [1 ].id , part2 ["BlockId" ])
90
+ self .assertEqual (block_list [0 ].id , block_id1 )
91
+ self .assertEqual (block_list [1 ].id , block_id2 )
92
+ mock_blob_client .commit_block_list .assert_called_once ()
93
+
94
+
95
+ if __name__ == "__main__" :
96
+ unittest .main ()
0 commit comments