-
Notifications
You must be signed in to change notification settings - Fork 22
/
Copy pathload_data.py
82 lines (69 loc) · 3.38 KB
/
load_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import torch
import copy
import torch.utils
import numpy as np
# split train and test
def split_ids(ids, folds):
n = len(ids)
stride = int(np.ceil(n / float(folds)))
test_ids = [ids[i: i + stride] for i in range(0, n, stride)]
assert np.all(
np.unique(np.concatenate(test_ids)) == sorted(ids)), 'some graphs are missing in the test sets'
assert len(test_ids) == folds, 'invalid test sets'
train_ids = []
for fold in range(folds):
train_ids.append(np.array([e for e in ids if e not in test_ids[fold]]))
assert len(train_ids[fold]) + len(test_ids[fold]) == len(
np.unique(list(train_ids[fold]) + list(test_ids[fold]))) == n, 'invalid splits'
return train_ids, test_ids
# Universal training_data loader and reader
class GraphData(torch.utils.data.Dataset):
def __init__(self, datareader, fold_id, split):
self.fold_id = fold_id
self.split = split
self.rnd_state = datareader.rnd_state
self.set_fold(datareader.data, fold_id)
def set_fold(self, data, fold_id):
self.total = len(data['targets'])
self.N_nodes_max = data['N_nodes_max']
self.num_classes = data['num_classes']
self.num_features = data['num_features']
self.idx = data['splits'][fold_id][self.split]
# use deepcopy to make sure we don't alter objects in folds
self.labels = copy.deepcopy([data['targets'][i] for i in self.idx])
self.ids = copy.deepcopy([data['ids'][i] for i in self.idx])
self.adj_list = copy.deepcopy([data['adj_list'][i] for i in self.idx])
self.features_onehot = copy.deepcopy([data['features_onehot'][i] for i in self.idx])
print('%s: %d/%d' % (self.split.upper(), len(self.labels), len(data['targets'])))
def __len__(self):
return len(self.labels)
def __getitem__(self, index):
# convert to torch
return [torch.from_numpy(self.features_onehot[index]).float(), # node_features
torch.from_numpy(self.adj_list[index]).float(), # adjacency matrix
int(self.labels[index]), # graph labels
int(self.ids[index]) # graph id
]
def collate_batch(batch):
"""
Creates a batch of same size graphs by zero-padding node features and adjacency matrices up to
the maximum number of nodes in the CURRENT batch rather than in the entire dataset.
Graphs in the batches are usually much smaller than the largest graph in the dataset, so this method is fast.
:param batch: [node_features * batch_size, A * batch_size, label * batch_size]
:return: [node_features, A, graph_support, N_nodes, label]
"""
B = len(batch)
N_nodes = [len(batch[b][1]) for b in range(B)]
C = batch[0][0].shape[1]
N_nodes_max = int(np.max(N_nodes))
graph_support = torch.zeros(B, N_nodes_max)
A = torch.zeros(B, N_nodes_max, N_nodes_max)
x = torch.zeros(B, N_nodes_max, C)
for b in range(B):
x[b, :N_nodes[b]] = batch[b][0]
A[b, :N_nodes[b], :N_nodes[b]] = batch[b][1]
graph_support[b][:N_nodes[b]] = 1 # mask with values of 0 for dummy (zero padded) nodes, otherwise 1
N_nodes = torch.from_numpy(np.array(N_nodes)).long()
labels = torch.from_numpy(np.array([batch[b][2] for b in range(B)])).long()
ids = torch.from_numpy(np.array([batch[b][3] for b in range(B)])).long()
return [x, A, graph_support, N_nodes, labels, ids]