Skip to content

Commit 4b4b30a

Browse files
committed
add variadic functions & fix rspmm small bug
1 parent 36832c6 commit 4b4b30a

File tree

3 files changed

+58
-12
lines changed

3 files changed

+58
-12
lines changed

torchdrug/layers/functional/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
from .functional import multinomial, masked_mean, mean_with_nan, shifted_softplus, multi_slice, multi_slice_mask, \
22
as_mask, _size_to_index, _extend, variadic_log_softmax, variadic_softmax, variadic_sum, variadic_mean, \
33
variadic_max, variadic_cross_entropy, variadic_sort, variadic_topk, variadic_arange, variadic_randperm, \
4-
variadic_sample, one_hot, clipped_policy_gradient_objective, policy_gradient_objective
4+
variadic_sample, variadic_meshgrid, variadic_to_padded, padded_to_variadic, one_hot, \
5+
clipped_policy_gradient_objective, policy_gradient_objective
56
from .embedding import transe_score, distmult_score, complex_score, simple_score, rotate_score
67
from .spmm import generalized_spmm, generalized_rspmm
78

89
__all__ = [
910
"multinomial", "masked_mean", "mean_with_nan", "shifted_softplus", "multi_slice_mask", "as_mask",
1011
"variadic_log_softmax", "variadic_softmax", "variadic_sum", "variadic_mean", "variadic_max",
1112
"variadic_cross_entropy", "variadic_sort", "variadic_topk", "variadic_arange", "variadic_randperm",
12-
"variadic_sample",
13+
"variadic_sample", "variadic_meshgrid", "variadic_to_padded", "padded_to_variadic",
1314
"one_hot", "clipped_policy_gradient_objective", "policy_gradient_objective",
1415
"transe_score", "distmult_score", "complex_score", "simple_score", "rotate_score",
1516
"generalized_spmm", "generalized_rspmm",

torchdrug/layers/functional/extension/rspmm.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ std::tuple<SparseTensor, Tensor, Tensor> rspmm_backward_cpu(
185185
output_arg(output_, "output", 4), output_grad_arg(output_grad_, "output_grad", 5);
186186

187187
rspmm_backward_check(fn_name, sparse_arg, relation_arg, input_arg, output_arg, output_grad_arg);
188-
checkDeviceType(fn_name, {sparse, input_, output_, output_grad_}, kCPU);
188+
checkDeviceType(fn_name, {sparse, relation_, input_, output_, output_grad_}, kCPU);
189189

190190
const Tensor relation = relation_.contiguous();
191191
const Tensor input = input_.contiguous();

torchdrug/layers/functional/functional.py

Lines changed: 54 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def multi_slice_mask(starts, ends, length):
106106
slices = torch.cat([starts, ends])
107107
if slices.numel():
108108
assert slices.min() >= 0 and slices.max() <= length
109-
mask = scatter_add(values, slices, dim_size=length + 1)[:-1]
109+
mask = scatter_add(values, slices, dim=0, dim_size=length + 1)[:-1]
110110
mask = mask.cumsum(0).bool()
111111
return mask
112112

@@ -230,7 +230,7 @@ def variadic_max(input, size):
230230
index2sample = index2sample.expand_as(input)
231231

232232
value, index = scatter_max(input, index2sample, dim=0)
233-
index = index - size.cumsum(0) + size
233+
index = index + (size - size.cumsum(0)).view([-1] + [1] * (index.ndim - 1))
234234
return value, index
235235

236236

@@ -314,7 +314,8 @@ def variadic_topk(input, size, k, largest=True):
314314
Parameters:
315315
input (Tensor): input of shape :math:`(B, ...)`
316316
size (LongTensor): size of sets of shape :math:`(N,)`
317-
k (int): the k in "top-k"
317+
k (int or LongTensor): the k in "top-k". Can be a fixed value for all sets,
318+
or different values for different sets of shape :math:`(N,)`.
318319
largest (bool, optional): return largest or smallest elements
319320
320321
Returns
@@ -326,13 +327,19 @@ def variadic_topk(input, size, k, largest=True):
326327
mask = ~torch.isinf(input)
327328
max = input[mask].max().item()
328329
min = input[mask].min().item()
329-
safe_input = input.clamp(2 * min - max, 2 * max - min)
330-
offset = (max - min) * 4
330+
abs_max = input[mask].abs().max().item()
331+
# special case: max = min
332+
gap = max - min + abs_max * 1e-6
333+
safe_input = input.clamp(min - gap, max + gap)
334+
offset = gap * 4
331335
if largest:
332336
offset = -offset
333337
input_ext = safe_input + offset * index2graph
334338
index_ext = input_ext.argsort(dim=0, descending=largest)
335-
num_actual = size.clamp(max=k)
339+
if isinstance(k, torch.Tensor) and k.shape == size.shape:
340+
num_actual = torch.min(size, k)
341+
else:
342+
num_actual = size.clamp(max=k)
336343
num_padding = k - num_actual
337344
starts = size.cumsum(0) - size
338345
ends = starts + num_actual
@@ -346,9 +353,14 @@ def variadic_topk(input, size, k, largest=True):
346353

347354
index = index_ext[mask] # (N * k, ...)
348355
value = input.gather(0, index)
349-
value = value.view(-1, k, *input.shape[1:])
350-
index = index.view(-1, k, *input.shape[1:])
351-
index = index - (size.cumsum(0) - size).view([-1] + [1] * (index.ndim - 1))
356+
if isinstance(k, torch.Tensor) and k.shape == size.shape:
357+
value = value.view(-1, *input.shape[1:])
358+
index = index.view(-1, *input.shape[1:])
359+
index = index - (size.cumsum(0) - size).repeat_interleave(k).view([-1] + [1] * (index.ndim - 1))
360+
else:
361+
value = value.view(-1, k, *input.shape[1:])
362+
index = index.view(-1, k, *input.shape[1:])
363+
index = index - (size.cumsum(0) - size).view([-1] + [1] * (index.ndim - 1))
352364

353365
return value, index
354366

@@ -432,6 +444,39 @@ def variadic_sample(input, size, num_sample):
432444
return sample
433445

434446

447+
def variadic_meshgrid(input1, size1, input2, size2):
448+
grid_size = size1 * size2
449+
local_index = variadic_arange(grid_size)
450+
local_inner_size = size2.repeat_interleave(grid_size)
451+
offset1 = (size1.cumsum(0) - size1).repeat_interleave(grid_size)
452+
offset2 = (size2.cumsum(0) - size2).repeat_interleave(grid_size)
453+
index1 = local_index // local_inner_size + offset1
454+
index2 = local_index % local_inner_size + offset2
455+
return input1[index1], input2[index2]
456+
457+
458+
def variadic_to_padded(input, size, value=0):
459+
num_sample = len(size)
460+
max_size = size.max()
461+
starts = torch.arange(num_sample, device=size.device) * max_size
462+
ends = starts + size
463+
mask = multi_slice_mask(starts, ends, num_sample * max_size)
464+
mask = mask.view(num_sample, max_size)
465+
shape = (num_sample, max_size) + input.shape[1:]
466+
padded = torch.full(shape, value, dtype=input.dtype, device=size.device)
467+
padded[mask] = input
468+
return padded, mask
469+
470+
471+
def padded_to_variadic(padded, size):
472+
num_sample, max_size = padded.shape[:2]
473+
starts = torch.arange(num_sample, device=size.device) * max_size
474+
ends = starts + size
475+
mask = multi_slice_mask(starts, ends, num_sample * max_size)
476+
mask = mask.view(num_sample, max_size)
477+
return padded[mask]
478+
479+
435480
def one_hot(index, size):
436481
"""
437482
Expand indexes into one-hot vectors.

0 commit comments

Comments
 (0)