Skip to content

Commit ed002e9

Browse files
authored
Fixing bug in matrix_multiply.py (#507)
* Fixing bug in matrix_multiply.py generator is created on cpu by default so trying to create a tensor on cuda with a cpu generator would cause an error, we can just move it over instead. added a test to demonstrate that this works and was broken before. added random-matrix to the other correctness tests (test was generated with claude-code and verified by me) Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com> * formatting Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com> * handing set default device Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com> * formatting Summary Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com> --------- Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com>
1 parent 20a63ea commit ed002e9

File tree

2 files changed

+38
-7
lines changed

2 files changed

+38
-7
lines changed

src/compressed_tensors/transform/factory/matrix_multiply.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ def _create_weight(self, size: int, device: device, precision: dtype) -> Paramet
6868
(size, size),
6969
generator=self.generator,
7070
dtype=precision,
71-
device=device,
72-
)
71+
device=self.generator.device,
72+
).to(device)
7373
return Parameter(data, requires_grad=self.scheme.requires_grad)
7474

7575
def _create_inverse(self, weight: Parameter) -> Parameter:

tests/test_transform/factory/test_correctness.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from tests.testing_utils import requires_accelerate, requires_gpu
2727

2828

29-
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
29+
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard", "random-matrix"))
3030
@pytest.mark.parametrize("randomize", (True, False))
3131
@pytest.mark.parametrize("head_dim", (None, 2, 4))
3232
@pytest.mark.parametrize("input_batch_size", (1, 5, 17))
@@ -57,7 +57,7 @@ def test_correctness_linear(type, randomize, head_dim, input_batch_size):
5757
assert torch.allclose(true_output, output, atol=1e-5, rtol=0.0)
5858

5959

60-
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
60+
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard", "random-matrix"))
6161
@pytest.mark.parametrize("randomize", (True, False))
6262
@pytest.mark.parametrize("embed_loc", ("weight_output", "output"))
6363
@pytest.mark.parametrize("linear_loc", ("input", "weight_input"))
@@ -89,7 +89,7 @@ def test_correctness_embedding(type, randomize, embed_loc, linear_loc):
8989
assert torch.allclose(true_output, output, atol=1e-5, rtol=0.0)
9090

9191

92-
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
92+
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard", "random-matrix"))
9393
@pytest.mark.parametrize("randomize", (True, False))
9494
@pytest.mark.parametrize("input_batch_size", (1, 5, 17))
9595
def test_correctness_model(
@@ -121,14 +121,14 @@ def test_correctness_model(
121121

122122
@requires_gpu
123123
@requires_accelerate()
124-
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
124+
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard", "random-matrix"))
125125
@pytest.mark.parametrize("randomize", (True, False))
126126
@pytest.mark.parametrize("input_batch_size", (1, 5, 17))
127127
def test_correctness_model_offload(type, randomize, input_batch_size, model_apply):
128128
test_correctness_model(type, randomize, input_batch_size, model_apply, offload=True)
129129

130130

131-
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
131+
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard", "random-matrix"))
132132
@pytest.mark.parametrize("randomize", (True, False))
133133
@pytest.mark.parametrize("head_dim", (4, 8))
134134
@pytest.mark.parametrize("input_batch_size", (1, 5, 17))
@@ -164,3 +164,34 @@ def test_correctness_attention_heads(type, randomize, head_dim, input_batch_size
164164

165165
output = attention(input)
166166
assert torch.allclose(true_output, output, atol=1e-5, rtol=0.0)
167+
168+
169+
@requires_gpu
170+
@pytest.mark.parametrize("cuda_default", (True, False))
171+
def test_random_matrix_device_handling(cuda_default):
172+
"""
173+
Test that random-matrix transforms can be created
174+
on CUDA.
175+
"""
176+
seed = 0
177+
size = (4, 8)
178+
179+
cur_default = torch.get_default_device()
180+
if cuda_default:
181+
torch.set_default_device("cuda")
182+
module = torch.nn.Linear(*size, bias=False).cuda()
183+
scheme = TransformScheme(type="random-matrix", randomize=True)
184+
factory = TransformFactory.from_scheme(scheme, name="", seed=seed)
185+
186+
# Create transforms - this should work despite CPU generator and CUDA module
187+
input_tfm = factory.create_transform(
188+
module, TransformArgs(targets="Linear", location="input", inverse=True)
189+
)
190+
191+
# Verify transforms work correctly on CUDA
192+
input = torch.rand((5, 3, size[0])).cuda()
193+
input_tfm(input)
194+
195+
# Verify that transforms were created on CUDA
196+
assert input_tfm.weight.device.type == "cuda"
197+
torch.set_default_device(cur_default)

0 commit comments

Comments
 (0)