Skip to content

Commit c478127

Browse files
committed
adam fix
1 parent de44dff commit c478127

File tree

3 files changed

+8
-7
lines changed

3 files changed

+8
-7
lines changed

docs/source/implementing.rst

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ Here is a ready to use Adam implementation through overwriting :code:`_single_te
8282
bias_correction1 = 1 - beta1**step
8383
bias_correction2 = 1 - beta2**step
8484
85-
denom = exp_avg_sq.sqrt().div_(bias_correction2**0.5 + eps)
85+
denom = exp_avg_sq.sqrt().div_(bias_correction2**0.5) + eps
8686
8787
state['step'] += 1
8888
@@ -144,7 +144,8 @@ Here is a ready to use Adam implementation through overwriting :code:`_update` u
144144
bias_correction2 = [1 - i**self.current_step for i in beta2]
145145
146146
denom = torch._foreach_sqrt(exp_avg_sq)
147-
torch._foreach_div_(denom, [c ** 0.5 + e for c, e in zip(bias_correction2, eps)])
147+
torch._foreach_div_(denom, [c ** 0.5 for c in bias_correction2])
148+
torch._foreach_add_(denom, eps)
148149
149150
ret = torch._foreach_div(exp_avg, denom)
150151
torch._foreach_mul_(ret, [a/d for a,d in zip(alpha, bias_correction1)])
@@ -159,6 +160,6 @@ Method 3. Overwriting step
159160
+++++++++++++++++++++++++++++++++++++++++++++
160161
:code:`step` method gives you the most control, but it requires the most understanding of the internals of torchzero. You can reevaluate the closure multiple times which is usually necessary for line searches and gradient approximation. You can step with multiple modules, skip an update, update parameters directly, basically anything is possible.
161162

162-
There are also helper classes: :py:mod:`GradientApproximatorBase<tz.modules.gradient_approximation.GradientApproximatorBase>` allows you to define a gradient approximation module in a more convenient way by overwriting :code:`_make_ascent` method. :py:mod:`GradientApproximatorBase<tz.modules.line_search.LineSearchBase>` is an easy way to define line searches by overwriting :code:`_find_best_lr`. I will be making a tutorial on those soon.
163+
There are also helper classes: :py:mod:`GradientApproximatorBase<tz.modules.gradient_approximation.GradientApproximatorBase>` allows you to define a gradient approximation module in a more convenient way by overwriting :code:`_make_ascent` method. :py:mod:`GradientApproximatorBase<tz.modules.line_search.LineSearchBase>` is an easy way to define line searches by overwriting :code:`_find_best_lr`.
163164

164-
WIP
165+
This section is WIP

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta"
66
name = "torchzero"
77
description = "Modular optimization library for PyTorch."
88

9-
version = "0.1.4"
9+
version = "0.1.5"
1010
dependencies = [
1111
"torch",
1212
"numpy",

src/torchzero/modules/optimizers/adam.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ def _adam_step(ascent: TensorList, exp_avg: TensorList, exp_avg_sq: TensorList,
1515

1616
if max_exp_avg_sqs is not None:
1717
max_exp_avg_sqs.maximum_(exp_avg_sq)
18-
denom = max_exp_avg_sqs.sqrt().div_(bias_correction2**0.5 + eps)
18+
denom = max_exp_avg_sqs.sqrt().div_(bias_correction2**0.5).add_(eps)
1919
else:
20-
denom = exp_avg_sq.sqrt().div_(bias_correction2**0.5 + eps)
20+
denom = exp_avg_sq.sqrt().div_(bias_correction2**0.5).add_(eps)
2121

2222
if params is None:
2323
return (exp_avg / denom).mul_(alpha / bias_correction1)

0 commit comments

Comments
 (0)