Skip to content

Commit 6cbcb87

Browse files
style(all): reformat all the files (#10)
1 parent 465358e commit 6cbcb87

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+1154
-785
lines changed

Makefile

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
print-% : ; @echo $* = $($*)
2+
SHELL = /bin/bash
3+
PROJECT_NAME = TorchOpt
4+
PYTHON_FILES = $(shell find . -type f -name "*.py")
5+
CPP_FILES = $(shell find . -type f -name "*.h" -o -name "*.cpp")
6+
COMMIT_HASH = $(shell git log -1 --format=%h)
7+
8+
9+
# installation
10+
11+
check_install = python3 -c "import $(1)" || (cd && pip3 install $(1) --upgrade && cd -)
12+
check_install_extra = python3 -c "import $(1)" || (cd && pip3 install $(2) --upgrade && cd -)
13+
14+
15+
flake8-install:
16+
$(call check_install, flake8)
17+
$(call check_install_extra, bugbear, flake8_bugbear)
18+
19+
py-format-install:
20+
$(call check_install, isort)
21+
$(call check_install, yapf)
22+
23+
mypy-install:
24+
$(call check_install, mypy)
25+
26+
cpplint-install:
27+
$(call check_install, cpplint)
28+
29+
clang-format-install:
30+
command -v clang-format-11 || sudo apt-get install -y clang-format-11
31+
32+
clang-tidy-install:
33+
command -v clang-tidy || sudo apt-get install -y clang-tidy
34+
35+
36+
doc-install:
37+
$(call check_install, pydocstyle)
38+
$(call check_install, doc8)
39+
$(call check_install, sphinx)
40+
$(call check_install, sphinx_rtd_theme)
41+
$(call check_install_extra, sphinxcontrib.spelling, sphinxcontrib.spelling pyenchant)
42+
43+
# python linter
44+
45+
flake8: flake8-install
46+
flake8 $(PYTHON_FILES) --count --show-source --statistics
47+
48+
py-format: py-format-install
49+
isort --check $(PYTHON_FILES) && yapf -r -d $(PYTHON_FILES)
50+
51+
mypy: mypy-install
52+
mypy $(PROJECT_NAME)
53+
54+
# c++ linter
55+
56+
cpplint: cpplint-install
57+
cpplint $(CPP_FILES)
58+
59+
clang-format: clang-format-install
60+
clang-format-11 --style=file -i $(CPP_FILES) -n --Werror
61+
62+
# documentation
63+
64+
docstyle: doc-install
65+
pydocstyle $(PROJECT_NAME) && doc8 docs && cd docs && make html SPHINXOPTS="-W"
66+
67+
doc: doc-install
68+
cd docs && make html && cd _build/html && python3 -m http.server
69+
70+
spelling: doc-install
71+
cd docs && make spelling SPHINXOPTS="-W"
72+
73+
doc-clean:
74+
cd docs && make clean
75+
76+
lint: flake8 py-format clang-format cpplint mypy docstyle spelling
77+
78+
format: py-format-install clang-format-install
79+
isort $(PYTHON_FILES)
80+
yapf -ir $(PYTHON_FILES)
81+
clang-format-11 -style=file -i $(CPP_FILES)
82+
83+

TorchOpt/__init__.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,12 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515

16-
from ._src import combine
17-
from ._src import clip
18-
from ._src import visual
19-
from ._src import hook
20-
from ._src import schedule
21-
from ._src.MetaOptimizer import MetaOptimizer, MetaSGD, MetaAdam, MetaRMSProp
22-
from ._src.Optimizer import Optimizer, SGD, Adam, RMSProp
16+
from ._src import (accelerated_op_available, clip, combine, hook, schedule,
17+
visual)
18+
from ._src.alias import adam, rmsprop, sgd
19+
from ._src.MetaOptimizer import MetaAdam, MetaOptimizer, MetaRMSProp, MetaSGD
20+
from ._src.Optimizer import SGD, Adam, Optimizer, RMSProp
2321
from ._src.update import apply_updates
24-
from ._src.alias import sgd, adam, rmsprop
25-
from ._src.utils import stop_gradient, extract_state_dict, recover_state_dict
26-
from ._src import accelerated_op_available
22+
from ._src.utils import extract_state_dict, recover_state_dict, stop_gradient
23+
2724
__version__ = "0.4.1"

TorchOpt/_lib/adam_op.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,30 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515

16-
def forward_(updates, mu, nu, lr, b1, b2, eps, eps_root, count): ...
1716

17+
def forward_(updates, mu, nu, lr, b1, b2, eps, eps_root, count):
18+
...
1819

19-
def forwardMu(updates, mu, b1): ...
2020

21+
def forwardMu(updates, mu, b1):
22+
...
2123

22-
def forwardNu(updates, nu, b2): ...
2324

25+
def forwardNu(updates, nu, b2):
26+
...
2427

25-
def forwardUpdates(new_mu, new_nu, lr, b1, b2, eps, eps_root, count): ...
2628

29+
def forwardUpdates(new_mu, new_nu, lr, b1, b2, eps, eps_root, count):
30+
...
2731

28-
def backwardMu(dmu, updates, mu, b1): ...
2932

33+
def backwardMu(dmu, updates, mu, b1):
34+
...
3035

31-
def backwardNu(dnu, updates, nu, b2): ...
3236

37+
def backwardNu(dnu, updates, nu, b2):
38+
...
3339

34-
def backwardUpdates(dupdates, updates, new_mu, new_nu, lr, b1, b2, count): ...
40+
41+
def backwardUpdates(dupdates, updates, new_mu, new_nu, lr, b1, b2, count):
42+
...

TorchOpt/_src/MetaOptimizer.py

Lines changed: 37 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,18 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515

16+
import jax
1617
import torch
1718
from torch import nn
18-
import jax
1919

2020
import TorchOpt
21-
from TorchOpt._src.alias import sgd, adam, rmsprop
2221
from TorchOpt._src import base
22+
from TorchOpt._src.alias import adam, rmsprop, sgd
2323
from TorchOpt._src.pytypes import ScalarOrSchedule
2424

2525

2626
class MetaOptimizer(object):
2727
"""A high-level optimizer base class for meta learning."""
28-
2928
def __init__(self, net: nn.Module, impl: base.GradientTransformation):
3029
"""
3130
Args:
@@ -51,18 +50,23 @@ def step(self, loss: torch.Tensor):
5150
loss (torch.Tensor): the loss that is used to compute the gradients to the network parameters.
5251
"""
5352
# step parameter only
54-
for idx, (state, param_containers) in enumerate(zip(self.state_groups, self.param_containers_groups)):
53+
for idx, (state, param_containers) in enumerate(
54+
zip(self.state_groups, self.param_containers_groups)):
5555
flatten_params, containers_tree = jax.tree_util.tree_flatten(
5656
param_containers)
5757
flatten_params = tuple(flatten_params)
58-
grad = torch.autograd.grad(
59-
loss, flatten_params, create_graph=True, allow_unused=True)
58+
grad = torch.autograd.grad(loss,
59+
flatten_params,
60+
create_graph=True,
61+
allow_unused=True)
6062
updates, state = self.impl.update(grad, state, False)
6163
self.state_groups[idx] = state
62-
new_params = TorchOpt.apply_updates(
63-
flatten_params, updates, inplace=False)
64+
new_params = TorchOpt.apply_updates(flatten_params,
65+
updates,
66+
inplace=False)
6467
unflatten_new_params = containers_tree.unflatten(new_params)
65-
for (container, unflatten_param) in zip(param_containers, unflatten_new_params):
68+
for (container, unflatten_param) in zip(param_containers,
69+
unflatten_new_params):
6670
container.update(unflatten_param)
6771

6872
def add_param_group(self, net):
@@ -89,7 +93,6 @@ def load_state_dict(self, state_dict):
8993

9094
class MetaSGD(MetaOptimizer):
9195
"""A canonical Stochastic Gradient Descent optimiser."""
92-
9396
def __init__(self,
9497
net,
9598
lr: ScalarOrSchedule,
@@ -102,17 +105,16 @@ def __init__(self,
102105
args: other arguments see `alias.sgd`, here we set `moment_requires_grad=True`
103106
to make tensors like momentum be differentiable.
104107
"""
105-
super().__init__(net,
106-
sgd(lr=lr,
107-
momentum=momentum,
108-
nesterov=nesterov,
109-
moment_requires_grad=moment_requires_grad)
110-
)
108+
super().__init__(
109+
net,
110+
sgd(lr=lr,
111+
momentum=momentum,
112+
nesterov=nesterov,
113+
moment_requires_grad=moment_requires_grad))
111114

112115

113116
class MetaAdam(MetaOptimizer):
114117
"""The classic Adam optimiser."""
115-
116118
def __init__(self,
117119
net,
118120
lr: ScalarOrSchedule,
@@ -128,20 +130,19 @@ def __init__(self,
128130
args: other arguments see `alias.adam`, here we set `moment_requires_grad=True`
129131
to make tensors like momentum be differentiable.
130132
"""
131-
super().__init__(net,
132-
adam(lr=lr,
133-
b1=b1,
134-
b2=b2,
135-
eps=eps,
136-
eps_root=eps_root,
137-
moment_requires_grad=moment_requires_grad,
138-
use_accelerated_op=use_accelerated_op)
139-
)
133+
super().__init__(
134+
net,
135+
adam(lr=lr,
136+
b1=b1,
137+
b2=b2,
138+
eps=eps,
139+
eps_root=eps_root,
140+
moment_requires_grad=moment_requires_grad,
141+
use_accelerated_op=use_accelerated_op))
140142

141143

142144
class MetaRMSProp(MetaOptimizer):
143145
"""The classic RMSProp optimiser."""
144-
145146
def __init__(self,
146147
net,
147148
lr: ScalarOrSchedule,
@@ -157,10 +158,12 @@ def __init__(self,
157158
args: other arguments see `alias.adam`, here we set `moment_requires_grad=True`
158159
to make tensors like momentum be differentiable.
159160
"""
160-
super().__init__(net, rmsprop(lr=lr,
161-
decay=decay,
162-
eps=eps,
163-
initial_scale=initial_scale,
164-
centered=centered,
165-
momentum=momentum,
166-
nesterov=nesterov))
161+
super().__init__(
162+
net,
163+
rmsprop(lr=lr,
164+
decay=decay,
165+
eps=eps,
166+
initial_scale=initial_scale,
167+
centered=centered,
168+
momentum=momentum,
169+
nesterov=nesterov))

0 commit comments

Comments
 (0)