2727# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
2828# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2929
30- from typing import Any , Callable , Dict , List , Tuple , TypeVar
30+ from typing import Any , Dict , List , Tuple , TypeVar
3131
3232import torch
3333import torch .nn as nn
@@ -94,7 +94,9 @@ def __init__(
9494 self .gaussian_conditional = gaussian_conditional
9595 self .entropy_parameters = entropy_parameters
9696 self .context_prediction = context_prediction
97- self .kernel_size = _to_single (self .context_prediction .kernel_size )
97+ [k , * ks ] = self .context_prediction .kernel_size
98+ assert all (k == k_ for k_ in ks )
99+ self .kernel_size = k
98100 self .padding = (self .kernel_size - 1 ) // 2
99101
100102 def forward (self , y : Tensor , params : Tensor ) -> Dict [str , Any ]:
@@ -112,18 +114,12 @@ def compress(self, y: Tensor, ctx_params: Tensor) -> Dict[str, Any]:
112114 ds = []
113115 for i in range (n ):
114116 encoder = BufferedRansEncoder ()
115- y_hat = raster_scan_compress_single_stream (
117+ y_hat = self . _compress_single_stream (
116118 encoder = encoder ,
117119 y = y [i : i + 1 , :, :, :],
118120 params = ctx_params [i : i + 1 , :, :, :],
119- gaussian_conditional = self .gaussian_conditional ,
120- entropy_parameters = self .entropy_parameters ,
121- context_prediction = self .context_prediction ,
122121 height = y_height ,
123122 width = y_width ,
124- padding = self .padding ,
125- kernel_size = self .kernel_size ,
126- merge = self .merge ,
127123 )
128124 y_strings = encoder .flush ()
129125 ds .append ({"strings" : [y_strings ], "y_hat" : y_hat .squeeze (0 )})
@@ -142,170 +138,142 @@ def decompress(
142138 for i in range (len (y_strings )):
143139 decoder = RansDecoder ()
144140 decoder .set_stream (y_strings [i ])
145- y_hat = raster_scan_decompress_single_stream (
141+ y_hat = self . _decompress_single_stream (
146142 decoder = decoder ,
147143 params = ctx_params [i : i + 1 , :, :, :],
148- gaussian_conditional = self .gaussian_conditional ,
149- entropy_parameters = self .entropy_parameters ,
150- context_prediction = self .context_prediction ,
151144 height = y_height ,
152145 width = y_width ,
153- padding = self .padding ,
154- kernel_size = self .kernel_size ,
155146 device = ctx_params .device ,
156- merge = self .merge ,
157147 )
158148 ds .append ({"y_hat" : y_hat .squeeze (0 )})
159149 return default_collate (ds )
160150
161- @staticmethod
162- def merge (* args ):
163- return torch .cat (args , dim = 1 )
164-
165-
166- def raster_scan_compress_single_stream (
167- encoder : BufferedRansEncoder ,
168- y : Tensor ,
169- params : Tensor ,
170- * ,
171- gaussian_conditional : GaussianConditional ,
172- entropy_parameters : nn .Module ,
173- context_prediction : MaskedConv2d ,
174- height : int ,
175- width : int ,
176- padding : int ,
177- kernel_size : int ,
178- merge : Callable [..., Tensor ] = lambda * args : torch .cat (args , dim = 1 ),
179- ) -> Tensor :
180- """Compresses y and writes to encoder bitstream.
181-
182- Returns:
183- The y_hat that will be reconstructed at the decoder.
184- """
185- assert height == y .shape [- 2 ]
186- assert width == y .shape [- 1 ]
187-
188- cdf = gaussian_conditional .quantized_cdf .tolist ()
189- cdf_lengths = gaussian_conditional .cdf_length .tolist ()
190- offsets = gaussian_conditional .offset .tolist ()
191- masked_weight = context_prediction .weight * context_prediction .mask
192-
193- y_hat = _pad_2d (y , padding )
194-
195- symbols_list = []
196- indexes_list = []
197-
198- # Warning, this is slow...
199- # TODO: profile the calls to the bindings...
200- for h in range (height ):
201- for w in range (width ):
202- # only perform the mask convolution on a cropped tensor
203- # centered in (h, w)
204- y_crop = y_hat [:, :, h : h + kernel_size , w : w + kernel_size ]
205- ctx_p = F .conv2d (
206- y_crop ,
207- masked_weight ,
208- context_prediction .bias ,
209- )
210-
211- # 1x1 conv for the entropy parameters prediction network, so
212- # we only keep the elements in the "center"
213- p = params [:, :, h : h + 1 , w : w + 1 ]
214- gaussian_params = entropy_parameters (merge (p , ctx_p ))
215- gaussian_params = gaussian_params .squeeze (3 ).squeeze (2 )
216- scales_hat , means_hat = gaussian_params .chunk (2 , 1 )
217- indexes = gaussian_conditional .build_indexes (scales_hat )
218-
219- y_crop = y_crop [:, :, padding , padding ]
220- symbols = gaussian_conditional .quantize (y_crop , "symbols" , means_hat )
221- y_hat_item = symbols + means_hat
222-
223- hp = h + padding
224- wp = w + padding
225- y_hat [:, :, hp , wp ] = y_hat_item
226-
227- symbols_list .extend (symbols .squeeze ().tolist ())
228- indexes_list .extend (indexes .squeeze ().tolist ())
229-
230- encoder .encode_with_indexes (symbols_list , indexes_list , cdf , cdf_lengths , offsets )
231-
232- y_hat = _pad_2d (y_hat , - padding )
233- return y_hat
234-
235-
236- def raster_scan_decompress_single_stream (
237- decoder : RansDecoder ,
238- params : Tensor ,
239- * ,
240- gaussian_conditional : GaussianConditional ,
241- entropy_parameters : nn .Module ,
242- context_prediction : MaskedConv2d ,
243- height : int ,
244- width : int ,
245- padding : int ,
246- kernel_size : int ,
247- device ,
248- merge : Callable [..., Tensor ] = lambda * args : torch .cat (args , dim = 1 ),
249- ) -> Tensor :
250- """Decodes y_hat from decoder bitstream.
251-
252- Returns:
253- The reconstructed y_hat.
254- """
255- cdf = gaussian_conditional .quantized_cdf .tolist ()
256- cdf_lengths = gaussian_conditional .cdf_length .tolist ()
257- offsets = gaussian_conditional .offset .tolist ()
258- masked_weight = context_prediction .weight * context_prediction .mask
259-
260- c = context_prediction .in_channels
261- shape = (1 , c , height + 2 * padding , width + 2 * padding )
262- y_hat = torch .zeros (shape , device = device )
263-
264- # Warning: this is slow due to the auto-regressive nature of the
265- # decoding... See more recent publication where they use an
266- # auto-regressive module on chunks of channels for faster decoding...
267- for h in range (height ):
268- for w in range (width ):
269- # only perform the mask convolution on a cropped tensor
270- # centered in (h, w)
271- y_crop = y_hat [:, :, h : h + kernel_size , w : w + kernel_size ]
272- ctx_p = F .conv2d (
273- y_crop ,
274- masked_weight ,
275- context_prediction .bias ,
276- )
277-
278- # 1x1 conv for the entropy parameters prediction network, so
279- # we only keep the elements in the "center"
280- p = params [:, :, h : h + 1 , w : w + 1 ]
281- gaussian_params = entropy_parameters (merge (p , ctx_p ))
282- gaussian_params = gaussian_params .squeeze (3 ).squeeze (2 )
283- scales_hat , means_hat = gaussian_params .chunk (2 , 1 )
284- indexes = gaussian_conditional .build_indexes (scales_hat )
285-
286- symbols = decoder .decode_stream (
287- indexes .squeeze ().tolist (), cdf , cdf_lengths , offsets
288- )
289- symbols = Tensor (symbols ).reshape (1 , - 1 )
290- y_hat_item = gaussian_conditional .dequantize (symbols , means_hat )
151+ def _compress_single_stream (
152+ self ,
153+ encoder : BufferedRansEncoder ,
154+ y : Tensor ,
155+ params : Tensor ,
156+ * ,
157+ height : int ,
158+ width : int ,
159+ ) -> Tensor :
160+ """Compresses y and writes to encoder bitstream.
161+
162+ Returns:
163+ The y_hat that will be reconstructed at the decoder.
164+ """
165+ assert height == y .shape [- 2 ]
166+ assert width == y .shape [- 1 ]
167+
168+ cdf = self .gaussian_conditional .quantized_cdf .tolist ()
169+ cdf_lengths = self .gaussian_conditional .cdf_length .tolist ()
170+ offsets = self .gaussian_conditional .offset .tolist ()
171+ masked_weight = self .context_prediction .weight * self .context_prediction .mask
172+
173+ y_hat = _pad_2d (y , self .padding )
174+
175+ symbols_list = []
176+ indexes_list = []
177+
178+ # Warning, this is slow...
179+ # TODO: profile the calls to the bindings...
180+ for h in range (height ):
181+ for w in range (width ):
182+ # only perform the mask convolution on a cropped tensor
183+ # centered in (h, w)
184+ y_crop = y_hat [:, :, h : h + self .kernel_size , w : w + self .kernel_size ]
185+ ctx_p = F .conv2d (y_crop , masked_weight , self .context_prediction .bias )
186+
187+ # 1x1 conv for the entropy parameters prediction network, so
188+ # we only keep the elements in the "center"
189+ p = params [:, :, h : h + 1 , w : w + 1 ]
190+ gaussian_params = self .entropy_parameters (self .merge (p , ctx_p ))
191+ gaussian_params = gaussian_params .squeeze (3 ).squeeze (2 )
192+ scales_hat , means_hat = gaussian_params .chunk (2 , 1 )
193+ indexes = self .gaussian_conditional .build_indexes (scales_hat )
194+
195+ y_crop = y_crop [:, :, self .padding , self .padding ]
196+ symbols = self .gaussian_conditional .quantize (
197+ y_crop , "symbols" , means_hat
198+ )
199+ y_hat_item = symbols + means_hat
200+
201+ hp = h + self .padding
202+ wp = w + self .padding
203+ y_hat [:, :, hp , wp ] = y_hat_item
204+
205+ symbols_list .extend (symbols .squeeze ().tolist ())
206+ indexes_list .extend (indexes .squeeze ().tolist ())
207+
208+ encoder .encode_with_indexes (
209+ symbols_list , indexes_list , cdf , cdf_lengths , offsets
210+ )
291211
292- hp = h + padding
293- wp = w + padding
294- y_hat [:, :, hp , wp ] = y_hat_item
212+ y_hat = _pad_2d (y_hat , - self .padding )
213+ return y_hat
295214
296- y_hat = _pad_2d (y_hat , - padding )
297- return y_hat
215+ def _decompress_single_stream (
216+ self ,
217+ decoder : RansDecoder ,
218+ params : Tensor ,
219+ * ,
220+ height : int ,
221+ width : int ,
222+ device ,
223+ ) -> Tensor :
224+ """Decodes y_hat from decoder bitstream.
225+
226+ Returns:
227+ The reconstructed y_hat.
228+ """
229+ cdf = self .gaussian_conditional .quantized_cdf .tolist ()
230+ cdf_lengths = self .gaussian_conditional .cdf_length .tolist ()
231+ offsets = self .gaussian_conditional .offset .tolist ()
232+ masked_weight = self .context_prediction .weight * self .context_prediction .mask
233+
234+ c = self .context_prediction .in_channels
235+ shape = (1 , c , height + 2 * self .padding , width + 2 * self .padding )
236+ y_hat = torch .zeros (shape , device = device )
237+
238+ # Warning: this is slow due to the auto-regressive nature of the
239+ # decoding... See more recent publication where they use an
240+ # auto-regressive module on chunks of channels for faster decoding...
241+ for h in range (height ):
242+ for w in range (width ):
243+ # only perform the mask convolution on a cropped tensor
244+ # centered in (h, w)
245+ y_crop = y_hat [:, :, h : h + self .kernel_size , w : w + self .kernel_size ]
246+ ctx_p = F .conv2d (y_crop , masked_weight , self .context_prediction .bias )
247+
248+ # 1x1 conv for the entropy parameters prediction network, so
249+ # we only keep the elements in the "center"
250+ p = params [:, :, h : h + 1 , w : w + 1 ]
251+ gaussian_params = self .entropy_parameters (self .merge (p , ctx_p ))
252+ gaussian_params = gaussian_params .squeeze (3 ).squeeze (2 )
253+ scales_hat , means_hat = gaussian_params .chunk (2 , 1 )
254+ indexes = self .gaussian_conditional .build_indexes (scales_hat )
255+
256+ symbols = decoder .decode_stream (
257+ indexes .squeeze ().tolist (), cdf , cdf_lengths , offsets
258+ )
259+ symbols = Tensor (symbols ).reshape (1 , - 1 )
260+ y_hat_item = self .gaussian_conditional .dequantize (symbols , means_hat )
261+
262+ hp = h + self .padding
263+ wp = w + self .padding
264+ y_hat [:, :, hp , wp ] = y_hat_item
265+
266+ y_hat = _pad_2d (y_hat , - self .padding )
267+ return y_hat
268+
269+ def merge (self , * args ):
270+ return torch .cat (args , dim = 1 )
298271
299272
300273def _pad_2d (x : Tensor , padding : int ) -> Tensor :
301274 return F .pad (x , (padding , padding , padding , padding ))
302275
303276
304- def _to_single (xs ):
305- assert all (x == xs [0 ] for x in xs )
306- return xs [0 ]
307-
308-
309277def default_collate (batch : List [Dict [K , V ]]) -> Dict [K , List [V ]]:
310278 """Combines a list of dictionaries into a single dictionary.
311279
0 commit comments