Skip to content

Commit d29d136

Browse files
vidursatijaTobyRoseman
authored andcommitted
Add test for copy_ op #918
1 parent bda440e commit d29d136

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

coremltools/converters/mil/frontend/torch/test/test_torch_ops.py

+11
Original file line numberDiff line numberDiff line change
@@ -4759,3 +4759,14 @@ def forward(self, x):
47594759
backend=backend,
47604760
converter_input_type=converter_input_type,
47614761
)
4762+
4763+
4764+
class TestCopy:
4765+
@pytest.mark.parametrize(
4766+
"backend, rank", itertools.product(backends, list(range(1, 6))),
4767+
)
4768+
def test_copy(self, backend, rank):
4769+
input_shape = tuple(np.random.randint(low=2, high=6, size=rank))
4770+
4771+
model = ModuleWrapper(function=lambda x: x.copy_())
4772+
run_compare_torch(input_shape, model, backend=backend)

0 commit comments

Comments
 (0)