Skip to content

Commit 01ba7a1

Browse files
Fix: TransformerEnginePrecision conversion for layers with bias=False (#20805)
* Update transformer_engine.py * Update test_transformer_engine.py --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 68f2adc commit 01ba7a1

File tree

2 files changed

+35
-1
lines changed

2 files changed

+35
-1
lines changed

src/lightning/fabric/plugins/precision/transformer_engine.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,9 @@ def _convert_layers(module: torch.nn.Module) -> None:
171171
elif isinstance(child, torch.nn.LayerNorm):
172172
replacement = te.LayerNorm(child.normalized_shape[0], eps=child.eps)
173173
replacement.weight.data = child.weight.data.clone()
174-
replacement.bias.data = child.bias.data.clone()
174+
# Check if bias exists before attempting to clone its data
175+
if child.bias is not None and replacement.bias is not None:
176+
replacement.bias.data = child.bias.data.clone()
175177
log.debug(f"Replacing layer {name!r} with Transformer Engine equivalent")
176178
module.__setattr__(name, replacement)
177179
else:

tests/tests_fabric/plugins/precision/test_transformer_engine.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,3 +115,35 @@ class TELayerNormMock(Mock): ...
115115
assert isinstance(model.l1, TELinearMock)
116116
assert isinstance(model.l2, TELayerNormMock)
117117
assert isinstance(model.l3.l, TELinearMock)
118+
119+
120+
def test_convert_module_handles_linear_without_bias(monkeypatch):
121+
module = lightning.fabric.plugins.precision.transformer_engine # Set up mock transformer_engine
122+
monkeypatch.setattr(module, "_TRANSFORMER_ENGINE_AVAILABLE", lambda: True)
123+
124+
transformer_engine_mock = Mock()
125+
monkeypatch.setitem(sys.modules, "transformer_engine", transformer_engine_mock)
126+
monkeypatch.setitem(sys.modules, "transformer_engine.pytorch", transformer_engine_mock.pytorch)
127+
monkeypatch.setitem(sys.modules, "transformer_engine.common.recipe", transformer_engine_mock.recipe)
128+
129+
class TELinearMock(torch.nn.Linear): # Mock the Linear replacement class
130+
def __init__(self, in_features, out_features, bias=True):
131+
super().__init__(in_features, out_features, bias)
132+
133+
transformer_engine_mock.pytorch.Linear = TELinearMock
134+
transformer_engine_mock.pytorch.LayerNorm = torch.nn.LayerNorm
135+
transformer_engine_mock.recipe.DelayedScaling.return_value = None
136+
137+
class BiaslessModel(torch.nn.Module):
138+
def __init__(self):
139+
super().__init__()
140+
self.linear = torch.nn.Linear(16, 32, bias=False) # This was causing the bug
141+
142+
model = BiaslessModel()
143+
precision = TransformerEnginePrecision(weights_dtype=torch.float16)
144+
precision.replace_layers = True
145+
146+
precision.convert_module(model) # This should no longer raise AttributeError
147+
148+
assert isinstance(model.linear, TELinearMock)
149+
assert model.linear.bias is None

0 commit comments

Comments
 (0)