Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 42 additions & 19 deletions thunder/tests/test_transformer_engine_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,25 @@
recipe_ids += ["nvfp4_e2m1"]


# Returns the estimated numerical error for a given dtype as per TE spec here:
# https://github.com/NVIDIA/TransformerEngine/blob/7ad130efd52c3aa4a386d25f1d42b28d5aa20090/tests/pytorch/test_numerics.py#L155-L167
def te_assert_close(actual, expected, **kwargs):
tolerances = {}

if not isinstance(actual, torch.Tensor) and isinstance(actual, float):
tolerances = dict(rtol=1.3e-6, atol=1e-5)
elif actual.dtype == torch.float32:
tolerances = dict(rtol=1.3e-6, atol=1e-5)
elif actual.dtype == torch.float16:
tolerances = dict(rtol=1e-3, atol=1e-5)
elif actual.dtype == torch.bfloat16:
tolerances = dict(rtol=1.6e-2, atol=1e-5)

kwargs.update(tolerances)

assert_close(actual, expected, **kwargs)


@requiresCUDA
@pytest.mark.parametrize("fp8_recipe", recipes, ids=recipe_ids)
@skip_on_sm120_and_sm121
Expand All @@ -69,6 +88,7 @@ def test_te_linear_forward_backward(fp8_recipe: recipe.Recipe):
# Verify that `torch.nn.functional.linear` is replaced with `te_linear_*`
# and the output as well as the gradients match for thunder compiled code.
dtype = torch.bfloat16

device = "cuda"

# TE inputs (3D input)
Expand Down Expand Up @@ -100,15 +120,15 @@ def fn(x, w1, w2):
te_result = te_linear2(inter_result + x_te)

# Verifies the result is close to TE
assert_close(thunder_result, te_result)
te_assert_close(thunder_result, te_result)

grad_output = torch.randn_like(te_result)
te_result.backward(grad_output)
thunder_result.backward(grad_output)

assert_close(x.grad, x_te.grad)
assert_close(w1.grad, te_linear1.weight.grad)
assert_close(w2.grad, te_linear2.weight.grad)
te_assert_close(x.grad, x_te.grad)
te_assert_close(w1.grad, te_linear1.weight.grad)
te_assert_close(w2.grad, te_linear2.weight.grad)

# Verifies te_linear was called
forward_trace = thunder.last_traces(cfn)
Expand Down Expand Up @@ -143,6 +163,7 @@ def test_te_linear_forward_backward_multiple_iteration(fp8_recipe: recipe.Recipe
# Since, the FP8 operations are stateful, we want to verify that
# our output matches over multiple iterations (where state handling comes into picture)
dtype = torch.bfloat16

device = "cuda"
# Running more iterations leads to `nan` for both eager and thunder
# with BlockScaling.
Expand Down Expand Up @@ -200,10 +221,10 @@ def thunder_model(x):
train_model(thunder_model, thunder_sgd_optimizer)

# Verify that the weights and biases converge to same value after few iterations.
assert_close(w1, te_linear1.weight)
assert_close(w2, te_linear2.weight)
assert_close(b1, te_linear1.bias)
assert_close(b2, te_linear2.bias)
te_assert_close(w1, te_linear1.weight)
te_assert_close(w2, te_linear2.weight)
te_assert_close(b1, te_linear1.bias)
te_assert_close(b2, te_linear2.bias)


@requiresCUDA
Expand All @@ -220,6 +241,7 @@ def test_te_linear_forward_backward_multiple_iteration_multiple_recipes():
pytest.skip("platform does not support two different recipes")

dtype = torch.bfloat16

device = "cuda"
# Running more iterations leads to `nan` for both eager and thunder
# with BlockScaling.
Expand Down Expand Up @@ -278,10 +300,10 @@ def thunder_model(x, fp8_recipe):
train_model(thunder_model, thunder_sgd_optimizer)

# Verify that the weights and biases converge to same value after few iterations.
assert_close(w1, te_linear1.weight)
assert_close(w2, te_linear2.weight)
assert_close(b1, te_linear1.bias)
assert_close(b2, te_linear2.bias)
te_assert_close(w1, te_linear1.weight)
te_assert_close(w2, te_linear2.weight)
te_assert_close(b1, te_linear1.bias)
te_assert_close(b2, te_linear2.bias)


@requiresCUDA
Expand Down Expand Up @@ -562,6 +584,7 @@ def test_te_activation_checkpointing_correctness(fp8_recipe: recipe.Recipe, comp
pytest.skip(msg_nvfp4)

dtype = torch.bfloat16

device = "cuda"
iterations = 6

Expand Down Expand Up @@ -648,12 +671,12 @@ def thunder_model(x):
train_model(thunder_model, thunder_sgd_optimizer, thunder_loss_hist)

for loss, te_loss in zip(thunder_loss_hist, te_loss_hist):
assert_close(loss, te_loss)
te_assert_close(loss, te_loss)

assert_close(w1, te_linear1.weight)
assert_close(w2, te_linear2.weight)
assert_close(b1, te_linear1.bias)
assert_close(b2, te_linear2.bias)
te_assert_close(w1, te_linear1.weight)
te_assert_close(w2, te_linear2.weight)
te_assert_close(b1, te_linear1.bias)
te_assert_close(b2, te_linear2.bias)

# TE does not expose the scales for MXFP8
if fp8_recipe.delayed():
Expand Down Expand Up @@ -687,8 +710,8 @@ def thunder_model(x):

# check the scales are the same but for last dimension which is always on in TE
for te_scale, th_scale in zip(te_scales, th_scales):
assert_close(te_scale[:-1], th_scale)
te_assert_close(te_scale[:-1], th_scale)

# check that amax history is the same as TE
for te_amax, th_amax in zip(te_amax_hist, th_amax_hist):
assert_close(te_amax[:, :-1], th_amax)
te_assert_close(te_amax[:, :-1], th_amax)
Loading