Skip to content

Commit 2de738a

Browse files
paulineshocopybara-github
authored andcommitted
Add hadamard rotation tests with golden inputs.
PiperOrigin-RevId: 758431575
1 parent 7fe1d0d commit 2de738a

File tree

1 file changed

+49
-0
lines changed

1 file changed

+49
-0
lines changed

ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,55 @@ def test_get_tensor_quant_params_basic(self):
119119
if qparams.hadamard is not None:
120120
self.assertEqual(qparams.hadamard.hadamard_size, 32)
121121

122+
def test_get_tensor_quant_params_golden_1(self):
123+
test_data = np.ones((6, 6))
124+
# expected:
125+
# [[127 0 127 0 127 0]
126+
# [127 0 127 0 127 0]
127+
# [127 0 127 0 127 0]
128+
# [127 0 127 0 127 0]
129+
# [127 0 127 0 127 0]
130+
# [127 0 127 0 127 0]]
131+
expected = np.tile([127, 0], [6, 3])
132+
qparams = hadamard_rotation.get_tensor_quant_params(
133+
self._op_info,
134+
self._op_info.op_quant_config.weight_tensor_config,
135+
test_data,
136+
self._tensor_name_to_qsv,
137+
)
138+
self.assertIsNotNone(qparams.quantized_data)
139+
np.testing.assert_array_equal(
140+
np.array(qparams.quantized_data), expected
141+
)
142+
143+
def test_get_tensor_quant_params_golden_2(self):
144+
# test_data:
145+
# [[1 2 1 2 1 2]
146+
# [3 4 3 4 3 4]
147+
# [1 2 1 2 1 2]
148+
# [3 4 3 4 3 4]
149+
# [1 2 1 2 1 2]
150+
# [3 4 3 4 3 4]]
151+
test_data = np.tile([[1, 2], [3, 4]], [3, 3])
152+
# expected:
153+
# [[127 -42 127 -42 127 -42]
154+
# [127 -18 127 -18 127 -18]
155+
# [127 -42 127 -42 127 -42]
156+
# [127 -18 127 -18 127 -18]
157+
# [127 -42 127 -42 127 -42]
158+
# [127 -18 127 -18 127 -18]]
159+
expected = np.tile([[127, -42], [127, -18]], [3, 3])
160+
qparams = hadamard_rotation.get_tensor_quant_params(
161+
self._op_info,
162+
self._op_info.op_quant_config.weight_tensor_config,
163+
test_data,
164+
self._tensor_name_to_qsv,
165+
)
166+
self.assertIsNotNone(qparams.quantized_data)
167+
np.testing.assert_array_equal(
168+
np.array(qparams.quantized_data), expected
169+
)
170+
122171
def test_raise_missing_tensor_content(self):
123172
with self.assertRaisesWithPredicateMatch(
124173
ValueError, lambda err: "weight tensor" in str(err)

0 commit comments

Comments
 (0)