Skip to content

Commit db84474

Browse files
BlaziusMaximusOrbax Authors
authored andcommitted
Add CNS tests for partial saving. Modify async file writes to unlink before writing. Consolidate tests into base test library file.
PiperOrigin-RevId: 875910607
1 parent d45294a commit db84474

File tree

1 file changed

+18
-2
lines changed

1 file changed

+18
-2
lines changed

checkpoint/orbax/checkpoint/_src/path/async_path.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,15 +50,31 @@ def _mkdir_sync(**thread_kwargs):
5050

5151

5252
async def write_bytes(path: epath.Path, data: Any) -> int:
53-
return await asyncio.to_thread(path.write_bytes, data)
53+
54+
def _write():
55+
try:
56+
path.unlink()
57+
except OSError:
58+
pass
59+
return path.write_bytes(data)
60+
61+
return await asyncio.to_thread(_write)
5462

5563

5664
async def read_bytes(path: epath.Path) -> bytes:
5765
return await asyncio.to_thread(path.read_bytes)
5866

5967

6068
async def write_text(path: epath.Path, text: str) -> int:
61-
return await asyncio.to_thread(path.write_text, text)
69+
70+
def _write():
71+
try:
72+
path.unlink()
73+
except OSError:
74+
pass
75+
return path.write_text(text)
76+
77+
return await asyncio.to_thread(_write)
6278

6379

6480
async def read_text(path: epath.Path) -> str:

0 commit comments

Comments
 (0)