Skip to content

Commit e1b6187

Browse files
committed
Feat: Refactor upload method and reduce duplication
- Restored type hints in `upload` for improved type safety. - Added `get_mpu_kwargs` to centralize shared keyword arguments. - Simplified `upload` and `mpu_upload` implementations by reusing `get_mpu_kwargs`. - Reduced code duplication across `_mpu.py` and `_multipart.py`.
1 parent 3b0223c commit e1b6187

File tree

5 files changed

+79
-81
lines changed

5 files changed

+79
-81
lines changed

odc/geo/cog/_az.py

Lines changed: 4 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
11
import base64
2-
from typing import Any, Union
2+
from typing import Any
33

4-
import dask
5-
from dask.delayed import Delayed
6-
7-
from ._mpu import mpu_write
84
from ._multipart import MultiPartUploadBase
95

106

@@ -132,41 +128,9 @@ def writer(self, kw: dict[str, Any], *, client: Any = None):
132128
"""
133129
return DelayedAzureWriter(self, kw)
134130

135-
def upload(
136-
self,
137-
chunks: Union[dask.bag.Bag, list[dask.bag.Bag]],
138-
*,
139-
mk_header: Any = None,
140-
mk_footer: Any = None,
141-
user_kw: dict[str, Any] | None = None,
142-
writes_per_chunk: int = 1,
143-
spill_sz: int = 20 * (1 << 20),
144-
client: Any = None,
145-
**kw,
146-
) -> Delayed:
147-
"""
148-
Upload chunks to Azure Blob Storage with multipart uploads.
149-
150-
:param chunks: Dask bag of chunks to upload.
151-
:param mk_header: Function to create header data.
152-
:param mk_footer: Function to create footer data.
153-
:param user_kw: User-provided metadata for the upload.
154-
:param writes_per_chunk: Number of writes per chunk.
155-
:param spill_sz: Spill size for buffering data.
156-
:param client: Dask client for distributed execution.
157-
:return: A Dask delayed object representing the finalised upload.
158-
"""
159-
write = self.writer(kw, client=client) if spill_sz else None
160-
return mpu_write(
161-
chunks,
162-
write,
163-
mk_header=mk_header,
164-
mk_footer=mk_footer,
165-
user_kw=user_kw,
166-
writes_per_chunk=writes_per_chunk,
167-
spill_sz=spill_sz,
168-
dask_name_prefix="azure-finalise",
169-
)
131+
def dask_name_prefix(self) -> str:
132+
"""Return the Dask name prefix for Azure."""
133+
return "azure-finalise"
170134

171135

172136
class DelayedAzureWriter(AzureLimits):

odc/geo/cog/_mpu.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,3 +495,49 @@ def _finalizer_dask_op(
495495

496496
_, rr = _root.flush(write, leftPartId=1, finalise=True)
497497
return rr
498+
499+
500+
def get_mpu_kwargs(
501+
mk_header=None,
502+
mk_footer=None,
503+
user_kw=None,
504+
writes_per_chunk=1,
505+
spill_sz=20 * (1 << 20),
506+
client=None,
507+
) -> dict:
508+
"""
509+
Construct shared keyword arguments for multipart uploads.
510+
"""
511+
return {
512+
"mk_header": mk_header,
513+
"mk_footer": mk_footer,
514+
"user_kw": user_kw,
515+
"writes_per_chunk": writes_per_chunk,
516+
"spill_sz": spill_sz,
517+
"client": client,
518+
}
519+
520+
521+
def mpu_upload(
522+
chunks: Union[dask.bag.Bag, list[dask.bag.Bag]],
523+
*,
524+
writer: Any,
525+
dask_name_prefix: str,
526+
**kw,
527+
) -> "Delayed":
528+
"""Shared logic for multipart uploads to storage services."""
529+
client = kw.pop("client", None)
530+
writer_kw = dict(kw)
531+
if client is not None:
532+
writer_kw["client"] = client
533+
spill_sz = kw.get("spill_sz", 20 * (1 << 20))
534+
if spill_sz:
535+
write = writer(writer_kw)
536+
else:
537+
write = None
538+
return mpu_write(
539+
chunks,
540+
write,
541+
dask_name_prefix=dask_name_prefix,
542+
**kw, # everything else remains
543+
)

odc/geo/cog/_multipart.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
from abc import ABC, abstractmethod
1010
from typing import Any, Union, TYPE_CHECKING
1111

12+
from dask.delayed import Delayed
13+
from ._mpu import get_mpu_kwargs, mpu_upload
14+
1215
if TYPE_CHECKING:
1316
# pylint: disable=import-outside-toplevel,import-error
1417
import dask.bag
@@ -53,6 +56,9 @@ def writer(self, kw: dict[str, Any], *, client: Any = None) -> Any:
5356
"""
5457

5558
@abstractmethod
59+
def dask_name_prefix(self) -> str:
60+
"""Return the dask name prefix specific to the backend."""
61+
5662
def upload(
5763
self,
5864
chunks: Union["dask.bag.Bag", list["dask.bag.Bag"]],
@@ -63,17 +69,19 @@ def upload(
6369
writes_per_chunk: int = 1,
6470
spill_sz: int = 20 * (1 << 20),
6571
client: Any = None,
66-
**kw,
67-
) -> Any:
68-
"""
69-
Orchestrate the upload process with multipart uploads.
70-
71-
:param chunks: Dask bag of chunks to upload.
72-
:param mk_header: Function to create header data.
73-
:param mk_footer: Function to create footer data.
74-
:param user_kw: User-provided metadata for the upload.
75-
:param writes_per_chunk: Number of writes per chunk.
76-
:param spill_sz: Spill size for buffering data.
77-
:param client: Dask client for distributed execution.
78-
:return: A Dask delayed object representing the finalised upload.
79-
"""
72+
) -> Delayed:
73+
"""High-level upload that calls mpu_upload under the hood."""
74+
kwargs = get_mpu_kwargs(
75+
mk_header=mk_header,
76+
mk_footer=mk_footer,
77+
user_kw=user_kw,
78+
writes_per_chunk=writes_per_chunk,
79+
spill_sz=spill_sz,
80+
client=client,
81+
)
82+
return mpu_upload(
83+
chunks,
84+
writer=self.writer,
85+
dask_name_prefix=self.dask_name_prefix(),
86+
**kwargs,
87+
)

odc/geo/cog/_s3.py

Lines changed: 4 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from cachetools import cached
1111

12-
from ._mpu import PartsWriter, SomeData, mpu_write
12+
from ._mpu import PartsWriter, SomeData
1313
from ._multipart import MultiPartUploadBase
1414

1515
if TYPE_CHECKING:
@@ -197,30 +197,9 @@ def writer(self, kw, *, client: Any = None) -> PartsWriter:
197197
writer.prep_client(client)
198198
return writer
199199

200-
def upload(
201-
self,
202-
chunks: "dask.bag.Bag" | list["dask.bag.Bag"],
203-
*,
204-
mk_header: Any = None,
205-
mk_footer: Any = None,
206-
user_kw: dict[str, Any] | None = None,
207-
writes_per_chunk: int = 1,
208-
spill_sz: int = 20 * (1 << 20),
209-
client: Any = None,
210-
**kw,
211-
) -> "Delayed":
212-
"""Upload chunks to S3 with multipart uploads."""
213-
write = self.writer(kw, client=client) if spill_sz else None
214-
return mpu_write(
215-
chunks,
216-
write,
217-
mk_header=mk_header,
218-
mk_footer=mk_footer,
219-
user_kw=user_kw,
220-
writes_per_chunk=writes_per_chunk,
221-
spill_sz=spill_sz,
222-
dask_name_prefix="s3finalise",
223-
)
200+
def dask_name_prefix(self) -> str:
201+
"""Return the Dask name prefix for S3."""
202+
return "s3finalise"
224203

225204

226205
def _safe_get(v, timeout=0.1):

odc/geo/geom.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,6 @@ def boundary(self, pts_per_side: int = 2) -> "Geometry":
320320
self.crs,
321321
)
322322

323-
324323
def qr2sample(
325324
self,
326325
n: int,
@@ -358,7 +357,8 @@ def qr2sample(
358357
n_side = int(numpy.round(sample_density * min(nx, ny))) + 1
359358
n_side = max(2, n_side)
360359
edge_pts = [
361-
(float(ep[0]), float(ep[1])) for ep in list(self.boundary(n_side).coords[:-1])
360+
(float(ep[0]), float(ep[1]))
361+
for ep in list(self.boundary(n_side).coords[:-1])
362362
]
363363
if padding is None:
364364
padding = 0.3 * min(nx, ny) / (n_side - 1)
@@ -377,6 +377,7 @@ def qr2sample(
377377

378378
return multipoint(coords, self.crs)
379379

380+
380381
def wrap_shapely(method):
381382
"""
382383
Takes a method that expects shapely geometry arguments and converts it to a method that operates

0 commit comments

Comments
 (0)