Skip to content

Commit 1309849

Browse files
authored
Add function for input preprocessing in numerical comparator
Differential Revision: D76745314 Pull Request resolved: #11739
1 parent 962db1b commit 1309849

File tree

2 files changed

+78
-0
lines changed

2 files changed

+78
-0
lines changed

devtools/inspector/_inspector_utils.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -690,3 +690,34 @@ def map_runtime_aot_intermediate_outputs(
690690
)
691691

692692
return aot_runtime_mapping
693+
694+
695+
def convert_to_float_tensor(input_data: Any) -> torch.Tensor:
696+
"""
697+
Convert input_data into a torch.Tensor on CPU with dtype torch.float64.
698+
This function handles the following types of input:
699+
- Scalar (int or float): Converts to a tensor with a single element.
700+
- Tensor: Converts to a float64 tensor on CPU.
701+
- List of Tensors: Stacks the tensors into a single float64 tensor on CPU.
702+
The resulting tensor is detached, moved to CPU, and cast to torch.float64.
703+
Parameters:
704+
input_data (Any): The input data to be converted to a tensor. It can be a scalar,
705+
a tensor, or a list of tensors.
706+
Returns:
707+
torch.Tensor: A tensor on CPU with dtype torch.float64.
708+
Raises:
709+
ValueError: If the input_data cannot be converted to a tensor.
710+
"""
711+
try:
712+
# Check if the input is a list of tensors
713+
if isinstance(input_data, list):
714+
input_tensor = torch.stack([convert_to_float_tensor(a) for a in input_data])
715+
# Try to convert the input to a tensor
716+
else:
717+
input_tensor = torch.as_tensor(input_data, dtype=torch.float64)
718+
except Exception as e:
719+
raise ValueError(
720+
f"Cannot convert value of type {type(input_data)} to a tensor: {e}"
721+
)
722+
input_tensor = input_tensor.detach().cpu().double()
723+
return input_tensor

devtools/inspector/tests/inspector_utils_test.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
calculate_mse,
3030
calculate_snr,
3131
calculate_time_scale_factor,
32+
convert_to_float_tensor,
3233
create_debug_handle_to_op_node_mapping,
3334
EDGE_DIALECT_GRAPH_KEY,
3435
find_populated_event,
@@ -317,6 +318,52 @@ def test_map_runtime_aot_intermediate_outputs_complex_chain(self):
317318
expected = {((1, 2, 3, 4, 5, 6), 300): ((2, 3, 4, 5, 6, 7), 350)}
318319
self.assertEqual(actual, expected)
319320

321+
def test_convert_input_to_tensor_convertible_inputs(self):
322+
# Scalar -> tensor
323+
actual_output1 = convert_to_float_tensor(5)
324+
self.assertIsInstance(actual_output1, torch.Tensor)
325+
self.assertEqual(actual_output1.dtype, torch.float64)
326+
self.assertEqual(tuple(actual_output1.shape), ())
327+
self.assertTrue(
328+
torch.allclose(actual_output1, torch.tensor([5.0], dtype=torch.float64))
329+
)
330+
self.assertEqual(actual_output1.device.type, "cpu")
331+
332+
# Tensor of ints -> float32 CPU
333+
t_int = torch.tensor([4, 5, 6], dtype=torch.int32)
334+
actual_output2 = convert_to_float_tensor(t_int)
335+
self.assertIsInstance(actual_output2, torch.Tensor)
336+
self.assertEqual(actual_output2.dtype, torch.float64)
337+
self.assertTrue(
338+
torch.allclose(
339+
actual_output2, torch.tensor([4.0, 5.0, 6.0], dtype=torch.float64)
340+
)
341+
)
342+
self.assertEqual(actual_output2.device.type, "cpu")
343+
344+
# List of tensors -> stacked tensor float32 CPU
345+
t_list = [torch.tensor([1, 2]), torch.tensor([2, 3]), torch.tensor([3, 4])]
346+
actual_output3 = convert_to_float_tensor(t_list)
347+
self.assertIsInstance(actual_output3, torch.Tensor)
348+
self.assertEqual(actual_output3.dtype, torch.float64)
349+
self.assertEqual(tuple(actual_output3.shape), (3, 2))
350+
self.assertTrue(
351+
torch.allclose(
352+
actual_output3,
353+
torch.tensor([[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]], dtype=torch.float64),
354+
)
355+
)
356+
self.assertEqual(actual_output3.device.type, "cpu")
357+
358+
def test_convert_input_to_tensor_non_convertible_raises(self):
359+
class X:
360+
pass
361+
362+
with self.assertRaises(ValueError) as cm:
363+
convert_to_float_tensor(X())
364+
msg = str(cm.exception)
365+
self.assertIn("Cannot convert value of type", msg)
366+
320367

321368
def gen_mock_operator_graph_with_expected_map() -> (
322369
Tuple[OperatorGraph, Dict[int, OperatorNode]]

0 commit comments

Comments
 (0)