Skip to content

Commit 6773ad0

Browse files
YodaEmbeddingfracape
authored andcommitted
refactor: simplify RasterScanLatentCodec
1 parent 00f1f2c commit 6773ad0

1 file changed

Lines changed: 124 additions & 156 deletions

File tree

compressai/latent_codecs/rasterscan.py

Lines changed: 124 additions & 156 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
2828
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2929

30-
from typing import Any, Callable, Dict, List, Tuple, TypeVar
30+
from typing import Any, Dict, List, Tuple, TypeVar
3131

3232
import torch
3333
import torch.nn as nn
@@ -94,7 +94,9 @@ def __init__(
9494
self.gaussian_conditional = gaussian_conditional
9595
self.entropy_parameters = entropy_parameters
9696
self.context_prediction = context_prediction
97-
self.kernel_size = _to_single(self.context_prediction.kernel_size)
97+
[k, *ks] = self.context_prediction.kernel_size
98+
assert all(k == k_ for k_ in ks)
99+
self.kernel_size = k
98100
self.padding = (self.kernel_size - 1) // 2
99101

100102
def forward(self, y: Tensor, params: Tensor) -> Dict[str, Any]:
@@ -112,18 +114,12 @@ def compress(self, y: Tensor, ctx_params: Tensor) -> Dict[str, Any]:
112114
ds = []
113115
for i in range(n):
114116
encoder = BufferedRansEncoder()
115-
y_hat = raster_scan_compress_single_stream(
117+
y_hat = self._compress_single_stream(
116118
encoder=encoder,
117119
y=y[i : i + 1, :, :, :],
118120
params=ctx_params[i : i + 1, :, :, :],
119-
gaussian_conditional=self.gaussian_conditional,
120-
entropy_parameters=self.entropy_parameters,
121-
context_prediction=self.context_prediction,
122121
height=y_height,
123122
width=y_width,
124-
padding=self.padding,
125-
kernel_size=self.kernel_size,
126-
merge=self.merge,
127123
)
128124
y_strings = encoder.flush()
129125
ds.append({"strings": [y_strings], "y_hat": y_hat.squeeze(0)})
@@ -142,170 +138,142 @@ def decompress(
142138
for i in range(len(y_strings)):
143139
decoder = RansDecoder()
144140
decoder.set_stream(y_strings[i])
145-
y_hat = raster_scan_decompress_single_stream(
141+
y_hat = self._decompress_single_stream(
146142
decoder=decoder,
147143
params=ctx_params[i : i + 1, :, :, :],
148-
gaussian_conditional=self.gaussian_conditional,
149-
entropy_parameters=self.entropy_parameters,
150-
context_prediction=self.context_prediction,
151144
height=y_height,
152145
width=y_width,
153-
padding=self.padding,
154-
kernel_size=self.kernel_size,
155146
device=ctx_params.device,
156-
merge=self.merge,
157147
)
158148
ds.append({"y_hat": y_hat.squeeze(0)})
159149
return default_collate(ds)
160150

161-
@staticmethod
162-
def merge(*args):
163-
return torch.cat(args, dim=1)
164-
165-
166-
def raster_scan_compress_single_stream(
167-
encoder: BufferedRansEncoder,
168-
y: Tensor,
169-
params: Tensor,
170-
*,
171-
gaussian_conditional: GaussianConditional,
172-
entropy_parameters: nn.Module,
173-
context_prediction: MaskedConv2d,
174-
height: int,
175-
width: int,
176-
padding: int,
177-
kernel_size: int,
178-
merge: Callable[..., Tensor] = lambda *args: torch.cat(args, dim=1),
179-
) -> Tensor:
180-
"""Compresses y and writes to encoder bitstream.
181-
182-
Returns:
183-
The y_hat that will be reconstructed at the decoder.
184-
"""
185-
assert height == y.shape[-2]
186-
assert width == y.shape[-1]
187-
188-
cdf = gaussian_conditional.quantized_cdf.tolist()
189-
cdf_lengths = gaussian_conditional.cdf_length.tolist()
190-
offsets = gaussian_conditional.offset.tolist()
191-
masked_weight = context_prediction.weight * context_prediction.mask
192-
193-
y_hat = _pad_2d(y, padding)
194-
195-
symbols_list = []
196-
indexes_list = []
197-
198-
# Warning, this is slow...
199-
# TODO: profile the calls to the bindings...
200-
for h in range(height):
201-
for w in range(width):
202-
# only perform the mask convolution on a cropped tensor
203-
# centered in (h, w)
204-
y_crop = y_hat[:, :, h : h + kernel_size, w : w + kernel_size]
205-
ctx_p = F.conv2d(
206-
y_crop,
207-
masked_weight,
208-
context_prediction.bias,
209-
)
210-
211-
# 1x1 conv for the entropy parameters prediction network, so
212-
# we only keep the elements in the "center"
213-
p = params[:, :, h : h + 1, w : w + 1]
214-
gaussian_params = entropy_parameters(merge(p, ctx_p))
215-
gaussian_params = gaussian_params.squeeze(3).squeeze(2)
216-
scales_hat, means_hat = gaussian_params.chunk(2, 1)
217-
indexes = gaussian_conditional.build_indexes(scales_hat)
218-
219-
y_crop = y_crop[:, :, padding, padding]
220-
symbols = gaussian_conditional.quantize(y_crop, "symbols", means_hat)
221-
y_hat_item = symbols + means_hat
222-
223-
hp = h + padding
224-
wp = w + padding
225-
y_hat[:, :, hp, wp] = y_hat_item
226-
227-
symbols_list.extend(symbols.squeeze().tolist())
228-
indexes_list.extend(indexes.squeeze().tolist())
229-
230-
encoder.encode_with_indexes(symbols_list, indexes_list, cdf, cdf_lengths, offsets)
231-
232-
y_hat = _pad_2d(y_hat, -padding)
233-
return y_hat
234-
235-
236-
def raster_scan_decompress_single_stream(
237-
decoder: RansDecoder,
238-
params: Tensor,
239-
*,
240-
gaussian_conditional: GaussianConditional,
241-
entropy_parameters: nn.Module,
242-
context_prediction: MaskedConv2d,
243-
height: int,
244-
width: int,
245-
padding: int,
246-
kernel_size: int,
247-
device,
248-
merge: Callable[..., Tensor] = lambda *args: torch.cat(args, dim=1),
249-
) -> Tensor:
250-
"""Decodes y_hat from decoder bitstream.
251-
252-
Returns:
253-
The reconstructed y_hat.
254-
"""
255-
cdf = gaussian_conditional.quantized_cdf.tolist()
256-
cdf_lengths = gaussian_conditional.cdf_length.tolist()
257-
offsets = gaussian_conditional.offset.tolist()
258-
masked_weight = context_prediction.weight * context_prediction.mask
259-
260-
c = context_prediction.in_channels
261-
shape = (1, c, height + 2 * padding, width + 2 * padding)
262-
y_hat = torch.zeros(shape, device=device)
263-
264-
# Warning: this is slow due to the auto-regressive nature of the
265-
# decoding... See more recent publication where they use an
266-
# auto-regressive module on chunks of channels for faster decoding...
267-
for h in range(height):
268-
for w in range(width):
269-
# only perform the mask convolution on a cropped tensor
270-
# centered in (h, w)
271-
y_crop = y_hat[:, :, h : h + kernel_size, w : w + kernel_size]
272-
ctx_p = F.conv2d(
273-
y_crop,
274-
masked_weight,
275-
context_prediction.bias,
276-
)
277-
278-
# 1x1 conv for the entropy parameters prediction network, so
279-
# we only keep the elements in the "center"
280-
p = params[:, :, h : h + 1, w : w + 1]
281-
gaussian_params = entropy_parameters(merge(p, ctx_p))
282-
gaussian_params = gaussian_params.squeeze(3).squeeze(2)
283-
scales_hat, means_hat = gaussian_params.chunk(2, 1)
284-
indexes = gaussian_conditional.build_indexes(scales_hat)
285-
286-
symbols = decoder.decode_stream(
287-
indexes.squeeze().tolist(), cdf, cdf_lengths, offsets
288-
)
289-
symbols = Tensor(symbols).reshape(1, -1)
290-
y_hat_item = gaussian_conditional.dequantize(symbols, means_hat)
151+
def _compress_single_stream(
152+
self,
153+
encoder: BufferedRansEncoder,
154+
y: Tensor,
155+
params: Tensor,
156+
*,
157+
height: int,
158+
width: int,
159+
) -> Tensor:
160+
"""Compresses y and writes to encoder bitstream.
161+
162+
Returns:
163+
The y_hat that will be reconstructed at the decoder.
164+
"""
165+
assert height == y.shape[-2]
166+
assert width == y.shape[-1]
167+
168+
cdf = self.gaussian_conditional.quantized_cdf.tolist()
169+
cdf_lengths = self.gaussian_conditional.cdf_length.tolist()
170+
offsets = self.gaussian_conditional.offset.tolist()
171+
masked_weight = self.context_prediction.weight * self.context_prediction.mask
172+
173+
y_hat = _pad_2d(y, self.padding)
174+
175+
symbols_list = []
176+
indexes_list = []
177+
178+
# Warning, this is slow...
179+
# TODO: profile the calls to the bindings...
180+
for h in range(height):
181+
for w in range(width):
182+
# only perform the mask convolution on a cropped tensor
183+
# centered in (h, w)
184+
y_crop = y_hat[:, :, h : h + self.kernel_size, w : w + self.kernel_size]
185+
ctx_p = F.conv2d(y_crop, masked_weight, self.context_prediction.bias)
186+
187+
# 1x1 conv for the entropy parameters prediction network, so
188+
# we only keep the elements in the "center"
189+
p = params[:, :, h : h + 1, w : w + 1]
190+
gaussian_params = self.entropy_parameters(self.merge(p, ctx_p))
191+
gaussian_params = gaussian_params.squeeze(3).squeeze(2)
192+
scales_hat, means_hat = gaussian_params.chunk(2, 1)
193+
indexes = self.gaussian_conditional.build_indexes(scales_hat)
194+
195+
y_crop = y_crop[:, :, self.padding, self.padding]
196+
symbols = self.gaussian_conditional.quantize(
197+
y_crop, "symbols", means_hat
198+
)
199+
y_hat_item = symbols + means_hat
200+
201+
hp = h + self.padding
202+
wp = w + self.padding
203+
y_hat[:, :, hp, wp] = y_hat_item
204+
205+
symbols_list.extend(symbols.squeeze().tolist())
206+
indexes_list.extend(indexes.squeeze().tolist())
207+
208+
encoder.encode_with_indexes(
209+
symbols_list, indexes_list, cdf, cdf_lengths, offsets
210+
)
291211

292-
hp = h + padding
293-
wp = w + padding
294-
y_hat[:, :, hp, wp] = y_hat_item
212+
y_hat = _pad_2d(y_hat, -self.padding)
213+
return y_hat
295214

296-
y_hat = _pad_2d(y_hat, -padding)
297-
return y_hat
215+
def _decompress_single_stream(
216+
self,
217+
decoder: RansDecoder,
218+
params: Tensor,
219+
*,
220+
height: int,
221+
width: int,
222+
device,
223+
) -> Tensor:
224+
"""Decodes y_hat from decoder bitstream.
225+
226+
Returns:
227+
The reconstructed y_hat.
228+
"""
229+
cdf = self.gaussian_conditional.quantized_cdf.tolist()
230+
cdf_lengths = self.gaussian_conditional.cdf_length.tolist()
231+
offsets = self.gaussian_conditional.offset.tolist()
232+
masked_weight = self.context_prediction.weight * self.context_prediction.mask
233+
234+
c = self.context_prediction.in_channels
235+
shape = (1, c, height + 2 * self.padding, width + 2 * self.padding)
236+
y_hat = torch.zeros(shape, device=device)
237+
238+
# Warning: this is slow due to the auto-regressive nature of the
239+
# decoding... See more recent publication where they use an
240+
# auto-regressive module on chunks of channels for faster decoding...
241+
for h in range(height):
242+
for w in range(width):
243+
# only perform the mask convolution on a cropped tensor
244+
# centered in (h, w)
245+
y_crop = y_hat[:, :, h : h + self.kernel_size, w : w + self.kernel_size]
246+
ctx_p = F.conv2d(y_crop, masked_weight, self.context_prediction.bias)
247+
248+
# 1x1 conv for the entropy parameters prediction network, so
249+
# we only keep the elements in the "center"
250+
p = params[:, :, h : h + 1, w : w + 1]
251+
gaussian_params = self.entropy_parameters(self.merge(p, ctx_p))
252+
gaussian_params = gaussian_params.squeeze(3).squeeze(2)
253+
scales_hat, means_hat = gaussian_params.chunk(2, 1)
254+
indexes = self.gaussian_conditional.build_indexes(scales_hat)
255+
256+
symbols = decoder.decode_stream(
257+
indexes.squeeze().tolist(), cdf, cdf_lengths, offsets
258+
)
259+
symbols = Tensor(symbols).reshape(1, -1)
260+
y_hat_item = self.gaussian_conditional.dequantize(symbols, means_hat)
261+
262+
hp = h + self.padding
263+
wp = w + self.padding
264+
y_hat[:, :, hp, wp] = y_hat_item
265+
266+
y_hat = _pad_2d(y_hat, -self.padding)
267+
return y_hat
268+
269+
def merge(self, *args):
270+
return torch.cat(args, dim=1)
298271

299272

300273
def _pad_2d(x: Tensor, padding: int) -> Tensor:
301274
return F.pad(x, (padding, padding, padding, padding))
302275

303276

304-
def _to_single(xs):
305-
assert all(x == xs[0] for x in xs)
306-
return xs[0]
307-
308-
309277
def default_collate(batch: List[Dict[K, V]]) -> Dict[K, List[V]]:
310278
"""Combines a list of dictionaries into a single dictionary.
311279

0 commit comments

Comments
 (0)