Skip to content

Commit f3c89cc

Browse files
NicolasHugpmeier
andauthored
Remove cutmix and mixup from prototype (#7787)
Co-authored-by: Philip Meier <github.pmeier@posteo.de>
1 parent cab9fba commit f3c89cc

File tree

3 files changed

+3
-142
lines changed

3 files changed

+3
-142
lines changed

test/test_prototype_transforms.py

Lines changed: 1 addition & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,10 @@
1212
make_bounding_box,
1313
make_detection_mask,
1414
make_image,
15-
make_images,
16-
make_segmentation_mask,
1715
make_video,
18-
make_videos,
1916
)
2017

21-
from prototype_common_utils import make_label, make_one_hot_labels
18+
from prototype_common_utils import make_label
2219

2320
from torchvision.datapoints import BoundingBoxes, BoundingBoxFormat, Image, Mask, Video
2421
from torchvision.prototype import datapoints, transforms
@@ -44,49 +41,6 @@ def parametrize(transforms_with_inputs):
4441
)
4542

4643

47-
@parametrize(
48-
[
49-
(
50-
transform,
51-
[
52-
dict(inpt=inpt, one_hot_label=one_hot_label)
53-
for inpt, one_hot_label in itertools.product(
54-
itertools.chain(
55-
make_images(extra_dims=BATCH_EXTRA_DIMS, dtypes=[torch.float]),
56-
make_videos(extra_dims=BATCH_EXTRA_DIMS, dtypes=[torch.float]),
57-
),
58-
make_one_hot_labels(extra_dims=BATCH_EXTRA_DIMS, dtypes=[torch.float]),
59-
)
60-
],
61-
)
62-
for transform in [
63-
transforms.RandomMixUp(alpha=1.0),
64-
transforms.RandomCutMix(alpha=1.0),
65-
]
66-
]
67-
)
68-
def test_mixup_cutmix(transform, input):
69-
transform(input)
70-
71-
input_copy = dict(input)
72-
input_copy["path"] = "/path/to/somewhere"
73-
input_copy["num"] = 1234
74-
transform(input_copy)
75-
76-
# Check if we raise an error if sample contains bbox or mask or label
77-
err_msg = "does not support PIL images, bounding boxes, masks and plain labels"
78-
input_copy = dict(input)
79-
for unsup_data in [
80-
make_label(),
81-
make_bounding_box(format="XYXY"),
82-
make_detection_mask(),
83-
make_segmentation_mask(),
84-
]:
85-
input_copy["unsupported"] = unsup_data
86-
with pytest.raises(TypeError, match=err_msg):
87-
transform(input_copy)
88-
89-
9044
class TestSimpleCopyPaste:
9145
def create_fake_image(self, mocker, image_type):
9246
if image_type == PIL.Image.Image:
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from ._presets import StereoMatching # usort: skip
22

3-
from ._augment import RandomCutMix, RandomMixUp, SimpleCopyPaste
3+
from ._augment import SimpleCopyPaste
44
from ._geometry import FixedSizeCrop
55
from ._misc import PermuteDimensions, TransposeDimensions
66
from ._type_conversion import LabelToOneHot

torchvision/prototype/transforms/_augment.py

Lines changed: 1 addition & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import math
21
from typing import Any, cast, Dict, List, Optional, Tuple, Union
32

43
import PIL.Image
@@ -9,100 +8,8 @@
98
from torchvision.prototype import datapoints as proto_datapoints
109
from torchvision.transforms.v2 import functional as F, InterpolationMode, Transform
1110

12-
from torchvision.transforms.v2._transform import _RandomApplyTransform
1311
from torchvision.transforms.v2.functional._geometry import _check_interpolation
14-
from torchvision.transforms.v2.utils import has_any, is_simple_tensor, query_size
15-
16-
17-
class _BaseMixUpCutMix(_RandomApplyTransform):
18-
def __init__(self, alpha: float, p: float = 0.5) -> None:
19-
super().__init__(p=p)
20-
self.alpha = alpha
21-
self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha]))
22-
23-
def _check_inputs(self, flat_inputs: List[Any]) -> None:
24-
if not (
25-
has_any(flat_inputs, datapoints.Image, datapoints.Video, is_simple_tensor)
26-
and has_any(flat_inputs, proto_datapoints.OneHotLabel)
27-
):
28-
raise TypeError(f"{type(self).__name__}() is only defined for tensor images/videos and one-hot labels.")
29-
if has_any(flat_inputs, PIL.Image.Image, datapoints.BoundingBoxes, datapoints.Mask, proto_datapoints.Label):
30-
raise TypeError(
31-
f"{type(self).__name__}() does not support PIL images, bounding boxes, masks and plain labels."
32-
)
33-
34-
def _mixup_onehotlabel(self, inpt: proto_datapoints.OneHotLabel, lam: float) -> proto_datapoints.OneHotLabel:
35-
if inpt.ndim < 2:
36-
raise ValueError("Need a batch of one hot labels")
37-
output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam))
38-
return proto_datapoints.OneHotLabel.wrap_like(inpt, output)
39-
40-
41-
class RandomMixUp(_BaseMixUpCutMix):
42-
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
43-
return dict(lam=float(self._dist.sample(()))) # type: ignore[arg-type]
44-
45-
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
46-
lam = params["lam"]
47-
if isinstance(inpt, (datapoints.Image, datapoints.Video)) or is_simple_tensor(inpt):
48-
expected_ndim = 5 if isinstance(inpt, datapoints.Video) else 4
49-
if inpt.ndim < expected_ndim:
50-
raise ValueError("The transform expects a batched input")
51-
output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam))
52-
53-
if isinstance(inpt, (datapoints.Image, datapoints.Video)):
54-
output = type(inpt).wrap_like(inpt, output) # type: ignore[arg-type]
55-
56-
return output
57-
elif isinstance(inpt, proto_datapoints.OneHotLabel):
58-
return self._mixup_onehotlabel(inpt, lam)
59-
else:
60-
return inpt
61-
62-
63-
class RandomCutMix(_BaseMixUpCutMix):
64-
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
65-
lam = float(self._dist.sample(())) # type: ignore[arg-type]
66-
67-
H, W = query_size(flat_inputs)
68-
69-
r_x = torch.randint(W, ())
70-
r_y = torch.randint(H, ())
71-
72-
r = 0.5 * math.sqrt(1.0 - lam)
73-
r_w_half = int(r * W)
74-
r_h_half = int(r * H)
75-
76-
x1 = int(torch.clamp(r_x - r_w_half, min=0))
77-
y1 = int(torch.clamp(r_y - r_h_half, min=0))
78-
x2 = int(torch.clamp(r_x + r_w_half, max=W))
79-
y2 = int(torch.clamp(r_y + r_h_half, max=H))
80-
box = (x1, y1, x2, y2)
81-
82-
lam_adjusted = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H))
83-
84-
return dict(box=box, lam_adjusted=lam_adjusted)
85-
86-
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
87-
if isinstance(inpt, (datapoints.Image, datapoints.Video)) or is_simple_tensor(inpt):
88-
box = params["box"]
89-
expected_ndim = 5 if isinstance(inpt, datapoints.Video) else 4
90-
if inpt.ndim < expected_ndim:
91-
raise ValueError("The transform expects a batched input")
92-
x1, y1, x2, y2 = box
93-
rolled = inpt.roll(1, 0)
94-
output = inpt.clone()
95-
output[..., y1:y2, x1:x2] = rolled[..., y1:y2, x1:x2]
96-
97-
if isinstance(inpt, (datapoints.Image, datapoints.Video)):
98-
output = inpt.wrap_like(inpt, output) # type: ignore[arg-type]
99-
100-
return output
101-
elif isinstance(inpt, proto_datapoints.OneHotLabel):
102-
lam_adjusted = params["lam_adjusted"]
103-
return self._mixup_onehotlabel(inpt, lam_adjusted)
104-
else:
105-
return inpt
12+
from torchvision.transforms.v2.utils import is_simple_tensor
10613

10714

10815
class SimpleCopyPaste(Transform):

0 commit comments

Comments
 (0)