@@ -106,7 +106,7 @@ def multi_slice_mask(starts, ends, length):
106
106
slices = torch .cat ([starts , ends ])
107
107
if slices .numel ():
108
108
assert slices .min () >= 0 and slices .max () <= length
109
- mask = scatter_add (values , slices , dim_size = length + 1 )[:- 1 ]
109
+ mask = scatter_add (values , slices , dim = 0 , dim_size = length + 1 )[:- 1 ]
110
110
mask = mask .cumsum (0 ).bool ()
111
111
return mask
112
112
@@ -230,7 +230,7 @@ def variadic_max(input, size):
230
230
index2sample = index2sample .expand_as (input )
231
231
232
232
value , index = scatter_max (input , index2sample , dim = 0 )
233
- index = index - size .cumsum (0 ) + size
233
+ index = index + ( size - size .cumsum (0 )). view ([ - 1 ] + [ 1 ] * ( index . ndim - 1 ))
234
234
return value , index
235
235
236
236
@@ -314,7 +314,8 @@ def variadic_topk(input, size, k, largest=True):
314
314
Parameters:
315
315
input (Tensor): input of shape :math:`(B, ...)`
316
316
size (LongTensor): size of sets of shape :math:`(N,)`
317
- k (int): the k in "top-k"
317
+ k (int or LongTensor): the k in "top-k". Can be a fixed value for all sets,
318
+ or different values for different sets of shape :math:`(N,)`.
318
319
largest (bool, optional): return largest or smallest elements
319
320
320
321
Returns
@@ -326,13 +327,19 @@ def variadic_topk(input, size, k, largest=True):
326
327
mask = ~ torch .isinf (input )
327
328
max = input [mask ].max ().item ()
328
329
min = input [mask ].min ().item ()
329
- safe_input = input .clamp (2 * min - max , 2 * max - min )
330
- offset = (max - min ) * 4
330
+ abs_max = input [mask ].abs ().max ().item ()
331
+ # special case: max = min
332
+ gap = max - min + abs_max * 1e-6
333
+ safe_input = input .clamp (min - gap , max + gap )
334
+ offset = gap * 4
331
335
if largest :
332
336
offset = - offset
333
337
input_ext = safe_input + offset * index2graph
334
338
index_ext = input_ext .argsort (dim = 0 , descending = largest )
335
- num_actual = size .clamp (max = k )
339
+ if isinstance (k , torch .Tensor ) and k .shape == size .shape :
340
+ num_actual = torch .min (size , k )
341
+ else :
342
+ num_actual = size .clamp (max = k )
336
343
num_padding = k - num_actual
337
344
starts = size .cumsum (0 ) - size
338
345
ends = starts + num_actual
@@ -346,9 +353,14 @@ def variadic_topk(input, size, k, largest=True):
346
353
347
354
index = index_ext [mask ] # (N * k, ...)
348
355
value = input .gather (0 , index )
349
- value = value .view (- 1 , k , * input .shape [1 :])
350
- index = index .view (- 1 , k , * input .shape [1 :])
351
- index = index - (size .cumsum (0 ) - size ).view ([- 1 ] + [1 ] * (index .ndim - 1 ))
356
+ if isinstance (k , torch .Tensor ) and k .shape == size .shape :
357
+ value = value .view (- 1 , * input .shape [1 :])
358
+ index = index .view (- 1 , * input .shape [1 :])
359
+ index = index - (size .cumsum (0 ) - size ).repeat_interleave (k ).view ([- 1 ] + [1 ] * (index .ndim - 1 ))
360
+ else :
361
+ value = value .view (- 1 , k , * input .shape [1 :])
362
+ index = index .view (- 1 , k , * input .shape [1 :])
363
+ index = index - (size .cumsum (0 ) - size ).view ([- 1 ] + [1 ] * (index .ndim - 1 ))
352
364
353
365
return value , index
354
366
@@ -432,6 +444,39 @@ def variadic_sample(input, size, num_sample):
432
444
return sample
433
445
434
446
447
+ def variadic_meshgrid (input1 , size1 , input2 , size2 ):
448
+ grid_size = size1 * size2
449
+ local_index = variadic_arange (grid_size )
450
+ local_inner_size = size2 .repeat_interleave (grid_size )
451
+ offset1 = (size1 .cumsum (0 ) - size1 ).repeat_interleave (grid_size )
452
+ offset2 = (size2 .cumsum (0 ) - size2 ).repeat_interleave (grid_size )
453
+ index1 = local_index // local_inner_size + offset1
454
+ index2 = local_index % local_inner_size + offset2
455
+ return input1 [index1 ], input2 [index2 ]
456
+
457
+
458
+ def variadic_to_padded (input , size , value = 0 ):
459
+ num_sample = len (size )
460
+ max_size = size .max ()
461
+ starts = torch .arange (num_sample , device = size .device ) * max_size
462
+ ends = starts + size
463
+ mask = multi_slice_mask (starts , ends , num_sample * max_size )
464
+ mask = mask .view (num_sample , max_size )
465
+ shape = (num_sample , max_size ) + input .shape [1 :]
466
+ padded = torch .full (shape , value , dtype = input .dtype , device = size .device )
467
+ padded [mask ] = input
468
+ return padded , mask
469
+
470
+
471
+ def padded_to_variadic (padded , size ):
472
+ num_sample , max_size = padded .shape [:2 ]
473
+ starts = torch .arange (num_sample , device = size .device ) * max_size
474
+ ends = starts + size
475
+ mask = multi_slice_mask (starts , ends , num_sample * max_size )
476
+ mask = mask .view (num_sample , max_size )
477
+ return padded [mask ]
478
+
479
+
435
480
def one_hot (index , size ):
436
481
"""
437
482
Expand indexes into one-hot vectors.
0 commit comments