Skip to content

Commit 335262f

Browse files
authored
Merge branch 'develop' into vs/sseg_tiler_mapi
2 parents 33fb8fa + 28f6529 commit 335262f

File tree

13 files changed

+145
-52
lines changed

13 files changed

+145
-52
lines changed

CHANGELOG.md

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,35 @@
22

33
All notable changes to this project will be documented in this file.
44

5-
## \[2.2.0\]
5+
## \[2.3.0\]
66

77
### New features
88

99
- Add YOLOv9 model for Object Detection
1010
(https://github.com/openvinotoolkit/training_extensions/pull/3917)
11+
- Add OV inference for keypoint detection
12+
(https://github.com/openvinotoolkit/training_extensions/pull/3970)
13+
- Add tiling for semantic segmentation
14+
(https://github.com/openvinotoolkit/training_extensions/pull/3954)
15+
16+
### Enhancements
17+
18+
- Upgrade OV, MAPI, and NNCF dependencies
19+
(https://github.com/openvinotoolkit/training_extensions/pull/3967)
20+
- Instance Segmentation Model refactoring
21+
(https://github.com/openvinotoolkit/training_extensions/pull/3865)
22+
- Bump torch and lightning to 2.4.0 versions
23+
(https://github.com/openvinotoolkit/training_extensions/pull/3843)
24+
25+
### Bug fixes
26+
27+
- Fix a wrong HPO log
28+
(https://github.com/openvinotoolkit/training_extensions/pull/3972)
29+
30+
## \[2.2.0\]
31+
32+
### New features
33+
1134
- Add RT-DETR model for Object Detection
1235
(https://github.com/openvinotoolkit/training_extensions/pull/3741)
1336
- Add Multi-Label & H-label Classification with torchvision models

src/otx/algo/detection/heads/yolo_head.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -296,8 +296,8 @@ class YOLOHeadModule(BaseDenseHead):
296296
csp_args (dict[str, Any], optional): Arguments for CSP blocks. Defaults to None.
297297
aux_cfg (dict[str, Any], optional): Configuration for auxiliary head. Defaults to None.
298298
with_nms (bool, optional): Whether to use NMS. Defaults to True.
299-
min_confidence (float, optional): Minimum confidence for NMS. Defaults to 0.05.
300-
min_iou (float, optional): Minimum IoU for NMS. Defaults to 0.9.
299+
min_confidence (float, optional): Minimum confidence for NMS. Defaults to 0.1.
300+
min_iou (float, optional): Minimum IoU for NMS. Defaults to 0.65.
301301
"""
302302

303303
def __init__(
@@ -311,8 +311,8 @@ def __init__(
311311
csp_args: dict[str, Any] | None = None,
312312
aux_cfg: dict[str, Any] | None = None,
313313
with_nms: bool = True,
314-
min_confidence: float = 0.05,
315-
min_iou: float = 0.9,
314+
min_confidence: float = 0.1,
315+
min_iou: float = 0.65,
316316
) -> None:
317317
if len(csp_channels) - 1 != len(concat_sources):
318318
msg = (

src/otx/algo/detection/losses/yolov9_loss.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -371,9 +371,9 @@ def __init__(
371371
loss_dfl: nn.Module | None = None,
372372
loss_iou: nn.Module | None = None,
373373
reg_max: int = 16,
374-
cls_rate: float = 1.5,
375-
dfl_rate: float = 7.5,
376-
iou_rate: float = 0.5,
374+
cls_rate: float = 0.5,
375+
dfl_rate: float = 1.5,
376+
iou_rate: float = 7.5,
377377
aux_rate: float = 0.25,
378378
) -> None:
379379
super().__init__()
@@ -394,7 +394,7 @@ def forward(
394394
main_preds: tuple[Tensor, Tensor, Tensor],
395395
targets: Tensor,
396396
aux_preds: tuple[Tensor, Tensor, Tensor] | None = None,
397-
) -> dict[str, Tensor]:
397+
) -> dict[str, Tensor] | None:
398398
"""Forward pass of the YOLOv9 criterion module.
399399
400400
Args:
@@ -405,6 +405,10 @@ def forward(
405405
Returns:
406406
dict[str, Tensor]: The loss dictionary.
407407
"""
408+
if targets.shape[1] == 0:
409+
# TODO (sungchul): should this step be done here?
410+
return None
411+
408412
main_preds = self.vec2box(main_preds)
409413
main_iou, main_dfl, main_cls = self._forward(main_preds, targets)
410414

src/otx/algo/detection/rtdetr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def _customize_inputs(
135135

136136
def _customize_outputs(
137137
self,
138-
outputs: list[torch.Tensor] | dict,
138+
outputs: list[torch.Tensor] | dict, # type: ignore[override]
139139
inputs: DetBatchDataEntity,
140140
) -> DetBatchPredEntity | OTXBatchLossEntity:
141141
if self.training:

src/otx/algo/detection/yolov9.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from otx.core.metrics.fmeasure import MeanAveragePrecisionFMeasureCallable
2020
from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable
2121
from otx.core.model.detection import OTXDetectionModel
22+
from otx.core.types.export import TaskLevelExportParameters
2223

2324
if TYPE_CHECKING:
2425
from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
@@ -116,7 +117,7 @@ def _exporter(self) -> OTXModelExporter:
116117
std=self.std,
117118
resize_mode="fit_to_window_letterbox",
118119
pad_value=114,
119-
swap_rgb=True,
120+
swap_rgb=False,
120121
via_onnx=True,
121122
onnx_export_configuration={
122123
"input_names": ["image"],
@@ -135,6 +136,14 @@ def _exporter(self) -> OTXModelExporter:
135136
output_names=None, # TODO (someone): support XAI
136137
)
137138

139+
@property
140+
def _export_parameters(self) -> TaskLevelExportParameters:
141+
"""Defines parameters required to export a particular model implementation."""
142+
return super()._export_parameters.wrap(
143+
confidence_threshold=self.model.bbox_head.min_confidence,
144+
iou_threshold=self.model.bbox_head.min_iou,
145+
)
146+
138147
def to(self, *args, **kwargs) -> Self:
139148
"""Sync device of the model and its components."""
140149
ret = super().to(*args, **kwargs)

src/otx/core/model/base.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,13 @@ def __init__(
141141
# so that it can retrieve it from the checkpoint
142142
self.save_hyperparameters(logger=False, ignore=["optimizer", "scheduler", "metric"])
143143

144-
def training_step(self, batch: T_OTXBatchDataEntity, batch_idx: int) -> Tensor:
144+
def training_step(self, batch: T_OTXBatchDataEntity, batch_idx: int) -> Tensor | None:
145145
"""Step for model training."""
146146
train_loss = self.forward(inputs=batch)
147+
if train_loss is None:
148+
# to skip current iteration
149+
# TODO (sungchul): check this in distributed training
150+
return None if self.trainer.world_size == 1 else torch.tensor(0.0, device=self.device)
147151

148152
if isinstance(train_loss, Tensor):
149153
self.log(

src/otx/core/model/detection.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,12 +137,15 @@ def _customize_inputs(
137137

138138
return inputs
139139

140-
def _customize_outputs(
140+
def _customize_outputs( # type: ignore[override]
141141
self,
142-
outputs: list[InstanceData] | dict,
142+
outputs: list[InstanceData] | dict | None,
143143
inputs: DetBatchDataEntity,
144-
) -> DetBatchPredEntity | OTXBatchLossEntity:
144+
) -> DetBatchPredEntity | OTXBatchLossEntity | None:
145145
if self.training:
146+
if outputs is None:
147+
return outputs
148+
146149
if not isinstance(outputs, dict):
147150
raise TypeError(outputs)
148151

src/otx/recipe/detection/yolov9_c.yaml

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ model:
77
optimizer:
88
class_path: torch.optim.SGD
99
init_args:
10-
lr: 0.001
10+
lr: 0.0001
1111
momentum: 0.937
1212
weight_decay: 0.0005
1313
nesterov: true
@@ -16,13 +16,13 @@ model:
1616
class_path: otx.core.schedulers.LinearWarmupSchedulerCallable
1717
init_args:
1818
num_warmup_steps: 3
19-
warmup_interval: epoch
2019
main_scheduler_callable:
21-
class_path: torch.optim.lr_scheduler.LinearLR
20+
class_path: lightning.pytorch.cli.ReduceLROnPlateau
2221
init_args:
23-
total_iters: 200
24-
start_factor: 1
25-
end_factor: 0.01
22+
mode: max
23+
factor: 0.1
24+
patience: 4
25+
monitor: val/map_50
2626

2727
engine:
2828
task: DETECTION
@@ -42,23 +42,38 @@ overrides:
4242
input_size:
4343
- 640
4444
- 640
45-
image_color_channel: BGR
4645
train_subset:
47-
batch_size: 16
46+
batch_size: 10
4847
transforms:
4948
- class_path: otx.core.data.transform_libs.torchvision.CachedMosaic
5049
init_args:
5150
random_pop: false
5251
max_cached_images: 20
5352
img_scale: $(input_size) # (H, W)
54-
- class_path: otx.core.data.transform_libs.torchvision.RandomCrop
53+
- class_path: otx.core.data.transform_libs.torchvision.RandomAffine
5554
init_args:
56-
crop_size: $(input_size) * 0.5
55+
scaling_ratio_range:
56+
- 0.1
57+
- 2.0
58+
border: $(input_size) * -0.5
59+
- class_path: otx.core.data.transform_libs.torchvision.CachedMixUp
60+
init_args:
61+
img_scale: $(input_size) # (H, W)
62+
ratio_range:
63+
- 1.0
64+
- 1.0
65+
prob: 0.5
66+
random_pop: false
67+
max_cached_images: 10
68+
- class_path: otx.core.data.transform_libs.torchvision.YOLOXHSVRandomAug
5769
- class_path: otx.core.data.transform_libs.torchvision.Resize
5870
init_args:
5971
scale: $(input_size)
6072
keep_ratio: true
6173
transform_bbox: true
74+
- class_path: otx.core.data.transform_libs.torchvision.RandomFlip
75+
init_args:
76+
prob: 0.5
6277
- class_path: otx.core.data.transform_libs.torchvision.Pad
6378
init_args:
6479
pad_to_square: true
@@ -75,7 +90,7 @@ overrides:
7590
class_path: otx.algo.samplers.balanced_sampler.BalancedSampler
7691

7792
val_subset:
78-
batch_size: 16
93+
batch_size: 10
7994
transforms:
8095
- class_path: otx.core.data.transform_libs.torchvision.Resize
8196
init_args:
@@ -95,7 +110,7 @@ overrides:
95110
std: [255.0, 255.0, 255.0]
96111

97112
test_subset:
98-
batch_size: 16
113+
batch_size: 10
99114
transforms:
100115
- class_path: otx.core.data.transform_libs.torchvision.Resize
101116
init_args:

src/otx/recipe/detection/yolov9_m.yaml

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ model:
77
optimizer:
88
class_path: torch.optim.SGD
99
init_args:
10-
lr: 0.001
10+
lr: 0.0001
1111
momentum: 0.937
1212
weight_decay: 0.0005
1313
nesterov: true
@@ -42,23 +42,38 @@ overrides:
4242
input_size:
4343
- 640
4444
- 640
45-
image_color_channel: BGR
4645
train_subset:
47-
batch_size: 16
46+
batch_size: 12
4847
transforms:
4948
- class_path: otx.core.data.transform_libs.torchvision.CachedMosaic
5049
init_args:
5150
random_pop: false
5251
max_cached_images: 20
5352
img_scale: $(input_size) # (H, W)
54-
- class_path: otx.core.data.transform_libs.torchvision.RandomCrop
53+
- class_path: otx.core.data.transform_libs.torchvision.RandomAffine
5554
init_args:
56-
crop_size: $(input_size) * 0.5
55+
scaling_ratio_range:
56+
- 0.1
57+
- 2.0
58+
border: $(input_size) * -0.5
59+
- class_path: otx.core.data.transform_libs.torchvision.CachedMixUp
60+
init_args:
61+
img_scale: $(input_size) # (H, W)
62+
ratio_range:
63+
- 1.0
64+
- 1.0
65+
prob: 0.5
66+
random_pop: false
67+
max_cached_images: 10
68+
- class_path: otx.core.data.transform_libs.torchvision.YOLOXHSVRandomAug
5769
- class_path: otx.core.data.transform_libs.torchvision.Resize
5870
init_args:
5971
scale: $(input_size)
6072
keep_ratio: true
6173
transform_bbox: true
74+
- class_path: otx.core.data.transform_libs.torchvision.RandomFlip
75+
init_args:
76+
prob: 0.5
6277
- class_path: otx.core.data.transform_libs.torchvision.Pad
6378
init_args:
6479
pad_to_square: true
@@ -75,7 +90,7 @@ overrides:
7590
class_path: otx.algo.samplers.balanced_sampler.BalancedSampler
7691

7792
val_subset:
78-
batch_size: 16
93+
batch_size: 12
7994
transforms:
8095
- class_path: otx.core.data.transform_libs.torchvision.Resize
8196
init_args:
@@ -95,7 +110,7 @@ overrides:
95110
std: [255.0, 255.0, 255.0]
96111

97112
test_subset:
98-
batch_size: 16
113+
batch_size: 12
99114
transforms:
100115
- class_path: otx.core.data.transform_libs.torchvision.Resize
101116
init_args:

0 commit comments

Comments
 (0)