Skip to content

Commit f997b2b

Browse files
Revert "Add MaskedTensor passthrough: unfold, F.Unfold, F.Fold, stack (pytorch#125262)"
This reverts commit f685018. Reverted pytorch#125262 on behalf of https://github.com/ZainRizvi due to Hi, this PR appears to be calling maskedtensor tests to fail on main. Please rebase your changes onto the latest trunk build to repro the failure. test_maskedtensor.py::TestOperatorsCUDA::test_like_empty_like_layout1_cuda_bool [GH job link](https://github.com/pytorch/pytorch/actions/runs/10604716811/job/29393256312) [HUD commit link](https://hud.pytorch.org/pytorch/pytorch/commit/f685018ea9d08f98cbd7106028db134f967f74d3) ([comment](pytorch#125262 (comment)))
1 parent 6dd3f81 commit f997b2b

File tree

8 files changed

+14
-56
lines changed

8 files changed

+14
-56
lines changed

aten/src/ATen/native/Col2Im.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ static void col2im_out_cpu_template(
144144

145145
output.resize_({batch_size, n_output_plane, output_height, output_width});
146146

147-
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND3(kBFloat16, kHalf, kBool,
147+
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf,
148148
input.scalar_type(), "col2im_out_cpu", [&] {
149149
Tensor input_n = Tensor();
150150
Tensor output_n = Tensor();

aten/src/ATen/native/Im2Col.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ static void im2col_out_cpu_template(
9494

9595
output.resize_({batch_size, n_output_plane, output_length});
9696

97-
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND3(kBFloat16, kHalf, kBool,
97+
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf,
9898
input.scalar_type(), "im2col_out_cpu", [&] {
9999
Tensor input_n;
100100
Tensor output_n;

aten/src/ATen/native/cuda/Col2Im.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ void col2im_out_cuda_template(
102102
output.resize_({batch_size, n_output_plane, output_height, output_width});
103103
int64_t output_batch_stride = output.stride(0);
104104

105-
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND3(kHalf, kBFloat16, kBool,
105+
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16,
106106
input.scalar_type(), "col2im_out_cuda", [&] {
107107
int64_t height_col = (output_height + 2 * pad_height -
108108
(dilation_height * (kernel_height - 1) + 1)) /

aten/src/ATen/native/cuda/Im2Col.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ static void im2col_out_cuda_template(
103103
output.resize_({batch_size, n_output_plane, output_length});
104104

105105
// Launch kernel
106-
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND3(kHalf, kBFloat16, kBool,
106+
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16,
107107
input.scalar_type(), "im2col_out_cuda", [&] {
108108
Tensor input_n;
109109
Tensor output_n;

docs/source/masked.rst

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -283,11 +283,9 @@ The following ops are currently supported:
283283
kron
284284
meshgrid
285285
narrow
286-
nn.functional.unfold
287286
ravel
288287
select
289288
split
290-
stack
291289
t
292290
transpose
293291
vsplit
@@ -296,7 +294,6 @@ The following ops are currently supported:
296294
Tensor.expand_as
297295
Tensor.reshape
298296
Tensor.reshape_as
299-
Tensor.unfold
300297
Tensor.view
301298

302299
Other functions

test/test_maskedtensor.py

Lines changed: 8 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -68,18 +68,6 @@ def _compare_mts(mt1, mt2, rtol=1e-05, atol=1e-08):
6868
if not _tensors_match(a, b, exact=False, rtol=rtol, atol=atol):
6969
raise ValueError("The data in MaskedTensor mt1 and MaskedTensor mt2 do not match")
7070

71-
def _compare_forward_backward(data, mask, fn):
72-
mt = masked_tensor(data, mask, requires_grad=True)
73-
masked_res = fn(mt)
74-
masked_res.sum().backward()
75-
76-
t = data.masked_fill(~mask, float("-inf")).detach().clone().requires_grad_()
77-
tensor_res = fn(t)
78-
tensor_res.sum().backward()
79-
80-
_compare_mt_t(masked_res, tensor_res)
81-
_compare_mt_t(mt.grad, t.grad, atol=1e-06)
82-
8371

8472
def _create_random_mask(shape, device):
8573
return make_tensor(shape, device=device, dtype=torch.bool)
@@ -178,8 +166,15 @@ def test_softmax(self, device):
178166
],
179167
device=device
180168
)
169+
mt = masked_tensor(data, mask, requires_grad=True)
170+
masked_res = torch.softmax(mt, -1)
171+
masked_res.sum().backward()
172+
xinf = data.masked_fill(~mask, float("-inf")).detach().clone().requires_grad_()
173+
tensor_res = torch.softmax(xinf, -1)
174+
tensor_res.sum().backward()
181175

182-
_compare_forward_backward(data, mask, lambda t: torch.softmax(t, -1))
176+
_compare_mt_t(masked_res, tensor_res)
177+
_compare_mt_t(mt.grad, xinf.grad, atol=1e-06)
183178

184179
def test_where(self, device):
185180
data = torch.tensor([-10.0, -5, 0, 5, 10, 50, 60, 70, 80, 90, 100], device=device)
@@ -199,35 +194,6 @@ def test_where(self, device):
199194
_compare_mt_t(mx.grad, x.grad)
200195
_compare_mt_t(my.grad, y.grad)
201196

202-
def test_unfold(self, device):
203-
data = torch.rand(5, 5, device=device)
204-
mask = torch.rand(5, 5, device=device) > 0.5
205-
_compare_forward_backward(data, mask, lambda t: t.unfold(1, 2, 2))
206-
207-
def test_nn_unfold(self, device):
208-
data = torch.rand(2, 5, 3, 4, device=device)
209-
mask = torch.rand(2, 5, 3, 4, device=device) > 0.5
210-
_compare_forward_backward(data, mask, lambda t: torch.nn.functional.unfold(t, kernel_size=(2, 3)))
211-
212-
def test_stack(self, device):
213-
masked_tensors = [
214-
masked_tensor(
215-
torch.rand(2, 5, 3, 4, device=device),
216-
torch.rand(2, 5, 3, 4, device=device) > 0.5,
217-
requires_grad=True,
218-
) for _ in range(3)
219-
]
220-
221-
data_tensors = [mt.get_data().detach().clone().requires_grad_() for mt in masked_tensors]
222-
masked_res = torch.stack(masked_tensors)
223-
tensor_res = torch.stack(data_tensors)
224-
225-
masked_res.sum().backward()
226-
tensor_res.sum().backward()
227-
_compare_mt_t(masked_res, tensor_res)
228-
for mt, t in zip(masked_tensors, data_tensors):
229-
_compare_mt_t(mt.grad, t.grad, atol=1e-06)
230-
231197
def test_to_sparse(self, device):
232198
for sample in _generate_sample_data(device=device):
233199
data = sample.input

torch/masked/maskedtensor/passthrough.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,6 @@
3030
torch.ops.aten._reshape_alias,
3131
torch.ops.aten.cat,
3232
torch.ops.aten.unsqueeze,
33-
torch.ops.aten.unfold,
34-
torch.ops.aten.unfold_backward,
35-
torch.ops.aten.im2col,
36-
torch.ops.aten.col2im,
37-
torch.ops.aten.stack,
3833
]
3934

4035

torch/testing/_internal/common_methods_invocations.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15266,8 +15266,8 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs):
1526615266
autodiff_nonfusible_nodes=["aten::hardswish"]),
1526715267
OpInfo('nn.functional.unfold',
1526815268
aten_name='im2col',
15269-
dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16, torch.bool),
15270-
dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16, torch.bool),
15269+
dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16),
15270+
dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
1527115271
sample_inputs_func=sample_inputs_nn_unfold,
1527215272
# Runs very slowly on slow gradcheck - alternatively reduce input sizes
1527315273
gradcheck_fast_mode=True,

0 commit comments

Comments
 (0)