-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathMCOrderingMain.py
More file actions
235 lines (196 loc) · 9.03 KB
/
MCOrderingMain.py
File metadata and controls
235 lines (196 loc) · 9.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
import torch
import numpy as np
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
import torch.optim as optim
from models import GNN
import argparse
from Chebutils import create_chebyshev
from plotutils import plot_and_save_matrices,visualize_nan_case
import os
# Define the directory name
save_directory_name = "SaveModeldir"
# Create the directory if it doesn't exist
if not os.path.exists(save_directory_name):
os.makedirs(save_directory_name)
print(f"Directory '{save_directory_name}' created.")
else:
pass
def check_P_properties(P):
row_sums = P.sum(dim=2)
col_sums = P.sum(dim=1)
print(f"Row sums min/max: {row_sums.min():.4f}/{row_sums.max():.4f}")
print(f"Col sums min/max: {col_sums.min():.4f}/{col_sums.max():.4f}")
# Configuration
parser = argparse.ArgumentParser()
parser.add_argument("--p_edge", type=float, default=0.7, help="Probability of edge creation.")
parser.add_argument('--seed', type=int, default=42, help='Random seed.')
parser.add_argument('--num_of_nodes', type=int, default=50, help='Graph Size')
parser.add_argument('--lr', type=float, default=3e-3, help='Learning Rate')
parser.add_argument('--hidden', type=int, default=128, help='Number of hidden units.')
parser.add_argument('--batch_size', type=int, default=500, help='Batch size')
parser.add_argument('--nlayers', type=int, default=2, help='Number of layers')
parser.add_argument('--EPOCHS', type=int, default=1000, help='Epochs to train')
parser.add_argument('--wdecay', type=float, default=0.0, help='Weight decay')
parser.add_argument('--stepsize', type=int, default=50, help='Step size')
parser.add_argument('--diag_loss', type=float, default=0., help='Penalty on the diag')
parser.add_argument('--grad_norm', type=float, default=1, help='Grad norm')
parser.add_argument('--outdim', type=int, default=20, help='Output dim')
parser.add_argument('--sctorder', type=int, default=4, help='Scattering order')
parser.add_argument('--gcnorder', type=int, default=2, help='GCN order')
parser.add_argument('--optim', type=str, default='none', help='type of optimizer')
# New arguments for GNN model parameters
parser.add_argument('--input_dim', type=int, default=2, help='Input dimension for node features')
parser.add_argument('--tanh_scale', type=float, default=40.0, help='Scale factor for tanh activation')
parser.add_argument('--tau', type=float, default=0.1, help='Temperature parameter for Gumbel-Sinkhorn')
parser.add_argument('--n_iter', type=int, default=20, help='Number of iterations for Gumbel-Sinkhorn')
parser.add_argument('--noise_scale', type=float, default=0.05, help='Noise scale for Gumbel-Sinkhorn')
parser.add_argument('--cheb', type=float, default=0.1, help='chebyshev coefficients: the parameters for the cheb matrix, (1+cheb)**')
args = parser.parse_args()
from hardpermutation import to_exact_permutation
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
class GraphDataset(Dataset):
"""Custom Dataset class for graph data"""
def __init__(self, data_dir, device=DEVICE):
# Load the saved dataset
dataset = torch.load(data_dir + '/er_graph_training_dataset_nodes%d_p%.2f.pt'%(args.num_of_nodes,args.p_edge), map_location=device,weights_only=True)
self.adj_matrices = dataset['adjacency_matrices'].to(device)
self.features = dataset['features'].to(device)
self.metadata = dataset['metadata']
def __len__(self):
return self.metadata['num_instances']
def __getitem__(self, idx):
return {
'adj_matrix': self.adj_matrices[idx],
'node_features': self.features[idx]
}
def train_model(model, train_loader, optimizer, num_epochs=10, device=DEVICE):
model.train()
# Convert to float32
JminusI = (torch.ones(args.num_of_nodes, args.num_of_nodes) - torch.eye(args.num_of_nodes)).float()
JminusI = JminusI.to(device)
CC = create_chebyshev(args.num_of_nodes).float() # Ensure Chebyshev matrix is float
CC = CC - args.num_of_nodes/2
CC = (1.0 + args.cheb)**CC
print(CC)
max_grad_norm = 1.0
for epoch in range(num_epochs):
total_loss = 0
for batch_idx, batch in enumerate(train_loader):
optimizer.zero_grad()
# Convert adjacency matrix to float if it's not already
adj = batch['adj_matrix'].float()
features = batch['node_features'].float() # Ensure features are float
batch_size = features.size(0)
# Forward pass
try:
P,_ = model(features, adj)
P = P.float()
# Check for NaN in output
if torch.isnan(P).any():
print(f"NaN detected in model output at epoch {epoch}, batch {batch_idx}")
# Visualize the input that caused NaN
adj_stats, feat_stats = visualize_nan_case(features, adj, epoch, batch_idx)
#print("Stats for NaN case:")
#print("Adjacency stats:", adj_stats)
#print("Feature stats:", feat_stats)
continue
Batched_JminusI = JminusI.repeat(batch_size, 1, 1)
Batched_Chebyshev = CC.repeat(batch_size, 1, 1).float()
Batched_JminusIminusA = (Batched_JminusI - adj).float()
PT = torch.transpose(P, 1, 2)
AdjP = torch.matmul(Batched_JminusIminusA, P)
epsilon = 1e-8
PTAdjP = torch.matmul(PT, AdjP) + epsilon
# Calculate loss
loss = torch.mul(PTAdjP, Batched_Chebyshev)
loss = loss.sum(dim=(1, 2)) # Shape: [32]
#loss = loss - torch.mul(Batched_JminusIminusA, Batched_Chebyshev)
loss = torch.mean(loss) # Use mean instead of sum
# Check if loss is valid
if not torch.isnan(loss) and not torch.isinf(loss):
loss.backward()
# Clip gradients
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
# Check gradients
valid_gradients = True
for param in model.parameters():
if param.grad is not None:
if torch.isnan(param.grad).any() or torch.isinf(param.grad).any():
valid_gradients = False
print(f"Invalid gradients detected at epoch {epoch}, batch {batch_idx}")
break
if valid_gradients:
optimizer.step()
total_loss += loss.item()
# Print batch statistics
if batch_idx % 10 == 0:
print(f'Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx}/{len(train_loader)}], '
f'Loss: {loss.item():.4f}')
except RuntimeError as e:
print(f"Error in batch {batch_idx} of epoch {epoch}: {str(e)}")
continue
# Print epoch statistics
if len(train_loader) > 0:
avg_loss = total_loss / len(train_loader)
print(f'Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.4f}')
def main():
# Parameters
DATA_DIR = "../GNN/er_graph_dataset"
BATCH_SIZE = args.batch_size
# Load dataset
dataset = GraphDataset(DATA_DIR, device=DEVICE)
# Create data loader
train_loader = DataLoader(
dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=0, # Set to 0 when using GPU
pin_memory=False # False since data is already on GPU
)
# Initialize model
input_dim = dataset.features.size(-1) # Number of node features (2 in our case)
# Initialize model with all parameters from args
model = GNN(
input_dim=args.input_dim,
hidden_dim=args.hidden,
output_dim=args.num_of_nodes,
n_layers=args.nlayers,
sctorder=args.sctorder,
gcnorder=args.gcnorder,
TanhScale=args.tanh_scale,
tau=args.tau,
n_iter=args.n_iter,
noise_scale=args.noise_scale
).to(DEVICE)
# Optimizer
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=0.001)
# Train model
print("Starting training...")
train_model(
model=model,
train_loader=train_loader,
optimizer=optimizer,
num_epochs=args.EPOCHS,
device=DEVICE
)
model_save_name = (
f'graph_model_'
f'S{args.num_of_nodes}_'
f'E{args.p_edge:.2f}_'
f'h{args.hidden}_'
f'l{args.nlayers}_'
f'sct{args.sctorder}_'
f'gcn{args.gcnorder}_'
f'tau{args.tau}_'
f'ns{args.noise_scale}_'
f'ts{args.tanh_scale}_'
f'cheb{args.cheb}_'
f'ni{args.n_iter}.pt'
)
torch.save(model.state_dict(), os.path.join(save_directory_name, model_save_name))
print(f"Model saved as: {model_save_name}")
print("Training completed and model saved!")
if __name__ == "__main__":
main()