44# See the accompanying LICENSE file for terms. #
55# #
66# Date: 25-11-2020 #
7- # Author(s): Diganta Misra, Andrea Cossu #
7+ # Author(s): Diganta Misra, Andrea Cossu, Lorenzo Pellegrini #
88# E-mail: contact@continualai.org #
99# Website: www.continualai.org #
1010################################################################################
1111""" This module handles all the functionalities related to the logging of
1212Avalanche experiments using Weights & Biases. """
1313
14- from typing import Union , List , TYPE_CHECKING
14+ import re
15+ from typing import Optional , Union , List , TYPE_CHECKING
1516from pathlib import Path
1617import os
17- import errno
18+ import warnings
1819
1920import numpy as np
2021from numpy import array
21- import torch
2222from torch import Tensor
2323
2424from PIL .Image import Image
3737 from avalanche .training .templates import SupervisedTemplate
3838
3939
40+ CHECKPOINT_METRIC_NAME = re .compile (
41+ r"^WeightCheckpoint\/(?P<phase_name>\S+)_phase\/(?P<stream_name>\S+)_"
42+ r"stream(\/Task(?P<task_id>\d+))?\/Exp(?P<experience_id>\d+)$"
43+ )
44+
45+
4046class WandBLogger (BaseLogger , SupervisedPlugin ):
4147 """Weights and Biases logger.
4248
@@ -60,18 +66,21 @@ def __init__(
6066 run_name : str = "Test" ,
6167 log_artifacts : bool = False ,
6268 path : Union [str , Path ] = "Checkpoints" ,
63- uri : str = None ,
69+ uri : Optional [ str ] = None ,
6470 sync_tfboard : bool = False ,
6571 save_code : bool = True ,
66- config : object = None ,
67- dir : Union [str , Path ] = None ,
68- params : dict = None ,
72+ config : Optional [ object ] = None ,
73+ dir : Optional [ Union [str , Path ] ] = None ,
74+ params : Optional [ dict ] = None ,
6975 ):
7076 """Creates an instance of the `WandBLogger`.
7177
7278 :param project_name: Name of the W&B project.
7379 :param run_name: Name of the W&B run.
7480 :param log_artifacts: Option to log model weights as W&B Artifacts.
81+ Note that, in order for model weights to be logged, the
82+ :class:`WeightCheckpoint` metric must be added to the
83+ evaluation plugin.
7584 :param path: Path to locally save the model checkpoints.
7685 :param uri: URI identifier for external storage buckets (GCS, S3).
7786 :param sync_tfboard: Syncs TensorBoard to the W&B dashboard UI.
@@ -102,6 +111,8 @@ def __init__(
102111 def import_wandb (self ):
103112 try :
104113 import wandb
114+
115+ assert hasattr (wandb , "__version__" )
105116 except ImportError :
106117 raise ImportError ('Please run "pip install wandb" to install wandb' )
107118 self .wandb = wandb
@@ -140,7 +151,7 @@ def after_training_exp(
140151 self ,
141152 strategy : "SupervisedTemplate" ,
142153 metric_values : List ["MetricValue" ],
143- ** kwargs
154+ ** kwargs ,
144155 ):
145156 for val in metric_values :
146157 self .log_metrics ([val ])
@@ -151,6 +162,11 @@ def after_training_exp(
151162 def log_single_metric (self , name , value , x_plot ):
152163 self .step = x_plot
153164
165+ if name .startswith ("WeightCheckpoint" ):
166+ if self .log_artifacts :
167+ self ._log_checkpoint (name , value , x_plot )
168+ return
169+
154170 if isinstance (value , AlternativeValues ):
155171 value = value .best_supported_value (
156172 Image ,
@@ -192,26 +208,46 @@ def log_single_metric(self, name, value, x_plot):
192208 elif isinstance (value , TensorImage ):
193209 self .wandb .log ({name : self .wandb .Image (array (value ))}, step = self .step )
194210
195- elif name .startswith ("WeightCheckpoint" ):
196- if self .log_artifacts :
197- cwd = os .getcwd ()
198- ckpt = os .path .join (cwd , self .path )
199- try :
200- os .makedirs (ckpt )
201- except OSError as e :
202- if e .errno != errno .EEXIST :
203- raise
204- suffix = ".pth"
205- dir_name = os .path .join (ckpt , name + suffix )
206- artifact_name = os .path .join ("Models" , name + suffix )
207- if isinstance (value , Tensor ):
208- torch .save (value , dir_name )
209- name = os .path .splittext (self .checkpoint )
210- artifact = self .wandb .Artifact (name , type = "model" )
211- artifact .add_file (dir_name , name = artifact_name )
212- self .wandb .run .log_artifact (artifact )
213- if self .uri is not None :
214- artifact .add_reference (self .uri , name = artifact_name )
211+ def _log_checkpoint (self , name , value , x_plot ):
212+ assert self .wandb is not None
213+
214+ # Example: 'WeightCheckpoint/train_phase/train_stream/Task000/Exp000'
215+ name_match = CHECKPOINT_METRIC_NAME .match (name )
216+ if name_match is None :
217+ warnings .warn (f"Checkpoint metric has unsupported name { name } ." )
218+ return
219+ # phase_name: str = name_match['phase_name']
220+ # stream_name: str = name_match['stream_name']
221+ task_id : Optional [int ] = (
222+ int (name_match ["task_id" ]) if name_match ["task_id" ] is not None else None
223+ )
224+ experience_id : int = int (name_match ["experience_id" ])
225+ assert experience_id >= 0
226+
227+ cwd = Path .cwd ()
228+ checkpoint_directory = cwd / self .path
229+ checkpoint_directory .mkdir (parents = True , exist_ok = True )
230+
231+ checkpoint_name = "Model_{}" .format (experience_id )
232+ checkpoint_file_name = checkpoint_name + ".pth"
233+ checkpoint_path = checkpoint_directory / checkpoint_file_name
234+ artifact_name = "Models/" + checkpoint_file_name
235+
236+ # Write the checkpoint blob
237+ with open (checkpoint_path , "wb" ) as f :
238+ f .write (value )
239+
240+ metadata = {
241+ "experience" : experience_id ,
242+ "x_step" : x_plot ,
243+ ** ({"task_id" : task_id } if task_id is not None else {}),
244+ }
245+
246+ artifact = self .wandb .Artifact (checkpoint_name , type = "model" , metadata = metadata )
247+ artifact .add_file (str (checkpoint_path ), name = artifact_name )
248+ self .wandb .run .log_artifact (artifact )
249+ if self .uri is not None :
250+ artifact .add_reference (self .uri , name = artifact_name )
215251
216252 def __getstate__ (self ):
217253 state = self .__dict__ .copy ()
0 commit comments