Skip to content

Commit f2e31c7

Browse files
angel-coreOrbax Authors
authored andcommitted
Create v1 ocp.load_checkpointables backwards compatibility tests against static v0 and v1 checkpoints.
PiperOrigin-RevId: 875725609
1 parent 74e1b99 commit f2e31c7

File tree

523 files changed

+2441
-306
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

523 files changed

+2441
-306
lines changed

checkpoint/orbax/checkpoint/experimental/v1/_src/context/options.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -380,9 +380,9 @@ def create_with_handlers(
380380
) -> CheckpointablesOptions:
381381
registry = registration.local_registry(include_global_registry=True)
382382
for handler in handlers:
383-
registry.add(handler, None)
383+
registry.add(handler, checkpointable_name=None)
384384
for name, handler in named_handlers.items():
385-
registry.add(handler, name)
385+
registry.add(handler, checkpointable_name=name)
386386
return cls(registry=registry)
387387

388388

checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/global_registration.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
given checkpointable will be used.
2121
"""
2222

23-
from typing import Type
23+
from typing import Sequence, Type
2424

2525
from orbax.checkpoint.experimental.v1._src.handlers import json_handler
2626
from orbax.checkpoint.experimental.v1._src.handlers import proto_handler
@@ -34,23 +34,45 @@
3434
def _try_register_handler(
3535
handler_type: Type[handler_types.CheckpointableHandler],
3636
name: str | None = None,
37+
secondary_typestrs: Sequence[str] | None = None,
3738
):
39+
"""Tries to register handler globally with name and secondary typestrs."""
3840
try:
39-
registration.global_registry().add(handler_type, name)
41+
registration.global_registry().add(
42+
handler_type,
43+
checkpointable_name=name,
44+
secondary_typestrs=secondary_typestrs,
45+
)
4046
except registration.AlreadyExistsError:
4147
pass
4248

4349

44-
_try_register_handler(proto_handler.ProtoHandler)
45-
_try_register_handler(json_handler.JsonHandler)
50+
_try_register_handler(
51+
proto_handler.ProtoHandler,
52+
secondary_typestrs=[
53+
'orbax.checkpoint._src.handlers.proto_checkpoint_handler.ProtoCheckpointHandler',
54+
],
55+
)
56+
_try_register_handler(
57+
json_handler.JsonHandler,
58+
secondary_typestrs=[
59+
'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler',
60+
],
61+
)
4662
_try_register_handler(
4763
stateful_checkpointable_handler.StatefulCheckpointableHandler
4864
)
4965
_try_register_handler(
5066
json_handler.MetricsHandler,
5167
checkpoint_layout.METRICS_CHECKPOINTABLE_KEY,
5268
)
53-
_try_register_handler(pytree_handler.PyTreeHandler)
69+
_try_register_handler(
70+
pytree_handler.PyTreeHandler,
71+
secondary_typestrs=[
72+
'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler',
73+
'orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardCheckpointHandler',
74+
],
75+
)
5476
_try_register_handler(
5577
pytree_handler.PyTreeHandler, checkpoint_layout.PYTREE_CHECKPOINTABLE_KEY
5678
)

0 commit comments

Comments
 (0)