Skip to content

Commit de5cae4

Browse files
committed
state_dict support
1 parent 702cd5e commit de5cae4

File tree

4 files changed

+57
-7
lines changed

4 files changed

+57
-7
lines changed

docs/source/FAQ.rst

+24-6
Original file line numberDiff line numberDiff line change
@@ -225,23 +225,41 @@ There is also a :py:class:`tz.m.WrapClosure<torczhero.modules.WrapClosure>` for
225225

226226
How to save/serialize a modular optimizer?
227227
============================================
228-
TODO
228+
Please refer to pytorch docs https://pytorch.org/tutorials/beginner/saving_loading_models.html.
229+
230+
Like pytorch optimizers, torchzero modular optimizers and modules support :code:`opt.state_dict()` and :code:`opt.load_state_dict()`, which saves and loads state dicts of all modules, including nested ones.
231+
232+
So you can use the standard code for saving and loading:
233+
234+
.. code:: python
235+
236+
torch.save({
237+
'model_state_dict': model.state_dict(),
238+
'optimizer_state_dict': optimizer.state_dict(),
239+
...
240+
}, PATH)
241+
242+
model = TheModelClass(*args, **kwargs)
243+
optimizer = tz.Modular(model.parameters(), *modules)
244+
245+
checkpoint = torch.load(PATH, weights_only=True)
246+
model.load_state_dict(checkpoint['model_state_dict'])
247+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
248+
229249
230250
How much overhead does a torchzero modular optimizer have compared to a normal optimizer?
231251
==========================================================================================
232-
A thorough benchmark will be posted to this section very soon. There is no overhead other than what is described below.
233-
234-
Since some optimizers, like Adam, have learning rate baked into the update rule, but we use LR module instead, that requires an extra add operation. Currently if :code:`tz.m.Adam` or :code:`tz.m.Wrap` are directly followed by a :code:`tz.m.LR`, they will be automatically fused (:code:`Wrap` fuses only when wrapped optimizer has an :code:`lr` parameter). However adding LR fusing to all modules with a learning rate is not a priority.
252+
Since some optimizers, like Adam, have learning rate baked into the update rule, but we use LR module instead, that requires an extra add operation. Currently if :code:`tz.m.Adam` or :code:`tz.m.Wrap` are directly followed by a :code:`tz.m.LR`, they will be automatically fused (:code:`Wrap` fuses only when wrapped optimizer has an :code:`lr` parameter) to mitigate that. However adding LR fusing to all modules with a learning rate is not a priority. From what I can tell this overhead is negligible.
235253

236254
Whenever possible I used `_foreach_xxx <https://pytorch.org/docs/stable/torch.html#foreach-operations>`_ operations. Those operate on all parameters at once instead of using a slow python for-loops. This makes the optimizers way quicker, especially with a lot of different parameter tensors. Also all modules change the update in-place whenever possible.
237255

238256
Is there support for complex-valued parameters?
239257
=================================================
240-
Currently no, as I have not made the modules with complex-valued parameters in mind, although some might still work. I do use complex-valued networks so I am looking into adding support. There may actually be a way to support them automatically.
258+
:code:`tz.m.ViewAsReal()` and :code:`tz.m.ViewAsComplex()` modules will be added soon. This will also allow to use custom pytorch optimizers with complex networks (via :code:`tz.m.Wrap`), even if they don't support those natively.
241259

242260
Is there support for optimized parameters being on different devices?
243261
======================================================================
244-
TODO
262+
Maybe, I need to test this.
245263

246264
Is there support for FSDP (FullyShardedDataParallel)?
247265
======================================================

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta"
66
name = "torchzero"
77
description = "Modular optimization library for PyTorch."
88

9-
version = "0.1.7"
9+
version = "0.1.8"
1010
dependencies = [
1111
"torch",
1212
"numpy",

src/torchzero/core/module.py

+16
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,22 @@ def __repr__(self):
212212
if self._initialized: return super().__repr__()
213213
return f"uninitialized {self.__class__.__name__}()"
214214

215+
def state_dict(self):
216+
state_dict = {}
217+
state_dict['__self__'] = super().state_dict()
218+
for k,v in self.children.items():
219+
state_dict[k] = v.state_dict()
220+
return state_dict
221+
222+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
223+
super().load_state_dict(state_dict['__self__'])
224+
for k, v in self.children.items():
225+
if k in state_dict:
226+
v.load_state_dict(state_dict[k])
227+
else:
228+
warnings.warn(f"Tried to load state dict for {k}: {v.__class__.__name__}, but it is not present in state_dict with {list(state_dict.keys()) = }")
229+
230+
215231
def set_params(self, params: ParamsT):
216232
"""
217233
Set parameters to this module. Use this to set per-parameter group settings.

src/torchzero/optim/modular.py

+16
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import warnings
33
from inspect import cleandoc
44
import torch
5+
from typing import Any
56

67
from ..core import OptimizerModule, TensorListOptimizer, OptimizationVars, _Chain, _Chainable
78
from ..utils.python_tools import flatten
@@ -67,6 +68,21 @@ def __init__(self, params, *modules: _Chainable):
6768
for hook in module.post_init_hooks:
6869
hook(self, module)
6970

71+
def state_dict(self):
72+
state_dict = {}
73+
state_dict['__self__'] = super().state_dict()
74+
for i,v in enumerate(self.unrolled_modules):
75+
state_dict[str(i)] = v.state_dict()
76+
return state_dict
77+
78+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
79+
super().load_state_dict(state_dict['__self__'])
80+
for i,v in enumerate(self.unrolled_modules):
81+
if str(i) in state_dict:
82+
v.load_state_dict(state_dict[str(i)])
83+
else:
84+
warnings.warn(f"Tried to load state dict for {i}th module: {v.__class__.__name__}, but it is not present in state_dict with {list(state_dict.keys()) = }")
85+
7086
def get_lr_module(self, last=True) -> OptimizerModule:
7187
"""
7288
Retrieves the module in the chain that controls the learning rate.

0 commit comments

Comments
 (0)