Skip to content

Commit c1f34d1

Browse files
authored
Merge pull request #1463 from lrzpellegrini/wandb_core_fixes
Various fixes and improvements
2 parents 435b40d + abde4c2 commit c1f34d1

File tree

9 files changed

+143
-48
lines changed

9 files changed

+143
-48
lines changed

avalanche/benchmarks/classic/core50.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,20 @@ def CORe50(
156156
eval_transform=eval_transform,
157157
)
158158

159+
if scenario == "nc":
160+
n_classes_per_exp = []
161+
classes_order = []
162+
for exp in benchmark_obj.train_stream:
163+
exp_dataset = exp.dataset
164+
unique_targets = list(
165+
sorted(set(int(x) for x in exp_dataset.targets)) # type: ignore
166+
)
167+
n_classes_per_exp.append(len(unique_targets))
168+
classes_order.extend(unique_targets)
169+
setattr(benchmark_obj, "n_classes_per_exp", n_classes_per_exp)
170+
setattr(benchmark_obj, "classes_order", classes_order)
171+
setattr(benchmark_obj, "n_classes", 50 if object_lvl else 10)
172+
159173
return benchmark_obj
160174

161175

avalanche/evaluation/metrics/checkpoint.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,11 @@
1010
################################################################################
1111

1212
import copy
13-
from typing import TYPE_CHECKING
13+
import io
14+
from typing import TYPE_CHECKING, Optional
1415

1516
from torch import Tensor
17+
import torch
1618

1719
from avalanche.evaluation import PluginMetric
1820
from avalanche.evaluation.metric_results import MetricValue, MetricResult
@@ -46,9 +48,9 @@ def __init__(self):
4648
retrieved using the `result` method.
4749
"""
4850
super().__init__()
49-
self.weights = None
51+
self.weights: Optional[bytes] = None
5052

51-
def update(self, weights) -> Tensor:
53+
def update(self, weights: bytes):
5254
"""
5355
Update the weight checkpoint at the current experience.
5456
@@ -57,7 +59,7 @@ def update(self, weights) -> Tensor:
5759
"""
5860
self.weights = weights
5961

60-
def result(self) -> Tensor:
62+
def result(self) -> Optional[bytes]:
6163
"""
6264
Retrieves the weight checkpoint at the current experience.
6365
@@ -75,6 +77,9 @@ def reset(self) -> None:
7577

7678
def _package_result(self, strategy) -> "MetricResult":
7779
weights = self.result()
80+
if weights is None:
81+
return None
82+
7883
metric_name = get_metric_name(
7984
self, strategy, add_experience=True, add_task=False
8085
)
@@ -83,9 +88,13 @@ def _package_result(self, strategy) -> "MetricResult":
8388
]
8489

8590
def after_training_exp(self, strategy: "SupervisedTemplate") -> "MetricResult":
86-
model_params = copy.deepcopy(strategy.model.parameters())
87-
self.update(model_params)
88-
return None
91+
buff = io.BytesIO()
92+
model_params = copy.deepcopy(strategy.model).to("cpu")
93+
torch.save(model_params, buff)
94+
buff.seek(0)
95+
self.update(buff.read())
96+
97+
return self._package_result(strategy)
8998

9099
def __str__(self):
91100
return "WeightCheckpoint"

avalanche/logging/text_logging.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
# Website: avalanche.continualai.org #
1010
################################################################################
1111
import datetime
12-
import os.path
1312
import sys
1413
import warnings
1514
from typing import List, TYPE_CHECKING, Tuple, Type, Optional, TextIO
@@ -24,7 +23,10 @@
2423
if TYPE_CHECKING:
2524
from avalanche.training.templates import SupervisedTemplate
2625

27-
UNSUPPORTED_TYPES: Tuple[Type] = (TensorImage,)
26+
UNSUPPORTED_TYPES: Tuple[Type, ...] = (
27+
TensorImage,
28+
bytes,
29+
)
2830

2931

3032
class TextLogger(BaseLogger, SupervisedPlugin):

avalanche/logging/wandb_logger.py

Lines changed: 65 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,21 @@
44
# See the accompanying LICENSE file for terms. #
55
# #
66
# Date: 25-11-2020 #
7-
# Author(s): Diganta Misra, Andrea Cossu #
7+
# Author(s): Diganta Misra, Andrea Cossu, Lorenzo Pellegrini #
88
# E-mail: contact@continualai.org #
99
# Website: www.continualai.org #
1010
################################################################################
1111
""" This module handles all the functionalities related to the logging of
1212
Avalanche experiments using Weights & Biases. """
1313

14-
from typing import Union, List, TYPE_CHECKING
14+
import re
15+
from typing import Optional, Union, List, TYPE_CHECKING
1516
from pathlib import Path
1617
import os
17-
import errno
18+
import warnings
1819

1920
import numpy as np
2021
from numpy import array
21-
import torch
2222
from torch import Tensor
2323

2424
from PIL.Image import Image
@@ -37,6 +37,12 @@
3737
from avalanche.training.templates import SupervisedTemplate
3838

3939

40+
CHECKPOINT_METRIC_NAME = re.compile(
41+
r"^WeightCheckpoint\/(?P<phase_name>\S+)_phase\/(?P<stream_name>\S+)_"
42+
r"stream(\/Task(?P<task_id>\d+))?\/Exp(?P<experience_id>\d+)$"
43+
)
44+
45+
4046
class WandBLogger(BaseLogger, SupervisedPlugin):
4147
"""Weights and Biases logger.
4248
@@ -60,18 +66,21 @@ def __init__(
6066
run_name: str = "Test",
6167
log_artifacts: bool = False,
6268
path: Union[str, Path] = "Checkpoints",
63-
uri: str = None,
69+
uri: Optional[str] = None,
6470
sync_tfboard: bool = False,
6571
save_code: bool = True,
66-
config: object = None,
67-
dir: Union[str, Path] = None,
68-
params: dict = None,
72+
config: Optional[object] = None,
73+
dir: Optional[Union[str, Path]] = None,
74+
params: Optional[dict] = None,
6975
):
7076
"""Creates an instance of the `WandBLogger`.
7177
7278
:param project_name: Name of the W&B project.
7379
:param run_name: Name of the W&B run.
7480
:param log_artifacts: Option to log model weights as W&B Artifacts.
81+
Note that, in order for model weights to be logged, the
82+
:class:`WeightCheckpoint` metric must be added to the
83+
evaluation plugin.
7584
:param path: Path to locally save the model checkpoints.
7685
:param uri: URI identifier for external storage buckets (GCS, S3).
7786
:param sync_tfboard: Syncs TensorBoard to the W&B dashboard UI.
@@ -102,6 +111,8 @@ def __init__(
102111
def import_wandb(self):
103112
try:
104113
import wandb
114+
115+
assert hasattr(wandb, "__version__")
105116
except ImportError:
106117
raise ImportError('Please run "pip install wandb" to install wandb')
107118
self.wandb = wandb
@@ -140,7 +151,7 @@ def after_training_exp(
140151
self,
141152
strategy: "SupervisedTemplate",
142153
metric_values: List["MetricValue"],
143-
**kwargs
154+
**kwargs,
144155
):
145156
for val in metric_values:
146157
self.log_metrics([val])
@@ -151,6 +162,11 @@ def after_training_exp(
151162
def log_single_metric(self, name, value, x_plot):
152163
self.step = x_plot
153164

165+
if name.startswith("WeightCheckpoint"):
166+
if self.log_artifacts:
167+
self._log_checkpoint(name, value, x_plot)
168+
return
169+
154170
if isinstance(value, AlternativeValues):
155171
value = value.best_supported_value(
156172
Image,
@@ -192,26 +208,46 @@ def log_single_metric(self, name, value, x_plot):
192208
elif isinstance(value, TensorImage):
193209
self.wandb.log({name: self.wandb.Image(array(value))}, step=self.step)
194210

195-
elif name.startswith("WeightCheckpoint"):
196-
if self.log_artifacts:
197-
cwd = os.getcwd()
198-
ckpt = os.path.join(cwd, self.path)
199-
try:
200-
os.makedirs(ckpt)
201-
except OSError as e:
202-
if e.errno != errno.EEXIST:
203-
raise
204-
suffix = ".pth"
205-
dir_name = os.path.join(ckpt, name + suffix)
206-
artifact_name = os.path.join("Models", name + suffix)
207-
if isinstance(value, Tensor):
208-
torch.save(value, dir_name)
209-
name = os.path.splittext(self.checkpoint)
210-
artifact = self.wandb.Artifact(name, type="model")
211-
artifact.add_file(dir_name, name=artifact_name)
212-
self.wandb.run.log_artifact(artifact)
213-
if self.uri is not None:
214-
artifact.add_reference(self.uri, name=artifact_name)
211+
def _log_checkpoint(self, name, value, x_plot):
212+
assert self.wandb is not None
213+
214+
# Example: 'WeightCheckpoint/train_phase/train_stream/Task000/Exp000'
215+
name_match = CHECKPOINT_METRIC_NAME.match(name)
216+
if name_match is None:
217+
warnings.warn(f"Checkpoint metric has unsupported name {name}.")
218+
return
219+
# phase_name: str = name_match['phase_name']
220+
# stream_name: str = name_match['stream_name']
221+
task_id: Optional[int] = (
222+
int(name_match["task_id"]) if name_match["task_id"] is not None else None
223+
)
224+
experience_id: int = int(name_match["experience_id"])
225+
assert experience_id >= 0
226+
227+
cwd = Path.cwd()
228+
checkpoint_directory = cwd / self.path
229+
checkpoint_directory.mkdir(parents=True, exist_ok=True)
230+
231+
checkpoint_name = "Model_{}".format(experience_id)
232+
checkpoint_file_name = checkpoint_name + ".pth"
233+
checkpoint_path = checkpoint_directory / checkpoint_file_name
234+
artifact_name = "Models/" + checkpoint_file_name
235+
236+
# Write the checkpoint blob
237+
with open(checkpoint_path, "wb") as f:
238+
f.write(value)
239+
240+
metadata = {
241+
"experience": experience_id,
242+
"x_step": x_plot,
243+
**({"task_id": task_id} if task_id is not None else {}),
244+
}
245+
246+
artifact = self.wandb.Artifact(checkpoint_name, type="model", metadata=metadata)
247+
artifact.add_file(str(checkpoint_path), name=artifact_name)
248+
self.wandb.run.log_artifact(artifact)
249+
if self.uri is not None:
250+
artifact.add_reference(self.uri, name=artifact_name)
215251

216252
def __getstate__(self):
217253
state = self.__dict__.copy()

avalanche/training/plugins/ewc.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ def after_training_exp(self, strategy, **kwargs):
121121
strategy.experience.dataset,
122122
strategy.device,
123123
strategy.train_mb_size,
124+
num_workers=kwargs.get("num_workers", 0),
124125
)
125126
self.update_importances(importances, exp_counter)
126127
self.saved_params[exp_counter] = copy_params_dict(strategy.model)
@@ -129,7 +130,7 @@ def after_training_exp(self, strategy, **kwargs):
129130
del self.saved_params[exp_counter - 1]
130131

131132
def compute_importances(
132-
self, model, criterion, optimizer, dataset, device, batch_size
133+
self, model, criterion, optimizer, dataset, device, batch_size, num_workers=0
133134
) -> Dict[str, ParamData]:
134135
"""
135136
Compute EWC importance matrix for each parameter
@@ -153,7 +154,12 @@ def compute_importances(
153154
# list of list
154155
importances = zerolike_params_dict(model)
155156
collate_fn = dataset.collate_fn if hasattr(dataset, "collate_fn") else None
156-
dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn)
157+
dataloader = DataLoader(
158+
dataset,
159+
batch_size=batch_size,
160+
collate_fn=collate_fn,
161+
num_workers=num_workers,
162+
)
157163
for i, batch in enumerate(dataloader):
158164
# get only input, target and task_id from the batch
159165
x, y, task_labels = batch[0], batch[1], batch[-1]

examples/multihead.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ def main(args):
6969

7070
# train and test loop
7171
for train_task in train_stream:
72-
strategy.train(train_task)
73-
strategy.eval(test_stream)
72+
strategy.train(train_task, num_workers=4)
73+
strategy.eval(test_stream, num_workers=4)
7474

7575

7676
if __name__ == "__main__":

examples/wandb_logger.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
from avalanche.benchmarks import nc_benchmark
2626
from avalanche.benchmarks.datasets.dataset_utils import default_dataset_location
27+
from avalanche.evaluation.metrics.checkpoint import WeightCheckpoint
2728
from avalanche.logging import InteractiveLogger, WandBLogger
2829
from avalanche.training.plugins import EvaluationPlugin
2930
from avalanche.evaluation.metrics import (
@@ -83,7 +84,11 @@ def main(args):
8384

8485
interactive_logger = InteractiveLogger()
8586
wandb_logger = WandBLogger(
86-
project_name=args.project, run_name=args.run, config=vars(args)
87+
project_name=args.project,
88+
run_name=args.run,
89+
log_artifacts=args.artifacts,
90+
path=args.path if args.path else None,
91+
config=vars(args),
8792
)
8893

8994
eval_plugin = EvaluationPlugin(
@@ -120,6 +125,7 @@ def main(args):
120125
),
121126
disk_usage_metrics(minibatch=True, epoch=True, experience=True, stream=True),
122127
MAC_metrics(minibatch=True, epoch=True, experience=True),
128+
WeightCheckpoint(),
123129
loggers=[interactive_logger, wandb_logger],
124130
)
125131

@@ -157,9 +163,22 @@ def main(args):
157163
default=0,
158164
help="Select zero-indexed cuda device. -1 to use CPU.",
159165
)
160-
parser.add_argument("--run", type=str, help="Provide a run name for WandB")
161166
parser.add_argument(
162167
"--project", type=str, help="Define the name of the WandB project"
163168
)
169+
parser.add_argument("--run", type=str, help="Provide a run name for WandB")
170+
parser.add_argument(
171+
"--artifacts",
172+
default=False,
173+
action="store_true",
174+
help="Log Model Checkpoints as W&B Artifacts",
175+
)
176+
parser.add_argument(
177+
"--path",
178+
type=str,
179+
default="Checkpoint",
180+
help="Local path to save the model checkpoints",
181+
)
182+
164183
args = parser.parse_args()
165184
main(args)

tests/test_core50.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@ def test_core50_nc_benchmark(self):
3838
classes_in_test = benchmark_instance.classes_in_experience["test"][0]
3939
self.assertSetEqual(set(range(50)), set(classes_in_test))
4040

41+
# Regression tests for issue #774
42+
self.assertSequenceEqual([10] + ([5] * 8), benchmark_instance.n_classes_per_exp)
43+
self.assertSetEqual(set(range(50)), set(benchmark_instance.classes_order))
44+
self.assertEqual(50, len(benchmark_instance.classes_order))
45+
4146

4247
if __name__ == "__main__":
4348
unittest.main()

tests/test_models.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import sys
22
import os
33
import copy
4+
import tempfile
45

56
import unittest
67

@@ -650,10 +651,13 @@ def test_ncm_save_load(self):
650651
),
651652
}
652653
)
653-
torch.save(classifier.state_dict(), "ncm.pt")
654-
del classifier
655-
classifier = NCMClassifier()
656-
check = torch.load("ncm.pt")
654+
655+
with tempfile.TemporaryFile() as tmpfile:
656+
torch.save(classifier.state_dict(), tmpfile)
657+
del classifier
658+
classifier = NCMClassifier()
659+
tmpfile.seek(0)
660+
check = torch.load(tmpfile)
657661
classifier.load_state_dict(check)
658662
assert classifier.class_means.shape == (3, 5)
659663
assert (classifier.class_means[0] == 0).all()

0 commit comments

Comments
 (0)