Skip to content

Commit abde4c2

Browse files
committed
Solve linter issue
1 parent 6c2ec81 commit abde4c2

File tree

1 file changed

+7
-2
lines changed
  • avalanche/training/plugins

1 file changed

+7
-2
lines changed

avalanche/training/plugins/ewc.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def after_training_exp(self, strategy, **kwargs):
121121
strategy.experience.dataset,
122122
strategy.device,
123123
strategy.train_mb_size,
124-
num_workers=kwargs.get('num_workers', 0)
124+
num_workers=kwargs.get("num_workers", 0),
125125
)
126126
self.update_importances(importances, exp_counter)
127127
self.saved_params[exp_counter] = copy_params_dict(strategy.model)
@@ -154,7 +154,12 @@ def compute_importances(
154154
# list of list
155155
importances = zerolike_params_dict(model)
156156
collate_fn = dataset.collate_fn if hasattr(dataset, "collate_fn") else None
157-
dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn, num_workers=num_workers)
157+
dataloader = DataLoader(
158+
dataset,
159+
batch_size=batch_size,
160+
collate_fn=collate_fn,
161+
num_workers=num_workers,
162+
)
158163
for i, batch in enumerate(dataloader):
159164
# get only input, target and task_id from the batch
160165
x, y, task_labels = batch[0], batch[1], batch[-1]

0 commit comments

Comments
 (0)