@@ -115,3 +115,35 @@ class TELayerNormMock(Mock): ...
115
115
assert isinstance (model .l1 , TELinearMock )
116
116
assert isinstance (model .l2 , TELayerNormMock )
117
117
assert isinstance (model .l3 .l , TELinearMock )
118
+
119
+
120
+ def test_convert_module_handles_linear_without_bias (monkeypatch ):
121
+ module = lightning .fabric .plugins .precision .transformer_engine # Set up mock transformer_engine
122
+ monkeypatch .setattr (module , "_TRANSFORMER_ENGINE_AVAILABLE" , lambda : True )
123
+
124
+ transformer_engine_mock = Mock ()
125
+ monkeypatch .setitem (sys .modules , "transformer_engine" , transformer_engine_mock )
126
+ monkeypatch .setitem (sys .modules , "transformer_engine.pytorch" , transformer_engine_mock .pytorch )
127
+ monkeypatch .setitem (sys .modules , "transformer_engine.common.recipe" , transformer_engine_mock .recipe )
128
+
129
+ class TELinearMock (torch .nn .Linear ): # Mock the Linear replacement class
130
+ def __init__ (self , in_features , out_features , bias = True ):
131
+ super ().__init__ (in_features , out_features , bias )
132
+
133
+ transformer_engine_mock .pytorch .Linear = TELinearMock
134
+ transformer_engine_mock .pytorch .LayerNorm = torch .nn .LayerNorm
135
+ transformer_engine_mock .recipe .DelayedScaling .return_value = None
136
+
137
+ class BiaslessModel (torch .nn .Module ):
138
+ def __init__ (self ):
139
+ super ().__init__ ()
140
+ self .linear = torch .nn .Linear (16 , 32 , bias = False ) # This was causing the bug
141
+
142
+ model = BiaslessModel ()
143
+ precision = TransformerEnginePrecision (weights_dtype = torch .float16 )
144
+ precision .replace_layers = True
145
+
146
+ precision .convert_module (model ) # This should no longer raise AttributeError
147
+
148
+ assert isinstance (model .linear , TELinearMock )
149
+ assert model .linear .bias is None
0 commit comments