diff --git a/pytorch360convert/pytorch360convert.py b/pytorch360convert/pytorch360convert.py index 8e9a616..0458a86 100644 --- a/pytorch360convert/pytorch360convert.py +++ b/pytorch360convert/pytorch360convert.py @@ -740,25 +740,33 @@ def c2e( NotImplementedError: If an unknown cube_format is provided. """ - if cube_format == "stack": - assert ( - isinstance(cubemap, torch.Tensor) - and len(cubemap.shape) == 4 - and cubemap.shape[0] == 6 - ) - cubemap = [cubemap[i] for i in range(cubemap.shape[0])] + if cubemap[0].dim() == 4 or cubemap[0].dim() == 5: + if cubemap[0].dim() == 4: + assert ( + isinstance(cubemap, torch.Tensor) + and len(cubemap.shape) == 4 + and cubemap.shape[0] == 6 + ) + cubemap = [cubemap[i] for i in range(cubemap.shape[0])] + else: + assert ( + isinstance(cubemap, torch.Tensor) + and len(cubemap.shape) == 5 + and cubemap.shape[1] == 6 + ) + cubemap = [cubemap[:, i] for i in range(cubemap.shape[1])] cube_format = "list" # Ensure input is in HWC format for processing if channels_first: if cube_format == "list" and isinstance(cubemap, (list, tuple)): - cubemap = [r.permute(1, 2, 0) for r in cubemap] + cubemap = [_nchw2nhwc(r) for r in cubemap] elif cube_format == "dict" and torch.jit.isinstance( cubemap, Dict[str, torch.Tensor] ): - cubemap = {k: v.permute(1, 2, 0) for k, v in cubemap.items()} # type: ignore + cubemap = {k: _nchw2nhwc(v) for k, v in cubemap.items()} # type: ignore elif cube_format in ["horizon", "dice"] and isinstance(cubemap, torch.Tensor): - cubemap = cubemap.permute(1, 2, 0) + cubemap = _nchw2nhwc(cubemap) else: raise NotImplementedError("unknown cube_format and cubemap type")