@@ -171,6 +171,7 @@ def _prepare_model(
171
171
module : nn .Module ,
172
172
* ,
173
173
batch_first : bool = True ,
174
+ max_grad_norm : Union [float , List [float ]] = 1.0 ,
174
175
loss_reduction : str = "mean" ,
175
176
grad_sample_mode : str = "hooks" ,
176
177
) -> AbstractGradSampleModule :
@@ -194,12 +195,21 @@ def _prepare_model(
194
195
195
196
return module
196
197
else :
197
- return wrap_model (
198
- module ,
199
- grad_sample_mode = grad_sample_mode ,
200
- batch_first = batch_first ,
201
- loss_reduction = loss_reduction ,
202
- )
198
+ if grad_sample_mode == "ghost" :
199
+ return wrap_model (
200
+ module ,
201
+ grad_sample_mode = grad_sample_mode ,
202
+ batch_first = batch_first ,
203
+ loss_reduction = loss_reduction ,
204
+ max_grad_norm = max_grad_norm ,
205
+ )
206
+ else :
207
+ return wrap_model (
208
+ module ,
209
+ grad_sample_mode = grad_sample_mode ,
210
+ batch_first = batch_first ,
211
+ loss_reduction = loss_reduction ,
212
+ )
203
213
204
214
def is_compatible (
205
215
self ,
@@ -355,6 +365,7 @@ def make_private(
355
365
module = self ._prepare_model (
356
366
module ,
357
367
batch_first = batch_first ,
368
+ max_grad_norm = max_grad_norm ,
358
369
loss_reduction = loss_reduction ,
359
370
grad_sample_mode = grad_sample_mode ,
360
371
)
0 commit comments