|
29 | 29 | calculate_mse,
|
30 | 30 | calculate_snr,
|
31 | 31 | calculate_time_scale_factor,
|
| 32 | + convert_to_float_tensor, |
32 | 33 | create_debug_handle_to_op_node_mapping,
|
33 | 34 | EDGE_DIALECT_GRAPH_KEY,
|
34 | 35 | find_populated_event,
|
@@ -317,6 +318,52 @@ def test_map_runtime_aot_intermediate_outputs_complex_chain(self):
|
317 | 318 | expected = {((1, 2, 3, 4, 5, 6), 300): ((2, 3, 4, 5, 6, 7), 350)}
|
318 | 319 | self.assertEqual(actual, expected)
|
319 | 320 |
|
| 321 | + def test_convert_input_to_tensor_convertible_inputs(self): |
| 322 | + # Scalar -> tensor |
| 323 | + actual_output1 = convert_to_float_tensor(5) |
| 324 | + self.assertIsInstance(actual_output1, torch.Tensor) |
| 325 | + self.assertEqual(actual_output1.dtype, torch.float64) |
| 326 | + self.assertEqual(tuple(actual_output1.shape), ()) |
| 327 | + self.assertTrue( |
| 328 | + torch.allclose(actual_output1, torch.tensor([5.0], dtype=torch.float64)) |
| 329 | + ) |
| 330 | + self.assertEqual(actual_output1.device.type, "cpu") |
| 331 | + |
| 332 | + # Tensor of ints -> float32 CPU |
| 333 | + t_int = torch.tensor([4, 5, 6], dtype=torch.int32) |
| 334 | + actual_output2 = convert_to_float_tensor(t_int) |
| 335 | + self.assertIsInstance(actual_output2, torch.Tensor) |
| 336 | + self.assertEqual(actual_output2.dtype, torch.float64) |
| 337 | + self.assertTrue( |
| 338 | + torch.allclose( |
| 339 | + actual_output2, torch.tensor([4.0, 5.0, 6.0], dtype=torch.float64) |
| 340 | + ) |
| 341 | + ) |
| 342 | + self.assertEqual(actual_output2.device.type, "cpu") |
| 343 | + |
| 344 | + # List of tensors -> stacked tensor float32 CPU |
| 345 | + t_list = [torch.tensor([1, 2]), torch.tensor([2, 3]), torch.tensor([3, 4])] |
| 346 | + actual_output3 = convert_to_float_tensor(t_list) |
| 347 | + self.assertIsInstance(actual_output3, torch.Tensor) |
| 348 | + self.assertEqual(actual_output3.dtype, torch.float64) |
| 349 | + self.assertEqual(tuple(actual_output3.shape), (3, 2)) |
| 350 | + self.assertTrue( |
| 351 | + torch.allclose( |
| 352 | + actual_output3, |
| 353 | + torch.tensor([[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]], dtype=torch.float64), |
| 354 | + ) |
| 355 | + ) |
| 356 | + self.assertEqual(actual_output3.device.type, "cpu") |
| 357 | + |
| 358 | + def test_convert_input_to_tensor_non_convertible_raises(self): |
| 359 | + class X: |
| 360 | + pass |
| 361 | + |
| 362 | + with self.assertRaises(ValueError) as cm: |
| 363 | + convert_to_float_tensor(X()) |
| 364 | + msg = str(cm.exception) |
| 365 | + self.assertIn("Cannot convert value of type", msg) |
| 366 | + |
320 | 367 |
|
321 | 368 | def gen_mock_operator_graph_with_expected_map() -> (
|
322 | 369 | Tuple[OperatorGraph, Dict[int, OperatorNode]]
|
|
0 commit comments