Skip to content
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion avalanche/evaluation/metrics/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def _package_result(self, strategy) -> 'MetricResult':
self.get_global_counter())]

def after_eval_exp(self, strategy: 'BaseStrategy') -> 'MetricResult':
model_params = copy.deepcopy(strategy.model.parameters())
model_params = copy.deepcopy(list(strategy.model.parameters()))
self.update(model_params)

def __str__(self):
Expand Down
10 changes: 9 additions & 1 deletion examples/wandb_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from os.path import expanduser

import argparse

from click import Path
import torch
from torch.nn import CrossEntropyLoss
from torch.optim import SGD
Expand Down Expand Up @@ -74,6 +76,8 @@ def main(args):

interactive_logger = InteractiveLogger()
wandb_logger = WandBLogger(project_name=args.project, run_name=args.run,
log_artifacts=args.artifacts,
path=args.path if args.path else None,
config=args)

eval_plugin = EvaluationPlugin(
Expand Down Expand Up @@ -127,8 +131,12 @@ def main(args):
parser = argparse.ArgumentParser()
parser.add_argument('--cuda', type=int, default=0,
help='Select zero-indexed cuda device. -1 to use CPU.')
parser.add_argument('--run', type=str, help='Provide a run name for WandB')
parser.add_argument('--project', type=str,
help='Define the name of the WandB project')
parser.add_argument('--run', type=str, help='Provide a run name for WandB')
parser.add_argument('--artifacts', type=bool, default=False,
help='Log Model Checkpoints as W&B Artifacts')
parser.add_argument('--path', type=str, default=None,
help='Local path to save the model checkpoints')
args = parser.parse_args()
main(args)