|
1 |
| -import math |
2 | 1 | from typing import Any, cast, Dict, List, Optional, Tuple, Union
|
3 | 2 |
|
4 | 3 | import PIL.Image
|
|
9 | 8 | from torchvision.prototype import datapoints as proto_datapoints
|
10 | 9 | from torchvision.transforms.v2 import functional as F, InterpolationMode, Transform
|
11 | 10 |
|
12 |
| -from torchvision.transforms.v2._transform import _RandomApplyTransform |
13 | 11 | from torchvision.transforms.v2.functional._geometry import _check_interpolation
|
14 |
| -from torchvision.transforms.v2.utils import has_any, is_simple_tensor, query_size |
15 |
| - |
16 |
| - |
17 |
| -class _BaseMixUpCutMix(_RandomApplyTransform): |
18 |
| - def __init__(self, alpha: float, p: float = 0.5) -> None: |
19 |
| - super().__init__(p=p) |
20 |
| - self.alpha = alpha |
21 |
| - self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha])) |
22 |
| - |
23 |
| - def _check_inputs(self, flat_inputs: List[Any]) -> None: |
24 |
| - if not ( |
25 |
| - has_any(flat_inputs, datapoints.Image, datapoints.Video, is_simple_tensor) |
26 |
| - and has_any(flat_inputs, proto_datapoints.OneHotLabel) |
27 |
| - ): |
28 |
| - raise TypeError(f"{type(self).__name__}() is only defined for tensor images/videos and one-hot labels.") |
29 |
| - if has_any(flat_inputs, PIL.Image.Image, datapoints.BoundingBoxes, datapoints.Mask, proto_datapoints.Label): |
30 |
| - raise TypeError( |
31 |
| - f"{type(self).__name__}() does not support PIL images, bounding boxes, masks and plain labels." |
32 |
| - ) |
33 |
| - |
34 |
| - def _mixup_onehotlabel(self, inpt: proto_datapoints.OneHotLabel, lam: float) -> proto_datapoints.OneHotLabel: |
35 |
| - if inpt.ndim < 2: |
36 |
| - raise ValueError("Need a batch of one hot labels") |
37 |
| - output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam)) |
38 |
| - return proto_datapoints.OneHotLabel.wrap_like(inpt, output) |
39 |
| - |
40 |
| - |
41 |
| -class RandomMixUp(_BaseMixUpCutMix): |
42 |
| - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: |
43 |
| - return dict(lam=float(self._dist.sample(()))) # type: ignore[arg-type] |
44 |
| - |
45 |
| - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: |
46 |
| - lam = params["lam"] |
47 |
| - if isinstance(inpt, (datapoints.Image, datapoints.Video)) or is_simple_tensor(inpt): |
48 |
| - expected_ndim = 5 if isinstance(inpt, datapoints.Video) else 4 |
49 |
| - if inpt.ndim < expected_ndim: |
50 |
| - raise ValueError("The transform expects a batched input") |
51 |
| - output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam)) |
52 |
| - |
53 |
| - if isinstance(inpt, (datapoints.Image, datapoints.Video)): |
54 |
| - output = type(inpt).wrap_like(inpt, output) # type: ignore[arg-type] |
55 |
| - |
56 |
| - return output |
57 |
| - elif isinstance(inpt, proto_datapoints.OneHotLabel): |
58 |
| - return self._mixup_onehotlabel(inpt, lam) |
59 |
| - else: |
60 |
| - return inpt |
61 |
| - |
62 |
| - |
63 |
| -class RandomCutMix(_BaseMixUpCutMix): |
64 |
| - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: |
65 |
| - lam = float(self._dist.sample(())) # type: ignore[arg-type] |
66 |
| - |
67 |
| - H, W = query_size(flat_inputs) |
68 |
| - |
69 |
| - r_x = torch.randint(W, ()) |
70 |
| - r_y = torch.randint(H, ()) |
71 |
| - |
72 |
| - r = 0.5 * math.sqrt(1.0 - lam) |
73 |
| - r_w_half = int(r * W) |
74 |
| - r_h_half = int(r * H) |
75 |
| - |
76 |
| - x1 = int(torch.clamp(r_x - r_w_half, min=0)) |
77 |
| - y1 = int(torch.clamp(r_y - r_h_half, min=0)) |
78 |
| - x2 = int(torch.clamp(r_x + r_w_half, max=W)) |
79 |
| - y2 = int(torch.clamp(r_y + r_h_half, max=H)) |
80 |
| - box = (x1, y1, x2, y2) |
81 |
| - |
82 |
| - lam_adjusted = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H)) |
83 |
| - |
84 |
| - return dict(box=box, lam_adjusted=lam_adjusted) |
85 |
| - |
86 |
| - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: |
87 |
| - if isinstance(inpt, (datapoints.Image, datapoints.Video)) or is_simple_tensor(inpt): |
88 |
| - box = params["box"] |
89 |
| - expected_ndim = 5 if isinstance(inpt, datapoints.Video) else 4 |
90 |
| - if inpt.ndim < expected_ndim: |
91 |
| - raise ValueError("The transform expects a batched input") |
92 |
| - x1, y1, x2, y2 = box |
93 |
| - rolled = inpt.roll(1, 0) |
94 |
| - output = inpt.clone() |
95 |
| - output[..., y1:y2, x1:x2] = rolled[..., y1:y2, x1:x2] |
96 |
| - |
97 |
| - if isinstance(inpt, (datapoints.Image, datapoints.Video)): |
98 |
| - output = inpt.wrap_like(inpt, output) # type: ignore[arg-type] |
99 |
| - |
100 |
| - return output |
101 |
| - elif isinstance(inpt, proto_datapoints.OneHotLabel): |
102 |
| - lam_adjusted = params["lam_adjusted"] |
103 |
| - return self._mixup_onehotlabel(inpt, lam_adjusted) |
104 |
| - else: |
105 |
| - return inpt |
| 12 | +from torchvision.transforms.v2.utils import is_simple_tensor |
106 | 13 |
|
107 | 14 |
|
108 | 15 | class SimpleCopyPaste(Transform):
|
|
0 commit comments