@@ -338,3 +338,70 @@ def on_save_checkpoint(self, trainer, pl_module, checkpoint):
338
338
assert not hasattr (model .layer .mlp_3 , "weight_orig" )
339
339
model = TestModel .load_from_checkpoint (trainer .checkpoint_callback .last_model_path )
340
340
assert not hasattr (model .layer .mlp_3 , "weight_orig" )
341
+
342
+
343
+ def test_sanitize_parameters_explicit_check ():
344
+ """Test the sanitize_parameters_to_prune method with various attribute types."""
345
+
346
+ class TestModule (nn .Module ):
347
+ def __init__ (self ):
348
+ super ().__init__ ()
349
+ self .weight = nn .Parameter (torch .randn (5 , 5 ))
350
+ self .bias = nn .Parameter (torch .randn (5 ))
351
+ self .some_bool = True
352
+ self .some_tensor = torch .randn (3 , 3 ) # Regular tensor, not parameter
353
+ self .some_string = "test"
354
+ self .some_none = None
355
+
356
+ class TestModel (BoringModel ):
357
+ def __init__ (self ):
358
+ super ().__init__ ()
359
+ self .test_module = TestModule ()
360
+
361
+ model = TestModel ()
362
+
363
+ parameters_to_prune = ModelPruning .sanitize_parameters_to_prune (
364
+ model ,
365
+ parameters_to_prune = (),
366
+ parameter_names = ["weight" , "bias" , "some_bool" , "some_tensor" , "some_string" , "some_none" ],
367
+ )
368
+
369
+ param_names_found = set ()
370
+ for module , param_name in parameters_to_prune :
371
+ param = getattr (module , param_name )
372
+ assert isinstance (param , nn .Parameter ), f"Expected Parameter, got { type (param )} "
373
+ param_names_found .add (param_name )
374
+
375
+ assert "weight" in param_names_found
376
+ assert "bias" in param_names_found
377
+ assert "some_bool" not in param_names_found
378
+ assert "some_tensor" not in param_names_found
379
+ assert "some_string" not in param_names_found
380
+ assert "some_none" not in param_names_found
381
+
382
+
383
+ def test_original_issue_reproduction ():
384
+ """Issue: https://github.com/Lightning-AI/pytorch-lightning/issues/10835."""
385
+
386
+ class ProblematicModel (BoringModel ):
387
+ def __init__ (self ):
388
+ super ().__init__ ()
389
+ self .layer = Sequential (
390
+ OrderedDict ([
391
+ ("mlp_1" , nn .Linear (32 , 32 )),
392
+ ("mlp_2" , nn .Linear (32 , 2 )),
393
+ ])
394
+ )
395
+ # Add boolean attributes that would cause the original error
396
+ self .layer .mlp_1 .training = True
397
+ self .layer .mlp_2 .requires_grad = True
398
+
399
+ model = ProblematicModel ()
400
+
401
+ parameters_to_prune = ModelPruning .sanitize_parameters_to_prune (
402
+ model , parameters_to_prune = (), parameter_names = ["weight" , "bias" , "training" , "requires_grad" ]
403
+ )
404
+
405
+ for module , param_name in parameters_to_prune :
406
+ param = getattr (module , param_name )
407
+ assert isinstance (param , nn .Parameter ), f"Non-parameter found: { type (param )} "
0 commit comments