-
Notifications
You must be signed in to change notification settings - Fork 155
Open
Description
DilatedKnnGraph can accept a badly shaped batch parameter when using "matrix" mode. This happens because only batch_size is used. Maybe an assertion / error should be raised as in Pytorch geometric's knn_graph.
import torch
from torch_geometric.nn import knn_graph
p = torch.rand((256, 3))
t = torch.cat([p, p])
batch = torch.cat([torch.ones(data_cloud.shape[0]) * i for i in range(2)]).type(torch.long) # normal batch
batch2 = torch.tensor([0, 0, 1,])
batch3 = torch.tensor([0, 1, 0,])
dknn = DilatedKnnGraph(k = 3, dilation = 1)
f0 = knn_graph(t, k = 3, batch = batch, loop = True)
f1 = dknn(t, batch = batch2) # Maybe this should raise a shape error?
torch.all(f1 == f0) # True
f2 = dknn(t, batch = batch3) # Weird behaviour since [-1] is used to compute batch_size
torch.all(f2 == f0) # False
Also at
batch_size = batch[-1] + 1 |
.max()
Maybe future issue (but we can assume this is a misuse from the user I guess):
If the user is not passing the batch parameter in order (for example passes [0, 1, 0, 1] instead of [0, 0, 1,1]) I'm not sure the reshape part will work.
lightaime
Metadata
Metadata
Assignees
Labels
No labels