|
1 | 1 | """SwinUNetR wrapper for napari_cellseg3d.""" |
2 | 2 |
|
| 3 | +import inspect |
| 4 | + |
3 | 5 | from monai.networks.nets import SwinUNETR |
4 | 6 |
|
5 | 7 | from napari_cellseg3d.utils import LOGGER |
@@ -30,30 +32,30 @@ def __init__( |
30 | 32 | use_checkpoint (bool): whether to use checkpointing during training. |
31 | 33 | **kwargs: additional arguments to SwinUNETR. |
32 | 34 | """ |
| 35 | + parent_init = super().__init__ |
| 36 | + sig = inspect.signature(parent_init) |
| 37 | + init_kwargs = dict( |
| 38 | + in_channels=in_channels, |
| 39 | + out_channels=out_channels, |
| 40 | + use_checkpoint=use_checkpoint, |
| 41 | + feature_size=48, |
| 42 | + drop_rate=0.5, |
| 43 | + attn_drop_rate=0.5, |
| 44 | + use_v2=True, |
| 45 | + **kwargs, |
| 46 | + ) |
| 47 | + if "img_size" in sig.parameters: |
| 48 | + # since MONAI API changes depending on py3.8 or py3.9 |
| 49 | + init_kwargs["img_size"] = input_img_size |
| 50 | + if "dropout_prob" in kwargs: |
| 51 | + init_kwargs["drop_rate"] = kwargs["dropout_prob"] |
| 52 | + init_kwargs.pop("dropout_prob") |
33 | 53 | try: |
34 | | - super().__init__( |
35 | | - input_img_size, |
36 | | - in_channels=in_channels, |
37 | | - out_channels=out_channels, |
38 | | - feature_size=48, |
39 | | - use_checkpoint=use_checkpoint, |
40 | | - drop_rate=0.5, |
41 | | - attn_drop_rate=0.5, |
42 | | - use_v2=True, |
43 | | - **kwargs, |
44 | | - ) |
| 54 | + parent_init(**init_kwargs) |
45 | 55 | except TypeError as e: |
46 | 56 | logger.warning(f"Caught TypeError: {e}") |
47 | | - super().__init__( |
48 | | - input_img_size, |
49 | | - in_channels=1, |
50 | | - out_channels=1, |
51 | | - feature_size=48, |
52 | | - use_checkpoint=use_checkpoint, |
53 | | - drop_rate=0.5, |
54 | | - attn_drop_rate=0.5, |
55 | | - use_v2=True, |
56 | | - ) |
| 57 | + init_kwargs["in_channels"] = 1 |
| 58 | + parent_init(**init_kwargs) |
57 | 59 |
|
58 | 60 | # def forward(self, x_in): |
59 | 61 | # y = super().forward(x_in) |
|
0 commit comments