|
26 | 26 | from tests.testing_utils import requires_accelerate, requires_gpu |
27 | 27 |
|
28 | 28 |
|
29 | | -@pytest.mark.parametrize("type", ("hadamard", "random-hadamard")) |
| 29 | +@pytest.mark.parametrize("type", ("hadamard", "random-hadamard", "random-matrix")) |
30 | 30 | @pytest.mark.parametrize("randomize", (True, False)) |
31 | 31 | @pytest.mark.parametrize("head_dim", (None, 2, 4)) |
32 | 32 | @pytest.mark.parametrize("input_batch_size", (1, 5, 17)) |
@@ -57,7 +57,7 @@ def test_correctness_linear(type, randomize, head_dim, input_batch_size): |
57 | 57 | assert torch.allclose(true_output, output, atol=1e-5, rtol=0.0) |
58 | 58 |
|
59 | 59 |
|
60 | | -@pytest.mark.parametrize("type", ("hadamard", "random-hadamard")) |
| 60 | +@pytest.mark.parametrize("type", ("hadamard", "random-hadamard", "random-matrix")) |
61 | 61 | @pytest.mark.parametrize("randomize", (True, False)) |
62 | 62 | @pytest.mark.parametrize("embed_loc", ("weight_output", "output")) |
63 | 63 | @pytest.mark.parametrize("linear_loc", ("input", "weight_input")) |
@@ -89,7 +89,7 @@ def test_correctness_embedding(type, randomize, embed_loc, linear_loc): |
89 | 89 | assert torch.allclose(true_output, output, atol=1e-5, rtol=0.0) |
90 | 90 |
|
91 | 91 |
|
92 | | -@pytest.mark.parametrize("type", ("hadamard", "random-hadamard")) |
| 92 | +@pytest.mark.parametrize("type", ("hadamard", "random-hadamard", "random-matrix")) |
93 | 93 | @pytest.mark.parametrize("randomize", (True, False)) |
94 | 94 | @pytest.mark.parametrize("input_batch_size", (1, 5, 17)) |
95 | 95 | def test_correctness_model( |
@@ -121,14 +121,14 @@ def test_correctness_model( |
121 | 121 |
|
122 | 122 | @requires_gpu |
123 | 123 | @requires_accelerate() |
124 | | -@pytest.mark.parametrize("type", ("hadamard", "random-hadamard")) |
| 124 | +@pytest.mark.parametrize("type", ("hadamard", "random-hadamard", "random-matrix")) |
125 | 125 | @pytest.mark.parametrize("randomize", (True, False)) |
126 | 126 | @pytest.mark.parametrize("input_batch_size", (1, 5, 17)) |
127 | 127 | def test_correctness_model_offload(type, randomize, input_batch_size, model_apply): |
128 | 128 | test_correctness_model(type, randomize, input_batch_size, model_apply, offload=True) |
129 | 129 |
|
130 | 130 |
|
131 | | -@pytest.mark.parametrize("type", ("hadamard", "random-hadamard")) |
| 131 | +@pytest.mark.parametrize("type", ("hadamard", "random-hadamard", "random-matrix")) |
132 | 132 | @pytest.mark.parametrize("randomize", (True, False)) |
133 | 133 | @pytest.mark.parametrize("head_dim", (4, 8)) |
134 | 134 | @pytest.mark.parametrize("input_batch_size", (1, 5, 17)) |
@@ -164,3 +164,34 @@ def test_correctness_attention_heads(type, randomize, head_dim, input_batch_size |
164 | 164 |
|
165 | 165 | output = attention(input) |
166 | 166 | assert torch.allclose(true_output, output, atol=1e-5, rtol=0.0) |
| 167 | + |
| 168 | + |
| 169 | +@requires_gpu |
| 170 | +@pytest.mark.parametrize("cuda_default", (True, False)) |
| 171 | +def test_random_matrix_device_handling(cuda_default): |
| 172 | + """ |
| 173 | + Test that random-matrix transforms can be created |
| 174 | + on CUDA. |
| 175 | + """ |
| 176 | + seed = 0 |
| 177 | + size = (4, 8) |
| 178 | + |
| 179 | + cur_default = torch.get_default_device() |
| 180 | + if cuda_default: |
| 181 | + torch.set_default_device("cuda") |
| 182 | + module = torch.nn.Linear(*size, bias=False).cuda() |
| 183 | + scheme = TransformScheme(type="random-matrix", randomize=True) |
| 184 | + factory = TransformFactory.from_scheme(scheme, name="", seed=seed) |
| 185 | + |
| 186 | + # Create transforms - this should work despite CPU generator and CUDA module |
| 187 | + input_tfm = factory.create_transform( |
| 188 | + module, TransformArgs(targets="Linear", location="input", inverse=True) |
| 189 | + ) |
| 190 | + |
| 191 | + # Verify transforms work correctly on CUDA |
| 192 | + input = torch.rand((5, 3, size[0])).cuda() |
| 193 | + input_tfm(input) |
| 194 | + |
| 195 | + # Verify that transforms were created on CUDA |
| 196 | + assert input_tfm.weight.device.type == "cuda" |
| 197 | + torch.set_default_device(cur_default) |
0 commit comments