Skip to content

Commit a3a2d44

Browse files
committed
increased allclose tolerance in test
1 parent 29f7710 commit a3a2d44

File tree

1 file changed

+17
-17
lines changed

1 file changed

+17
-17
lines changed

tests/test_tensorlist.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1301,7 +1301,7 @@ def test_reduction_ops(simple_tl: TensorList, reduction_method, dim, keepdim):
13011301
expected_tl = TensorList(expected_list)
13021302
assert isinstance(result, TensorList)
13031303
assert len(result) == len(expected_tl)
1304-
assert_tl_allclose(result, expected_tl, atol=1e-6) # Use allclose due to potential float variations
1304+
assert_tl_allclose(result, expected_tl, atol=1e-3) # Use allclose due to potential float variations
13051305

13061306
# --- Grafting, Rescaling, Normalizing, Clipping ---
13071307

@@ -1381,8 +1381,8 @@ def test_rescale(simple_tl: TensorList, dim):
13811381
assert torch.allclose(rescaled_scalar.global_min(), torch.tensor(min_val))
13821382
assert torch.allclose(rescaled_scalar.global_max(), torch.tensor(max_val))
13831383
else:
1384-
assert_tl_allclose(rescaled_scalar_min, TensorList([torch.full_like(t, min_val) for t in rescaled_scalar_min]),atol=1e-4)
1385-
assert_tl_allclose(rescaled_scalar_max, TensorList([torch.full_like(t, max_val) for t in rescaled_scalar_max]),atol=1e-4)
1384+
assert_tl_allclose(rescaled_scalar_min, TensorList([torch.full_like(t, min_val) for t in rescaled_scalar_min]),atol=1e-3)
1385+
assert_tl_allclose(rescaled_scalar_max, TensorList([torch.full_like(t, max_val) for t in rescaled_scalar_max]),atol=1e-3)
13861386

13871387

13881388
# Rescale list
@@ -1402,8 +1402,8 @@ def test_rescale(simple_tl: TensorList, dim):
14021402
assert global_max_rescaled < avg_max + 1.0 # Loose check
14031403

14041404
else:
1405-
assert_tl_allclose(rescaled_list_min, TensorList([torch.full_like(t, mn) for t, mn in zip(rescaled_list_min, min_list)]),atol=1e-4)
1406-
assert_tl_allclose(rescaled_list_max, TensorList([torch.full_like(t, mx) for t, mx in zip(rescaled_list_max, max_list)]),atol=1e-4)
1405+
assert_tl_allclose(rescaled_list_min, TensorList([torch.full_like(t, mn) for t, mn in zip(rescaled_list_min, min_list)]),atol=1e-3)
1406+
assert_tl_allclose(rescaled_list_max, TensorList([torch.full_like(t, mx) for t, mx in zip(rescaled_list_max, max_list)]),atol=1e-3)
14071407

14081408
# Rescale to 01 helper
14091409
rescaled_01 = simple_tl.rescale_to_01(dim=dim, eps=eps)
@@ -1413,8 +1413,8 @@ def test_rescale(simple_tl: TensorList, dim):
14131413
assert torch.allclose(rescaled_01.global_min(), torch.tensor(0.0))
14141414
assert torch.allclose(rescaled_01.global_max(), torch.tensor(1.0))
14151415
else:
1416-
assert_tl_allclose(rescaled_01_min, TensorList([torch.zeros_like(t) for t in rescaled_01_min]), atol=1e-4)
1417-
assert_tl_allclose(rescaled_01_max, TensorList([torch.ones_like(t) for t in rescaled_01_max]), atol=1e-4)
1416+
assert_tl_allclose(rescaled_01_min, TensorList([torch.zeros_like(t) for t in rescaled_01_min]), atol=1e-3)
1417+
assert_tl_allclose(rescaled_01_max, TensorList([torch.ones_like(t) for t in rescaled_01_max]), atol=1e-3)
14181418

14191419

14201420
# Test inplace
@@ -1454,11 +1454,11 @@ def test_normalize(big_tl: TensorList, dim):
14541454
normalized_scalar_var = normalized_scalar.var(dim=dim if dim != 'global' else None)
14551455

14561456
if dim == 'global':
1457-
assert torch.allclose(normalized_scalar.global_mean(), torch.tensor(mean_val), atol=1e-4)
1458-
assert torch.allclose(normalized_scalar.global_var(), torch.tensor(var_val), atol=1e-4)
1457+
assert torch.allclose(normalized_scalar.global_mean(), torch.tensor(mean_val), atol=1e-3)
1458+
assert torch.allclose(normalized_scalar.global_var(), torch.tensor(var_val), atol=1e-3)
14591459
else:
1460-
assert_tl_allclose(normalized_scalar_mean, TensorList([torch.full_like(t, mean_val) for t in normalized_scalar_mean]), atol=1e-4)
1461-
assert_tl_allclose(normalized_scalar_var, TensorList([torch.full_like(t, var_val) for t in normalized_scalar_var]), atol=1e-4)
1460+
assert_tl_allclose(normalized_scalar_mean, TensorList([torch.full_like(t, mean_val) for t in normalized_scalar_mean]), atol=1e-3)
1461+
assert_tl_allclose(normalized_scalar_var, TensorList([torch.full_like(t, var_val) for t in normalized_scalar_var]), atol=1e-3)
14621462

14631463
# Normalize list mean/var
14641464
normalized_list = simple_tl.normalize(mean_list, var_list, dim=dim)
@@ -1476,19 +1476,19 @@ def test_normalize(big_tl: TensorList, dim):
14761476
# assert torch.allclose(global_mean_rescaled, torch.tensor(avg_mean), rtol=1e-1, atol=1e-1) # Loose check
14771477
# assert torch.allclose(global_var_rescaled, torch.tensor(avg_var), rtol=1e-1, atol=1e-1) # Loose check
14781478
else:
1479-
assert_tl_allclose(normalized_list_mean, TensorList([torch.full_like(t, m) for t, m in zip(normalized_list_mean, mean_list)]), atol=1e-4)
1480-
assert_tl_allclose(normalized_list_var, TensorList([torch.full_like(t, v) for t, v in zip(normalized_list_var, var_list)]), atol=1e-4)
1479+
assert_tl_allclose(normalized_list_mean, TensorList([torch.full_like(t, m) for t, m in zip(normalized_list_mean, mean_list)]), atol=1e-3)
1480+
assert_tl_allclose(normalized_list_var, TensorList([torch.full_like(t, v) for t, v in zip(normalized_list_var, var_list)]), atol=1e-3)
14811481

14821482
# Z-normalize helper
14831483
znorm = simple_tl.znormalize(dim=dim, eps=1e-10)
14841484
znorm_mean = znorm.mean(dim=dim if dim != 'global' else None)
14851485
znorm_var = znorm.var(dim=dim if dim != 'global' else None)
14861486
if dim == 'global':
1487-
assert torch.allclose(znorm.global_mean(), torch.tensor(0.0), atol=1e-4)
1488-
assert torch.allclose(znorm.global_var(), torch.tensor(1.0), atol=1e-4)
1487+
assert torch.allclose(znorm.global_mean(), torch.tensor(0.0), atol=1e-3)
1488+
assert torch.allclose(znorm.global_var(), torch.tensor(1.0), atol=1e-3)
14891489
else:
1490-
assert_tl_allclose(znorm_mean, TensorList([torch.zeros_like(t) for t in znorm_mean]), atol=1e-4)
1491-
assert_tl_allclose(znorm_var, TensorList([torch.ones_like(t) for t in znorm_var]), atol=1e-4)
1490+
assert_tl_allclose(znorm_mean, TensorList([torch.zeros_like(t) for t in znorm_mean]), atol=1e-3)
1491+
assert_tl_allclose(znorm_var, TensorList([torch.ones_like(t) for t in znorm_var]), atol=1e-3)
14921492

14931493

14941494
# Test inplace

0 commit comments

Comments
 (0)