From 92e385dffaeb430b102770ba8012a64847491de5 Mon Sep 17 00:00:00 2001 From: Jeremy Sawruk Date: Tue, 10 Jun 2025 14:34:31 -0700 Subject: [PATCH] Fix pyre-fixme[6] in _test_linear_classifier.py Summary: title Differential Revision: D75990130 --- .../linear_models/_test_linear_classifier.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/tests/utils/models/linear_models/_test_linear_classifier.py b/tests/utils/models/linear_models/_test_linear_classifier.py index b1f458e922..a3d514c229 100644 --- a/tests/utils/models/linear_models/_test_linear_classifier.py +++ b/tests/utils/models/linear_models/_test_linear_classifier.py @@ -96,16 +96,14 @@ def compare_to_sk_learn( pytorch_h = pytorch_classifier.representation() sklearn_h = sklearn_classifier.representation() + alpha_tensor = torch.tensor(alpha) + if objective == "ridge": - # pyre-fixme[6]: For 2nd argument expected `Tensor` but got `float`. - o_pytorch["l2_reg"] = alpha * pytorch_h.norm(p=2, dim=-1) - # pyre-fixme[6]: For 2nd argument expected `Tensor` but got `float`. - o_sklearn["l2_reg"] = alpha * sklearn_h.norm(p=2, dim=-1) + o_pytorch["l2_reg"] = alpha_tensor * pytorch_h.norm(p=2, dim=-1) + o_sklearn["l2_reg"] = alpha_tensor * sklearn_h.norm(p=2, dim=-1) elif objective == "lasso": - # pyre-fixme[6]: For 2nd argument expected `Tensor` but got `float`. - o_pytorch["l1_reg"] = alpha * pytorch_h.norm(p=1, dim=-1) - # pyre-fixme[6]: For 2nd argument expected `Tensor` but got `float`. - o_sklearn["l1_reg"] = alpha * sklearn_h.norm(p=1, dim=-1) + o_pytorch["l1_reg"] = alpha_tensor * pytorch_h.norm(p=1, dim=-1) + o_sklearn["l1_reg"] = alpha_tensor * sklearn_h.norm(p=1, dim=-1) rel_diff = cast( npt.NDArray,