Skip to content

Commit 4700a8c

Browse files
committed
release 0.0.1
1 parent 28f3b2f commit 4700a8c

File tree

5 files changed

+225
-0
lines changed

5 files changed

+225
-0
lines changed

.github/workflows/python-publish.yml

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# This workflow will upload a Python Package using Twine when a release is created
2+
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3+
4+
# This workflow uses actions that are not certified by GitHub.
5+
# They are provided by a third-party and are governed by
6+
# separate terms of service, privacy policy, and support
7+
# documentation.
8+
9+
name: Upload Python Package
10+
11+
on:
12+
release:
13+
types: [published]
14+
15+
jobs:
16+
deploy:
17+
18+
runs-on: ubuntu-latest
19+
20+
steps:
21+
- uses: actions/checkout@v2
22+
- name: Set up Python
23+
uses: actions/setup-python@v2
24+
with:
25+
python-version: '3.x'
26+
- name: Install dependencies
27+
run: |
28+
python -m pip install --upgrade pip
29+
pip install build
30+
- name: Build package
31+
run: python -m build
32+
- name: Publish package
33+
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
34+
with:
35+
user: __token__
36+
password: ${{ secrets.PYPI_API_TOKEN }}

README.md

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,41 @@ Implementation of the proposed <a href="https://arxiv.org/abs/2407.05872">Adam-a
44

55
A multi-million dollar paper out of google deepmind basically proposes a small change to Adam (using `atan2`) for greater stability
66

7+
## Install
8+
9+
```bash
10+
$ pip install adam-atan2-pytorch
11+
```
12+
13+
## Usage
14+
15+
```python
16+
# toy model
17+
18+
import torch
19+
from torch import nn
20+
21+
model = nn.Linear(10, 1)
22+
23+
# import AdamAtan2 and instantiate with parameters
24+
25+
from adam_atan2_pytorch import AdamAtan2
26+
27+
opt = AdamAtan2(model.parameters(), lr = 1e-4)
28+
29+
# forward and backwards
30+
31+
for _ in range(100):
32+
loss = model(torch.randn(10))
33+
loss.backward()
34+
35+
# optimizer step
36+
37+
opt.step()
38+
opt.zero_grad()
39+
40+
```
41+
742
## Citations
843

944
```bibtex

adam_atan2_pytorch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from adam_atan2_pytorch.adam_atan2 import AdamAtan2

adam_atan2_pytorch/adam_atan2.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
from __future__ import annotations
2+
from typing import Tuple, Callable
3+
4+
import torch
5+
from torch import atan2, sqrt
6+
from torch.optim.optimizer import Optimizer
7+
8+
# functions
9+
10+
def exists(val):
11+
return val is not None
12+
13+
# class
14+
15+
class AdamAtan2(Optimizer):
16+
def __init__(
17+
self,
18+
params,
19+
lr = 1e-4,
20+
betas: Tuple[float, float] = (0.9, 0.99),
21+
weight_decay = 0.,
22+
a = 1.,
23+
b = 1.
24+
):
25+
assert lr > 0.
26+
assert all([0. <= beta <= 1. for beta in betas])
27+
assert weight_decay >= 0.
28+
29+
self._init_lr = lr
30+
31+
defaults = dict(
32+
lr = lr,
33+
betas = betas,
34+
a = a,
35+
b = b,
36+
weight_decay = weight_decay
37+
)
38+
39+
super().__init__(params, defaults)
40+
41+
@torch.no_grad()
42+
def step(
43+
self,
44+
closure: Callable | None = None
45+
):
46+
47+
loss = None
48+
if exists(closure):
49+
with torch.enable_grad():
50+
loss = closure()
51+
52+
for group in self.param_groups:
53+
for p in filter(lambda p: exists(p.grad), group['params']):
54+
55+
grad, lr, wd, beta1, beta2, a, b, state, init_lr = p.grad, group['lr'], group['weight_decay'], *group['betas'], group['a'], group['b'], self.state[p], self._init_lr
56+
57+
# decoupled weight decay
58+
59+
if wd > 0.:
60+
p.mul_(1. - lr / init_lr * wd)
61+
62+
# init state if needed
63+
64+
if len(state) == 0:
65+
state['steps'] = 0
66+
state['exp_avg'] = torch.zeros_like(grad)
67+
state['exp_avg_sq'] = torch.zeros_like(grad)
68+
69+
# get some of the states
70+
71+
exp_avg, exp_avg_sq, steps = state['exp_avg'], state['exp_avg_sq'], state['steps']
72+
73+
steps += 1
74+
75+
# bias corrections
76+
77+
bias_correct1 = 1. - beta1 ** steps
78+
bias_correct2 = 1. - beta2 ** steps
79+
80+
# decay running averages
81+
82+
exp_avg.lerp_(grad, 1. - beta1)
83+
exp_avg_sq.lerp_(grad * grad, 1. - beta2)
84+
85+
# the following line is the proposed change to the update rule
86+
# using atan2 instead of a division with epsilons - they also suggest hyperparameters `a` and `b` should be explored beyond its default of 1.
87+
88+
update = a * atan2(exp_avg / bias_correct1, b * sqrt(exp_avg_sq / bias_correct2))
89+
90+
p.add_(update, alpha = -lr)
91+
92+
# increment steps
93+
94+
state['steps'] = steps
95+
96+
return loss

pyproject.toml

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
[project]
2+
name = "adam-atan2-pytorch"
3+
version = "0.0.1"
4+
description = "Adam-atan2 for Pytorch"
5+
authors = [
6+
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
7+
]
8+
readme = "README.md"
9+
requires-python = ">= 3.9"
10+
license = { file = "LICENSE" }
11+
keywords = [
12+
'artificial intelligence',
13+
'deep learning',
14+
'adam',
15+
'optimizers'
16+
]
17+
18+
classifiers=[
19+
'Development Status :: 4 - Beta',
20+
'Intended Audience :: Developers',
21+
'Topic :: Scientific/Engineering :: Artificial Intelligence',
22+
'License :: OSI Approved :: MIT License',
23+
'Programming Language :: Python :: 3.9',
24+
]
25+
26+
dependencies = [
27+
"torch>=2.0",
28+
]
29+
30+
[project.urls]
31+
Homepage = "https://pypi.org/project/adam_atan2_pytorch/"
32+
Repository = "https://github.com/lucidrains/adam_atan2_pytorch"
33+
34+
[project.optional-dependencies]
35+
examples = []
36+
test = [
37+
"pytest"
38+
]
39+
40+
[tool.pytest.ini_options]
41+
pythonpath = [
42+
"."
43+
]
44+
45+
[build-system]
46+
requires = ["hatchling"]
47+
build-backend = "hatchling.build"
48+
49+
[tool.rye]
50+
managed = true
51+
dev-dependencies = []
52+
53+
[tool.hatch.metadata]
54+
allow-direct-references = true
55+
56+
[tool.hatch.build.targets.wheel]
57+
packages = ["adam_atan2_pytorch"]

0 commit comments

Comments
 (0)