|
| 1 | +import os |
| 2 | +import torch |
| 3 | +import copy |
| 4 | +import torch.nn.functional as F |
| 5 | +from torch.utils.data import DataLoader, TensorDataset |
| 6 | +from torch_geometric.loader import NeighborLoader |
| 7 | +from torch_geometric.utils import negative_sampling |
| 8 | + |
| 9 | + |
| 10 | +class FoldDataset: |
| 11 | + def __init__(self, all_edges, fold_indices): |
| 12 | + """ |
| 13 | + all_edges: torch.Tensor of shape [2, num_edges] (src, dst) |
| 14 | + fold_indices: list of (train_idx, val_idx) tuples for each fold |
| 15 | + """ |
| 16 | + self.all_edges = all_edges |
| 17 | + self.fold_indices = fold_indices |
| 18 | + |
| 19 | + def get_fold(self, i): |
| 20 | + train_idx, val_idx = self.fold_indices[i] |
| 21 | + return self.all_edges[:, train_idx], self.all_edges[:, val_idx] |
| 22 | + |
| 23 | + |
| 24 | +class LinkPredictionCVRunner: |
| 25 | + def __init__(self, data, test_edges, fold_dataset, model_fn, decoder_fn, |
| 26 | + save_dir='cv_results', device='cpu', num_epochs=10, batch_size=1024, |
| 27 | + edge_types_supervision=None): |
| 28 | + """ |
| 29 | + edge_types_supervision: list of edge types to supervise on. |
| 30 | + If empty or None, supervise on all edge types. |
| 31 | + """ |
| 32 | + self.orig_data = data |
| 33 | + self.test_edges = test_edges |
| 34 | + self.fold_dataset = fold_dataset |
| 35 | + self.model_fn = model_fn |
| 36 | + self.decoder_fn = decoder_fn |
| 37 | + self.device = device |
| 38 | + self.num_epochs = num_epochs |
| 39 | + self.batch_size = batch_size |
| 40 | + self.save_dir = save_dir |
| 41 | + self.edge_types_supervision = edge_types_supervision or [] |
| 42 | + os.makedirs(save_dir, exist_ok=True) |
| 43 | + |
| 44 | + def _to_homo_cached(self, data): |
| 45 | + if hasattr(data, '_homo_cache'): |
| 46 | + return data._homo_cache |
| 47 | + data._homo_cache = data.to_homogeneous() |
| 48 | + return data._homo_cache |
| 49 | + |
| 50 | + def run_fold(self, fold_idx): |
| 51 | + train_edges, val_edges = self.fold_dataset.get_fold(fold_idx) |
| 52 | + |
| 53 | + data_fold = copy.deepcopy(self.orig_data) |
| 54 | + if self.edge_types_supervision: |
| 55 | + for etype in self.edge_types_supervision: |
| 56 | + data_fold[etype].edge_index = train_edges |
| 57 | + |
| 58 | + homo = self._to_homo_cached(data_fold) |
| 59 | + |
| 60 | + if self.edge_types_supervision: |
| 61 | + etype_ids = [homo.edge_type_names.index(etype) for etype in self.edge_types_supervision] |
| 62 | + edge_type_mask = torch.isin(homo.edge_type, torch.tensor(etype_ids, device=homo.edge_type.device)) |
| 63 | + pos_edge_index = homo.edge_index[:, edge_type_mask] |
| 64 | + else: |
| 65 | + pos_edge_index = train_edges |
| 66 | + |
| 67 | + edge_dataset = TensorDataset(pos_edge_index[0], pos_edge_index[1]) |
| 68 | + edge_loader = DataLoader(edge_dataset, batch_size=self.batch_size, shuffle=True) |
| 69 | + |
| 70 | + model = self.model_fn().to(self.device) |
| 71 | + decoder = self.decoder_fn().to(self.device) |
| 72 | + optimizer = torch.optim.Adam(list(model.parameters()) + list(decoder.parameters()), lr=1e-3) |
| 73 | + |
| 74 | + for epoch in range(self.num_epochs): |
| 75 | + model.train() |
| 76 | + decoder.train() |
| 77 | + for src_pos, dst_pos in edge_loader: |
| 78 | + node_ids = torch.cat([src_pos, dst_pos]).unique() |
| 79 | + sub_loader = NeighborLoader( |
| 80 | + homo, |
| 81 | + input_nodes=node_ids, |
| 82 | + num_neighbors=[15, 10], |
| 83 | + batch_size=node_ids.size(0), |
| 84 | + shuffle=False |
| 85 | + ) |
| 86 | + sub_data = next(iter(sub_loader)).to(self.device) |
| 87 | + |
| 88 | + x = model(sub_data.x, sub_data.edge_index) |
| 89 | + src_pos, dst_pos = src_pos.to(self.device), dst_pos.to(self.device) |
| 90 | + pos_out = decoder(x[src_pos], x[dst_pos]) |
| 91 | + |
| 92 | + neg_dst = torch.randint(0, homo.num_nodes, (len(src_pos),), device=self.device) |
| 93 | + neg_out = decoder(x[src_pos], x[neg_dst]) |
| 94 | + |
| 95 | + pos_loss = F.binary_cross_entropy_with_logits(pos_out, torch.ones_like(pos_out)) |
| 96 | + neg_loss = F.binary_cross_entropy_with_logits(neg_out, torch.zeros_like(neg_out)) |
| 97 | + loss = pos_loss + neg_loss |
| 98 | + |
| 99 | + loss.backward() |
| 100 | + optimizer.step() |
| 101 | + optimizer.zero_grad() |
| 102 | + |
| 103 | + model_path = os.path.join(self.save_dir, f'model_fold{fold_idx}.pt') |
| 104 | + torch.save({'model': model.state_dict(), 'decoder': decoder.state_dict()}, model_path) |
| 105 | + |
| 106 | + val_result = self.evaluate(model, decoder, homo, val_edges, name=f'val_fold{fold_idx}') |
| 107 | + test_result = self.evaluate(model, decoder, homo, self.test_edges, name=f'test_fold{fold_idx}') |
| 108 | + |
| 109 | + result_path = os.path.join(self.save_dir, f'result_fold{fold_idx}.pt') |
| 110 | + torch.save({'val': val_result, 'test': test_result}, result_path) |
| 111 | + |
| 112 | + @torch.no_grad() |
| 113 | + def evaluate(self, model, decoder, homo, edge_index, name='val'): |
| 114 | + model.eval() |
| 115 | + decoder.eval() |
| 116 | + x = model(homo.x.to(self.device), homo.edge_index.to(self.device)) |
| 117 | + |
| 118 | + src, dst = edge_index[0].to(self.device), edge_index[1].to(self.device) |
| 119 | + pos_pred = decoder(x[src], x[dst]).sigmoid() |
| 120 | + |
| 121 | + neg_dst = torch.randint(0, homo.num_nodes, (len(src),), device=self.device) |
| 122 | + neg_pred = decoder(x[src], x[neg_dst]).sigmoid() |
| 123 | + |
| 124 | + y_true = torch.cat([torch.ones_like(pos_pred), torch.zeros_like(neg_pred)]) |
| 125 | + y_score = torch.cat([pos_pred, neg_pred]) |
| 126 | + |
| 127 | + from sklearn.metrics import roc_auc_score, average_precision_score |
| 128 | + auc = roc_auc_score(y_true.cpu(), y_score.cpu()) |
| 129 | + ap = average_precision_score(y_true.cpu(), y_score.cpu()) |
| 130 | + return {'auc': auc, 'ap': ap} |
| 131 | + |
| 132 | + def run_all_folds(self): |
| 133 | + for i in range(len(self.fold_dataset.fold_indices)): |
| 134 | + self.run_fold(i) |
| 135 | + print("All folds completed.") |
0 commit comments