From 9bc7d3372b4d47072acf9f3db3df064d6bf3d218 Mon Sep 17 00:00:00 2001 From: Manda Kausthubh Date: Tue, 8 Jul 2025 13:04:21 +0530 Subject: [PATCH 01/14] Basic implementation of the generator and discriminator --- torch_molecule/generator/molgan/__init__.py | 1 + .../generator/molgan/discriminator.py | 0 torch_molecule/generator/molgan/generator.py | 76 +++++++++++++++++++ .../generator/molgan/modeling_molgan.py | 0 4 files changed, 77 insertions(+) create mode 100644 torch_molecule/generator/molgan/__init__.py create mode 100644 torch_molecule/generator/molgan/discriminator.py create mode 100644 torch_molecule/generator/molgan/generator.py create mode 100644 torch_molecule/generator/molgan/modeling_molgan.py diff --git a/torch_molecule/generator/molgan/__init__.py b/torch_molecule/generator/molgan/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/torch_molecule/generator/molgan/__init__.py @@ -0,0 +1 @@ + diff --git a/torch_molecule/generator/molgan/discriminator.py b/torch_molecule/generator/molgan/discriminator.py new file mode 100644 index 0000000..e69de29 diff --git a/torch_molecule/generator/molgan/generator.py b/torch_molecule/generator/molgan/generator.py new file mode 100644 index 0000000..d8d559d --- /dev/null +++ b/torch_molecule/generator/molgan/generator.py @@ -0,0 +1,76 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + + +class MolGANConfig: + """ + Configuration class for MolGAN Generator and Discriminator. + + This class stores hyperparameters and architectural details used to construct + the MolGAN generator and other related modules. It allows modular control over + model depth, input/output dimensionality, and Gumbel-softmax behavior. + """ + def __init__(self, + latent_dim=56, + hidden_dims=[128, 128, 256], + num_nodes=9, + num_atom_types=5, + num_bond_types=4, + tau=1.0): + self.latent_dim = latent_dim + self.hidden_dims = hidden_dims + self.num_nodes = num_nodes + self.num_atom_types = num_atom_types + self.num_bond_types = num_bond_types + self.tau = tau + + + +class MolGANGenerator(nn.Module): + + """ + Generator network for MolGAN. + + Maps a latent vector z to a molecular graph represented by: + - Adjacency tensor A ∈ [B, Y, N, N] (bonds) + - Node features X ∈ [B, N, T] (atoms) + + Uses Gumbel-Softmax to approximate discrete molecular structure. + """ + + def __init__(self, config): + super().__init__() + self.config = config + + output_dim = (config.num_nodes * config.num_atom_types) + \ + (config.num_nodes * config.num_nodes * config.num_bond_types) + + layers = [] + input_dim = config.latent_dim + for hidden_dim in config.hidden_dims: + layers.append(nn.Linear(input_dim, hidden_dim)) + layers.append(nn.ReLU()) + input_dim = hidden_dim + layers.append(nn.Linear(input_dim, output_dim)) + + self.fc = nn.Sequential(*layers) + + def forward(self, z): + B = z.size(0) + out = self.fc(z) + + N, T, Y = self.config.num_nodes, self.config.num_atom_types, self.config.num_bond_types + node_size = N * T + adj_size = N * N * Y + + node_flat, adj_flat = torch.split(out, [node_size, adj_size], dim=1) + node = node_flat.view(B, N, T) + adj = adj_flat.view(B, Y, N, N) + + # Gumbel-softmax + node = F.gumbel_softmax(node, tau=self.config.tau, hard=True, dim=-1) + adj = F.gumbel_softmax(adj, tau=self.config.tau, hard=True, dim=1) + + return adj, node diff --git a/torch_molecule/generator/molgan/modeling_molgan.py b/torch_molecule/generator/molgan/modeling_molgan.py new file mode 100644 index 0000000..e69de29 From 0b289165bab41f26c7b0c79670b54cc2e6836505 Mon Sep 17 00:00:00 2001 From: Manda Kausthubh Date: Tue, 8 Jul 2025 13:09:57 +0530 Subject: [PATCH 02/14] Adding config class for discriminator --- .../generator/molgan/discriminator.py | 112 ++++++++++++++++++ 1 file changed, 112 insertions(+) diff --git a/torch_molecule/generator/molgan/discriminator.py b/torch_molecule/generator/molgan/discriminator.py index e69de29..198295a 100644 --- a/torch_molecule/generator/molgan/discriminator.py +++ b/torch_molecule/generator/molgan/discriminator.py @@ -0,0 +1,112 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class MolGANDiscriminatorConfig: + """ + Configuration class for MolGAN Discriminator. + + Stores architectural hyperparameters and allows modular configuration. + """ + + def __init__(self, + num_atom_types=5, + num_bond_types=4, + num_nodes=9, + hidden_dim=128, + num_layers=2): + """ + Parameters + ---------- + num_atom_types : int + Number of atom types in node features (input channels). + + num_bond_types : int + Number of bond types (number of relational edge types). + + num_nodes : int + Max number of nodes in the graph (used for flattening before readout). + + hidden_dim : int + Hidden dimension size for R-GCN layers. + + num_layers : int + Number of stacked R-GCN layers. + """ + self.num_atom_types = num_atom_types + self.num_bond_types = num_bond_types + self.num_nodes = num_nodes + self.hidden_dim = hidden_dim + self.num_layers = num_layers + + +class RelationalGCNLayer(nn.Module): + def __init__(self, in_dim, out_dim, num_relations): + super().__init__() + self.num_relations = num_relations + self.linears = nn.ModuleList([nn.Linear(in_dim, out_dim) for _ in range(num_relations)]) + self.bias = nn.Parameter(torch.zeros(out_dim)) + + def forward(self, adj, h): + """ + adj: [B, Y, N, N] + h: [B, N, D] + """ + out = 0 + for i in range(self.num_relations): + adj_i = adj[:, i, :, :] + h_i = self.linears[i](h) + out += torch.bmm(adj_i, h_i) + + out = out + self.bias + return F.relu(out) + + +class MolGANDiscriminator(nn.Module): + """ + Discriminator network for MolGAN using stacked Relational GCNs. + """ + + def __init__(self, config: MolGANDiscriminatorConfig): + super().__init__() + self.config = config + + self.gcn_layers = nn.ModuleList() + self.gcn_layers.append( + RelationalGCNLayer(config.num_atom_types, config.hidden_dim, config.num_bond_types) + ) + + for _ in range(1, config.num_layers): + self.gcn_layers.append( + RelationalGCNLayer(config.hidden_dim, config.hidden_dim, config.num_bond_types) + ) + + self.readout = nn.Sequential( + nn.Linear(config.num_nodes * config.hidden_dim, config.hidden_dim), + nn.ReLU(), + nn.Linear(config.hidden_dim, 1) + ) + + def forward(self, adj, node): + """ + Parameters: + adj: Tensor of shape [B, Y, N, N] -- adjacency tensor + node: Tensor of shape [B, N, T] -- one-hot or softmax node features + + Returns: + Tensor of shape [B] with real/fake logits + """ + h = node + for gcn in self.gcn_layers: + h = gcn(adj, h) + + h = h.view(h.size(0), -1) + return self.readout(h).squeeze(-1) + + + + + + + From e146dcba65ec66d94c52c5a6255f3d5875e05105 Mon Sep 17 00:00:00 2001 From: Manda Kausthubh Date: Tue, 8 Jul 2025 14:35:30 +0530 Subject: [PATCH 03/14] Added dataset class for MolGAN --- torch_molecule/generator/molgan/dataset.py | 85 ++++++++++++++++++++++ torch_molecule/generator/molgan/gan.py | 31 ++++++++ 2 files changed, 116 insertions(+) create mode 100644 torch_molecule/generator/molgan/dataset.py create mode 100644 torch_molecule/generator/molgan/gan.py diff --git a/torch_molecule/generator/molgan/dataset.py b/torch_molecule/generator/molgan/dataset.py new file mode 100644 index 0000000..5318b33 --- /dev/null +++ b/torch_molecule/generator/molgan/dataset.py @@ -0,0 +1,85 @@ +from torch.utils.data import Dataset +import torch +from rdkit import Chem +from rdkit.Chem import AllChem +from rdkit.Chem import MolFromSmiles +import numpy as np + +class MolGraphDataset(Dataset): + """ + Dataset for MolGAN that converts SMILES to graph tensors. + Outputs: + - adj: [Y, N, N] + - node: [N, T] + - reward: float (optional) + """ + + def __init__(self, smiles_list, atom_types, bond_types, max_nodes=9, rewards=None): + """ + Parameters + ---------- + smiles_list : List[str] + List of SMILES strings + + atom_types : List[str] + Ordered list of allowed atom types (e.g., ['C', 'O', 'N', 'F']) + + bond_types : List[int] + List of allowed bond types (RDKit enums: SINGLE=1, DOUBLE=2, etc.) + + max_nodes : int + Max number of atoms in any molecule (pad or skip otherwise) + + rewards : Optional[List[float]] + Precomputed rewards (e.g., QED values) + """ + self.smiles_list = smiles_list + self.atom_types = atom_types + self.bond_types = bond_types + self.max_nodes = max_nodes + self.rewards = rewards if rewards is not None else [0.0] * len(smiles_list) + + self.atom_type_map = {atom: i for i, atom in enumerate(atom_types)} + self.bond_type_map = {b: i for i, b in enumerate(bond_types)} + + self.data = [self._smiles_to_graph(s) for s in smiles_list] + + def _smiles_to_graph(self, smiles): + mol = Chem.MolFromSmiles(smiles) + mol = Chem.AddHs(mol) + num_atoms = mol.GetNumAtoms() + + if num_atoms > self.max_nodes: + raise ValueError(f"Too many atoms in molecule: {num_atoms} > {self.max_nodes}") + + # Node features + node = np.zeros((self.max_nodes, len(self.atom_types))) + for i, atom in enumerate(mol.GetAtoms()): + atom_type = atom.GetSymbol() + if atom_type not in self.atom_type_map: + continue + node[i, self.atom_type_map[atom_type]] = 1 + + # Adjacency tensor + adj = np.zeros((len(self.bond_types), self.max_nodes, self.max_nodes)) + for bond in mol.GetBonds(): + i = bond.GetBeginAtomIdx() + j = bond.GetEndAtomIdx() + bond_type = int(bond.GetBondTypeAsDouble()) + if bond_type not in self.bond_type_map: + continue + k = self.bond_type_map[bond_type] + adj[k, i, j] = 1 + adj[k, j, i] = 1 # symmetric + + return {"adj": adj, "node": node} + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + sample = self.data[idx] + adj = torch.tensor(sample["adj"], dtype=torch.float32) + node = torch.tensor(sample["node"], dtype=torch.float32) + reward = torch.tensor(self.rewards[idx], dtype=torch.float32) + return {"adj": adj, "node": node, "reward": reward} diff --git a/torch_molecule/generator/molgan/gan.py b/torch_molecule/generator/molgan/gan.py new file mode 100644 index 0000000..81c2770 --- /dev/null +++ b/torch_molecule/generator/molgan/gan.py @@ -0,0 +1,31 @@ +import torch +import torch.nn as nn +from .generator import MolGANGenerator +from .discriminator import MolGANDiscriminator + +class MolGAN(nn.Module): + """ + Combined MolGAN model: generator + discriminator + """ + + def __init__(self, gen_config, disc_config): + super().__init__() + self.generator = MolGANGenerator(gen_config) + self.discriminator = MolGANDiscriminator(disc_config) + + def generate(self, z): + """Forward pass through generator only.""" + return self.generator(z) + + def discriminate(self, adj, node): + """Forward pass through discriminator only.""" + return self.discriminator(adj, node) + + def forward(self, z): + """ + Combined forward pass (generator → discriminator). + Used for adversarial training. + """ + adj_fake, node_fake = self.generator(z) + pred_fake = self.discriminator(adj_fake, node_fake) + return adj_fake, node_fake, pred_fake From 377946640b92d6430f3cfa24703bfb1214c40712 Mon Sep 17 00:00:00 2001 From: Manda Kausthubh Date: Tue, 8 Jul 2025 15:52:40 +0530 Subject: [PATCH 04/14] Added additional reward function --- torch_molecule/generator/molgan/dataset.py | 131 +++++------ .../generator/molgan/discriminator.py | 23 +- torch_molecule/generator/molgan/gan_utils.py | 85 +++++++ torch_molecule/generator/molgan/rewards.py | 208 ++++++++++++++++++ 4 files changed, 361 insertions(+), 86 deletions(-) create mode 100644 torch_molecule/generator/molgan/gan_utils.py create mode 100644 torch_molecule/generator/molgan/rewards.py diff --git a/torch_molecule/generator/molgan/dataset.py b/torch_molecule/generator/molgan/dataset.py index 5318b33..ba74d63 100644 --- a/torch_molecule/generator/molgan/dataset.py +++ b/torch_molecule/generator/molgan/dataset.py @@ -1,85 +1,88 @@ from torch.utils.data import Dataset import torch +from typing import List, Optional, Union, Callable from rdkit import Chem -from rdkit.Chem import AllChem -from rdkit.Chem import MolFromSmiles -import numpy as np + +# assumes you already have this: +from ...utils.graph.graph_from_smiles import graph_from_smiles + class MolGraphDataset(Dataset): """ - Dataset for MolGAN that converts SMILES to graph tensors. - Outputs: - - adj: [Y, N, N] - - node: [N, T] - - reward: float (optional) + Dataset for MolGAN: converts SMILES strings to graph format. + + Outputs a dict with: + - 'adj': [Y, N, N] adjacency tensor + - 'node': [N, T] node feature matrix + - 'reward': float (optional) + - 'smiles': original SMILES (optional) """ - def __init__(self, smiles_list, atom_types, bond_types, max_nodes=9, rewards=None): + def __init__(self, + smiles_list: List[str], + reward_function: Optional[Callable[[str], float]] = None, + max_nodes: int = 9, + drop_invalid: bool = True): """ Parameters ---------- smiles_list : List[str] - List of SMILES strings + List of SMILES strings to convert into graph format. - atom_types : List[str] - Ordered list of allowed atom types (e.g., ['C', 'O', 'N', 'F']) - - bond_types : List[int] - List of allowed bond types (RDKit enums: SINGLE=1, DOUBLE=2, etc.) + reward_function : Callable[[str], float], optional + If provided, computes a scalar reward per molecule (e.g., QED, logP). + Must accept a SMILES string and return a float. max_nodes : int - Max number of atoms in any molecule (pad or skip otherwise) + Maximum allowed number of atoms (molecules exceeding this are dropped). - rewards : Optional[List[float]] - Precomputed rewards (e.g., QED values) + drop_invalid : bool + Whether to skip invalid or unparsable SMILES. """ - self.smiles_list = smiles_list - self.atom_types = atom_types - self.bond_types = bond_types - self.max_nodes = max_nodes - self.rewards = rewards if rewards is not None else [0.0] * len(smiles_list) - - self.atom_type_map = {atom: i for i, atom in enumerate(atom_types)} - self.bond_type_map = {b: i for i, b in enumerate(bond_types)} - - self.data = [self._smiles_to_graph(s) for s in smiles_list] - - def _smiles_to_graph(self, smiles): - mol = Chem.MolFromSmiles(smiles) - mol = Chem.AddHs(mol) - num_atoms = mol.GetNumAtoms() - - if num_atoms > self.max_nodes: - raise ValueError(f"Too many atoms in molecule: {num_atoms} > {self.max_nodes}") - - # Node features - node = np.zeros((self.max_nodes, len(self.atom_types))) - for i, atom in enumerate(mol.GetAtoms()): - atom_type = atom.GetSymbol() - if atom_type not in self.atom_type_map: - continue - node[i, self.atom_type_map[atom_type]] = 1 - - # Adjacency tensor - adj = np.zeros((len(self.bond_types), self.max_nodes, self.max_nodes)) - for bond in mol.GetBonds(): - i = bond.GetBeginAtomIdx() - j = bond.GetEndAtomIdx() - bond_type = int(bond.GetBondTypeAsDouble()) - if bond_type not in self.bond_type_map: - continue - k = self.bond_type_map[bond_type] - adj[k, i, j] = 1 - adj[k, j, i] = 1 # symmetric - - return {"adj": adj, "node": node} + self.samples = [] + + for smiles in smiles_list: + try: + mol = Chem.MolFromSmiles(smiles) + if mol is None: + raise ValueError("Invalid SMILES") + + if mol.GetNumAtoms() > max_nodes: + raise ValueError("Too many atoms") + + # Compute reward if needed + reward = reward_function(smiles) if reward_function else 0.0 + + # Convert to graph + graph = graph_from_smiles(smiles, properties=reward) + + # Sanity check + if 'adj' not in graph or 'node' not in graph: + raise ValueError("Incomplete graph data") + + graph['reward'] = reward + graph['smiles'] = smiles + self.samples.append(graph) + + except Exception as e: + if not drop_invalid: + self.samples.append({ + "adj": torch.zeros(1, max_nodes, max_nodes), + "node": torch.zeros(max_nodes, 1), + "reward": 0.0, + "smiles": smiles + }) + else: + print(f"[MolGraphDataset] Skipping SMILES {smiles}: {e}") def __len__(self): - return len(self.data) + return len(self.samples) def __getitem__(self, idx): - sample = self.data[idx] - adj = torch.tensor(sample["adj"], dtype=torch.float32) - node = torch.tensor(sample["node"], dtype=torch.float32) - reward = torch.tensor(self.rewards[idx], dtype=torch.float32) - return {"adj": adj, "node": node, "reward": reward} + sample = self.samples[idx] + return { + "adj": torch.tensor(sample["adj"], dtype=torch.float32), + "node": torch.tensor(sample["node"], dtype=torch.float32), + "reward": torch.tensor(sample["reward"], dtype=torch.float32), + "smiles": sample["smiles"] + } diff --git a/torch_molecule/generator/molgan/discriminator.py b/torch_molecule/generator/molgan/discriminator.py index 198295a..d3afb20 100644 --- a/torch_molecule/generator/molgan/discriminator.py +++ b/torch_molecule/generator/molgan/discriminator.py @@ -1,6 +1,5 @@ -import torch import torch.nn as nn -import torch.nn.functional as F +from .gan_utils import RelationalGCNLayer class MolGANDiscriminatorConfig: @@ -41,26 +40,6 @@ def __init__(self, self.num_layers = num_layers -class RelationalGCNLayer(nn.Module): - def __init__(self, in_dim, out_dim, num_relations): - super().__init__() - self.num_relations = num_relations - self.linears = nn.ModuleList([nn.Linear(in_dim, out_dim) for _ in range(num_relations)]) - self.bias = nn.Parameter(torch.zeros(out_dim)) - - def forward(self, adj, h): - """ - adj: [B, Y, N, N] - h: [B, N, D] - """ - out = 0 - for i in range(self.num_relations): - adj_i = adj[:, i, :, :] - h_i = self.linears[i](h) - out += torch.bmm(adj_i, h_i) - - out = out + self.bias - return F.relu(out) class MolGANDiscriminator(nn.Module): diff --git a/torch_molecule/generator/molgan/gan_utils.py b/torch_molecule/generator/molgan/gan_utils.py new file mode 100644 index 0000000..f773fc2 --- /dev/null +++ b/torch_molecule/generator/molgan/gan_utils.py @@ -0,0 +1,85 @@ +from typing import Optional +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from rdkit import Chem + + + + +class RelationalGCNLayer(nn.Module): + def __init__(self, in_dim, out_dim, num_relations): + super().__init__() + self.num_relations = num_relations + self.linears = nn.ModuleList([nn.Linear(in_dim, out_dim) for _ in range(num_relations)]) + self.bias = nn.Parameter(torch.zeros(out_dim)) + + def forward(self, adj, h): + """ + adj: [B, Y, N, N] + h: [B, N, D] + """ + out = 0 + for i in range(self.num_relations): + adj_i = adj[:, i, :, :] + h_i = self.linears[i](h) + out += torch.bmm(adj_i, h_i) + + out = out + self.bias + return F.relu(out) + + + + +def molgan_graph_from_smiles(smiles: str, atom_vocab: list, bond_types: list, max_nodes: int) -> Optional[dict]: + """ + Convert SMILES to MolGAN-style (adjacency, node) graph. + + Parameters + ---------- + smiles : str + SMILES string + + atom_vocab : list of str + List of allowed atom types (e.g., ['C', 'N', 'O', 'F']) + + bond_types : list of float + List of bond types (e.g., [1.0, 1.5, 2.0, 3.0]) + + max_nodes : int + Maximum number of atoms + + Returns + ------- + dict with keys: + 'adj': [Y, N, N] tensor + 'node': [N, T] tensor + """ + mol = Chem.MolFromSmiles(smiles) + if mol is None or mol.GetNumAtoms() > max_nodes: + return None + + T = len(atom_vocab) + Y = len(bond_types) + + node = np.zeros((max_nodes, T)) + for i, atom in enumerate(mol.GetAtoms()): + symbol = atom.GetSymbol() + if symbol in atom_vocab: + node[i, atom_vocab.index(symbol)] = 1 + + adj = np.zeros((Y, max_nodes, max_nodes)) + for bond in mol.GetBonds(): + i = bond.GetBeginAtomIdx() + j = bond.GetEndAtomIdx() + btype = bond.GetBondTypeAsDouble() + if btype in bond_types: + k = bond_types.index(btype) + adj[k, i, j] = 1 + adj[k, j, i] = 1 + + return { + "adj": torch.tensor(adj, dtype=torch.float32).unsqueeze(0), + "node": torch.tensor(node, dtype=torch.float32).unsqueeze(0) + } diff --git a/torch_molecule/generator/molgan/rewards.py b/torch_molecule/generator/molgan/rewards.py new file mode 100644 index 0000000..5096222 --- /dev/null +++ b/torch_molecule/generator/molgan/rewards.py @@ -0,0 +1,208 @@ +import torch +import torch.nn as nn +from typing import List, Optional, Union +from rdkit import Chem +from rdkit.Chem import QED, Crippen, rdMolDescriptors +from .gan_utils import RelationalGCNLayer, molgan_graph_from_smiles +from ...utils.graph.graph_from_smiles import graph_from_smiles + + +# Non-Neural reward functions based on RDKit +def qed_reward(smiles: str) -> float: + mol = Chem.MolFromSmiles(smiles) + return QED.qed(mol) if mol else 0.0 + +def logp_reward(smiles: str) -> float: + mol = Chem.MolFromSmiles(smiles) + return Crippen.MolLogP(mol) if mol else 0.0 + +def weight_reward(smiles: str) -> float: + mol = Chem.MolFromSmiles(smiles) + return rdMolDescriptors.CalcExactMolWt(mol) if mol else 0.0 + +def combo_reward(smiles: str, weights=(0.7, 0.3)) -> float: + mol = Chem.MolFromSmiles(smiles) + if mol is None: + return 0.0 + qed_score = QED.qed(mol) + logp_score = Crippen.MolLogP(mol) + return weights[0] * qed_score + weights[1] * logp_score + +class RewardOracle: + def __init__(self, kind="qed"): + if kind == "qed": + self.func = qed_reward + elif kind == "logp": + self.func = logp_reward + elif kind == "combo": + self.func = lambda s: combo_reward(s, weights=(0.6, 0.4)) + else: + raise ValueError(f"Unknown reward type: {kind}") + + def __call__(self, smiles: str) -> float: + return self.func(smiles) + + + + + +# Reward Network using Relational GCNs +class RewardNeuralNetwork(nn.Module): + """ + Reward Network that predicts reward from (adj, node) graphs. + """ + + def __init__(self, num_atom_types=5, num_bond_types=4, hidden_dim=128, num_layers=2, num_nodes=9): + super().__init__() + self.gcn_layers = nn.ModuleList() + self.gcn_layers.append(RelationalGCNLayer(num_atom_types, hidden_dim, num_bond_types)) + + for _ in range(1, num_layers): + self.gcn_layers.append(RelationalGCNLayer(hidden_dim, hidden_dim, num_bond_types)) + + self.readout = nn.Sequential( + nn.Linear(num_nodes * hidden_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, 1) + ) + + def forward(self, adj, node): + """ + adj: [B, Y, N, N] + node: [B, N, T] + """ + h = node + for layer in self.gcn_layers: + h = layer(adj, h) + + h = h.view(h.size(0), -1) + return self.readout(h).squeeze(-1) + + +def fit_reward_network( + reward_model: RewardNeuralNetwork, + train_loader, + epochs: int = 10, + lr: float = 1e-3, + weight_decay: float = 0.0, + device: str = "cpu", + verbose: bool = True +): + """ + Train the reward model to approximate oracle rewards. + + Parameters + ---------- + reward_model : RewardNeuralNetwork + The neural network to train + + train_loader : DataLoader + Yields batches of (adj, node, reward) + + epochs : int + Number of training epochs + + lr : float + Learning rate + + weight_decay : float + Optional L2 regularization + + device : str + Device to run on ("cpu" or "cuda") + + verbose : bool + Whether to print losses + """ + model = reward_model.to(device) + optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay) + criterion = nn.MSELoss() + + model.train() + for epoch in range(epochs): + epoch_losses = [] + + for batch in train_loader: + adj = batch["adj"].to(device) # [B, Y, N, N] + node = batch["node"].to(device) # [B, N, T] + reward = batch["reward"].to(device) # [B] + + pred = model(adj, node) # [B] + loss = criterion(pred, reward) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + epoch_losses.append(loss.item()) + + if verbose: + print(f"[Epoch {epoch+1}/{epochs}] RewardNet Loss: {sum(epoch_losses)/len(epoch_losses):.4f}") + + + + + +# Combined reward wrapper: which uses either neural or oracle rewards +class RewardNetwork: + """ + Combined reward network that can use either neural or oracle rewards. + """ + + def __init__(self, kind: str = "qed", num_atom_types=5, num_bond_types=4, hidden_dim=128, num_layers=2, num_nodes=9): + if kind in ["qed", "logp", "combo"]: + self.oracle = RewardOracle(kind) + self.neural = None + else: + self.oracle = None + self.neural = RewardNeuralNetwork(num_atom_types, num_bond_types, hidden_dim, num_layers, num_nodes) + + def train_neural(self, train_loader, epochs=10, lr=1e-3, weight_decay=0.0, device="cpu", verbose=True): + """ + Train the neural reward network using the provided DataLoader. + """ + if self.neural is None: + raise ValueError("No neural network defined. Use an oracle reward instead.") + + fit_reward_network( + self.neural, + train_loader, + epochs=epochs, + lr=lr, + weight_decay=weight_decay, + device=device, + verbose=verbose + ) + + + def default_converter(self, smiles: str) -> tuple: + try: + graph = molgan_graph_from_smiles( + smiles, + atom_vocab=["C", "N", "O", "F"], + bond_types=[1.0, 1.5, 2.0, 3.0], + max_nodes=9 + ) + + if graph is None: + return None, None + + adj = torch.tensor(graph["adj"], dtype=torch.float32).unsqueeze(0) + node = torch.tensor(graph["node"], dtype=torch.float32).unsqueeze(0) + return adj, node + except Exception as e: + print(f"[RewardNetwork] SMILES conversion failed: {smiles} → {e}") + return None, None + + def __call__(self, smiles: str) -> float: + if self.oracle: + return self.oracle(smiles) + elif self.neural: + # Convert SMILES to graph representation and pass through neural network + adj, node = self.default_converter(smiles) + if adj is not None and node is not None: + return self.neural(adj, node).item() + else: + return 0.0 + else: + raise ValueError("No valid reward function defined.") From 8bd133b6115c826a90d2b1251fde29f90bf573dd Mon Sep 17 00:00:00 2001 From: Manda Kausthubh Date: Tue, 8 Jul 2025 23:29:29 +0530 Subject: [PATCH 05/14] Created MolGAN's basic implementation --- torch_molecule/generator/molgan/dataset.py | 3 +- torch_molecule/generator/molgan/gan.py | 266 ++++++++++++++++-- torch_molecule/generator/molgan/gan_utils.py | 110 ++++++++ .../generator/molgan/modeling_molgan.py | 103 +++++++ torch_molecule/generator/molgan/rewards.py | 93 +++--- 5 files changed, 505 insertions(+), 70 deletions(-) diff --git a/torch_molecule/generator/molgan/dataset.py b/torch_molecule/generator/molgan/dataset.py index ba74d63..322d946 100644 --- a/torch_molecule/generator/molgan/dataset.py +++ b/torch_molecule/generator/molgan/dataset.py @@ -1,9 +1,8 @@ from torch.utils.data import Dataset import torch -from typing import List, Optional, Union, Callable +from typing import List, Optional, Callable from rdkit import Chem -# assumes you already have this: from ...utils.graph.graph_from_smiles import graph_from_smiles diff --git a/torch_molecule/generator/molgan/gan.py b/torch_molecule/generator/molgan/gan.py index 81c2770..1055e6d 100644 --- a/torch_molecule/generator/molgan/gan.py +++ b/torch_molecule/generator/molgan/gan.py @@ -1,31 +1,265 @@ +from typing import Optional, List import torch import torch.nn as nn + from .generator import MolGANGenerator from .discriminator import MolGANDiscriminator +from .rewards import RewardNetwork +from .gan_utils import decode_smiles, encode_smiles_to_graph +from ...utils.graph.graph_to_smiles import graph_to_smiles + class MolGAN(nn.Module): + """ - Combined MolGAN model: generator + discriminator + Full MolGAN model integrating: + - Generator + - Discriminator + - Reward Network (oracle or neural) """ - def __init__(self, gen_config, disc_config): + def __init__( + self, + generator_config, + discriminator_config, + reward_config, + use_reward=True, + reward_lambda=1.0, + device="cpu"): super().__init__() - self.generator = MolGANGenerator(gen_config) - self.discriminator = MolGANDiscriminator(disc_config) + self.device = device + self.use_reward = use_reward + self.reward_lambda = reward_lambda + + self.generator = MolGANGenerator(generator_config).to(device) + self.discriminator = MolGANDiscriminator(discriminator_config).to(device) + self.reward = RewardNetwork(**reward_config) if use_reward else None + + self.gen_opt = torch.optim.Adam(self.generator.parameters(), lr=generator_config.get("lr", 1e-3)) + self.dis_opt = torch.optim.Adam(self.discriminator.parameters(), lr=discriminator_config.get("lr", 1e-3)) + + def generate(self, batch_size): + z = torch.randn(batch_size, self.generator.latent_dim).to(self.device) + adj, node = self.generator(z) + return adj, node + + def compute_rewards( + self, + smiles_list: Optional[List[str]] = None, + adj: Optional[torch.Tensor] = None, + node: Optional[torch.Tensor] = None, + ): + """ + Compute reward using the internal RewardNetwork, either from SMILES or from graph tensors. + + Parameters + ---------- + smiles_list : List[str], optional + List of SMILES strings to compute rewards for + + adj : Tensor [B, Y, N, N], optional + Adjacency tensor + + node : Tensor [B, N, T], optional + Node tensor + + Returns + ------- + Tensor [B] + Reward values + """ + if self.reward is None: + if adj is None or node is None: + raise ValueError("Either smiles_list or (adj, node) must be provided for reward computation.") + return torch.zeros(adj.size(0), device=self.device) + + if smiles_list is not None: + adjs, nodes = [], [] + for smiles in smiles_list: + try: + encoded_graph = encode_smiles_to_graph( + smiles, + atom_vocab=self.atom_decoder, + bond_types=self.bond_types, + max_nodes=self.max_nodes + ) + if encoded_graph is None: + raise ValueError(f"Invalid SMILES: {smiles}") + a, n = encoded_graph + adjs.append(a) + nodes.append(n) + except Exception: + # fallback to zeros if decoding fails + adjs.append(torch.zeros(len(self.bond_types), self.max_nodes, self.max_nodes)) + nodes.append(torch.zeros(self.max_nodes, len(self.atom_decoder))) + + adj_batch = torch.stack(adjs).to(self.device) + node_batch = torch.stack(nodes).to(self.device) + return self.reward(adj_batch, node_batch) + + elif adj is not None and node is not None: + return self.reward(adj.to(self.device), node.to(self.device)) + + else: + raise ValueError("Either smiles_list or (adj, node) must be provided for reward computation.") + + + def fit( + self, + data_loader, + epochs: int = 10, + log_every: int = 1, + reward_scale: float = 1.0 + ): + """ + Train the MolGAN model using adversarial and (optional) reward-based learning. + + Parameters + ---------- + data_loader : DataLoader + DataLoader yielding batches of {"adj", "node", "smiles"} dictionaries. + + epochs : int + Number of training epochs. + + log_every : int + Frequency of logging losses. + + reward_scale : float + Weight of reward loss in the generator's total loss. + """ + self.generator.train() + self.discriminator.train() + if self.reward and hasattr(self.reward, 'neural') and self.reward.neural: + self.reward.neural.eval() + + for epoch in range(epochs): + d_losses, g_losses, reward_vals = [], [], [] - def generate(self, z): - """Forward pass through generator only.""" - return self.generator(z) + for batch in data_loader: + real_adj = batch["adj"].to(self.device) # [B, Y, N, N] + real_node = batch["node"].to(self.device) # [B, N, T] + smiles = batch.get("smiles", None) - def discriminate(self, adj, node): - """Forward pass through discriminator only.""" - return self.discriminator(adj, node) + batch_size = real_adj.size(0) - def forward(self, z): + # === Train Discriminator === + self.dis_opt.zero_grad() + fake_adj, fake_node = self.generate(batch_size) + + real_logits = self.discriminator(real_adj, real_node) + fake_logits = self.discriminator(fake_adj.detach(), fake_node.detach()) + + d_loss = -torch.mean(real_logits) + torch.mean(fake_logits) + d_loss.backward() + self.dis_opt.step() + + # === Train Generator === + self.gen_opt.zero_grad() + fake_logits = self.discriminator(fake_adj, fake_node) + g_loss = -torch.mean(fake_logits) + + # === Add reward loss if applicable === + if self.use_reward: + rewards = self.compute_rewards(adj=fake_adj, node=fake_node) # [B] + reward_loss = -rewards.mean() + g_loss += reward_scale * reward_loss + reward_vals.append(rewards.mean().item()) + else: + reward_vals.append(0.0) + + g_loss.backward() + self.gen_opt.step() + + d_losses.append(d_loss.item()) + g_losses.append(g_loss.item()) + + if (epoch + 1) % log_every == 0: + print(f"[Epoch {epoch+1}/{epochs}] " + f"D_loss: {sum(d_losses)/len(d_losses):.4f} | " + f"G_loss: {sum(g_losses)/len(g_losses):.4f} | " + f"Reward: {sum(reward_vals)/len(reward_vals):.4f}") + + + def fit( + self, + data_loader, + epochs: int = 10, + log_every: int = 1, + reward_scale: float = 1.0 + ): """ - Combined forward pass (generator → discriminator). - Used for adversarial training. + Train the MolGAN model using adversarial and (optional) reward-based learning. + + Parameters + ---------- + data_loader : DataLoader + DataLoader yielding batches of {"adj", "node", "smiles"} dictionaries. + + epochs : int + Number of training epochs. + + log_every : int + Frequency of logging losses. + + reward_scale : float + Weight of reward loss in the generator's total loss. """ - adj_fake, node_fake = self.generator(z) - pred_fake = self.discriminator(adj_fake, node_fake) - return adj_fake, node_fake, pred_fake + self.generator.train() + self.discriminator.train() + if self.reward and hasattr(self.reward, 'neural') and self.reward.neural: + self.reward.neural.eval() + + for epoch in range(epochs): + d_losses, g_losses, reward_vals = [], [], [] + + for batch in data_loader: + real_adj = batch["adj"].to(self.device) + real_node = batch["node"].to(self.device) + smiles = batch.get("smiles", None) + + batch_size = real_adj.size(0) + + # === Train Discriminator === + self.dis_opt.zero_grad() + fake_adj, fake_node = self.generate(batch_size) + + real_logits = self.discriminator(real_adj, real_node) + fake_logits = self.discriminator(fake_adj.detach(), fake_node.detach()) + + d_loss = -torch.mean(real_logits) + torch.mean(fake_logits) + d_loss.backward() + self.dis_opt.step() + + # === Train Generator === + self.gen_opt.zero_grad() + fake_logits = self.discriminator(fake_adj, fake_node) + g_loss = -torch.mean(fake_logits) + + # === Add reward loss if applicable === + if self.use_reward: + # Convert fake graphs to SMILES if reward expects SMILES + if self.reward.oracle is not None: + smiles_fake = self.decode_smiles(fake_adj, fake_node) + rewards = self.compute_rewards(smiles_list=smiles_fake) + else: + rewards = self.compute_rewards(adj=fake_adj, node=fake_node) + + reward_loss = -rewards.mean() + g_loss += reward_scale * reward_loss + reward_vals.append(rewards.mean().item()) + else: + reward_vals.append(0.0) + + g_loss.backward() + self.gen_opt.step() + + d_losses.append(d_loss.item()) + g_losses.append(g_loss.item()) + + if (epoch + 1) % log_every == 0: + print(f"[Epoch {epoch+1}/{epochs}] " + f"D_loss: {sum(d_losses)/len(d_losses):.4f} | " + f"G_loss: {sum(g_losses)/len(g_losses):.4f} | " + f"Reward: {sum(reward_vals)/len(reward_vals):.4f}") + diff --git a/torch_molecule/generator/molgan/gan_utils.py b/torch_molecule/generator/molgan/gan_utils.py index f773fc2..153286e 100644 --- a/torch_molecule/generator/molgan/gan_utils.py +++ b/torch_molecule/generator/molgan/gan_utils.py @@ -4,6 +4,8 @@ import torch.nn.functional as F import numpy as np from rdkit import Chem +from ...utils.graph.graph_to_smiles import graph_to_smiles +from ...utils.graph.graph_from_smiles import graph_from_smiles @@ -30,6 +32,114 @@ def forward(self, adj, h): return F.relu(out) +def encode_smiles_to_graph( + smiles: str, + atom_vocab: list = ["C", "N", "O", "F"], + bond_types: list = [1.0, 1.5, 2.0, 3.0], + max_nodes: int = 9 +) -> Optional[tuple[torch.Tensor, torch.Tensor]]: + """ + Convert a SMILES string into (adj, node) tensors. + + Parameters + ---------- + smiles : str + Input SMILES string + + atom_vocab : list of str + List of valid atom types + + bond_types : list of float + Allowed bond types (e.g., 1.0: single, 2.0: double) + + max_nodes : int + Max number of atoms (graph will be padded) + + Returns + ------- + adj : Tensor [Y, N, N] + Multi-relational adjacency tensor + + node : Tensor [N, T] + One-hot atom features + """ + mol = Chem.MolFromSmiles(smiles) + if mol is None or mol.GetNumAtoms() > max_nodes: + # raise ValueError(f"Invalid or oversized molecule: {smiles}") + return None + + N = max_nodes + T = len(atom_vocab) + Y = len(bond_types) + + # Initialize node features + node = np.zeros((N, T), dtype=np.float32) + for i, atom in enumerate(mol.GetAtoms()): + if i >= N: + break + symbol = atom.GetSymbol() + if symbol in atom_vocab: + node[i, atom_vocab.index(symbol)] = 1.0 + + # Initialize adjacency tensor + adj = np.zeros((Y, N, N), dtype=np.float32) + for bond in mol.GetBonds(): + i = bond.GetBeginAtomIdx() + j = bond.GetEndAtomIdx() + btype = bond.GetBondTypeAsDouble() + if btype in bond_types and i < N and j < N: + k = bond_types.index(btype) + adj[k, i, j] = 1.0 + adj[k, j, i] = 1.0 # undirected + + # Convert to torch.Tensor + return torch.tensor(adj), torch.tensor(node) + + +def decode_smiles( + adj: torch.Tensor, + node: torch.Tensor, + atom_decoder: list = ["C", "N", "O", "F"] + ) -> list: + """ + Convert a batch of (adj, node) tensors to SMILES strings. + + Parameters + ---------- + adj : torch.Tensor + Adjacency tensor of shape [B, Y, N, N] + + node : torch.Tensor + Node feature tensor of shape [B, N, T] + + atom_decoder : list of str + Atom types in order of one-hot encoding indices + + Returns + ------- + List[str or None] + Decoded SMILES strings or None for invalid molecules + """ + # Ensure tensors are detached and moved to CPU + adj_np = adj.detach().cpu().numpy() + node_np = node.detach().cpu().numpy() + + # Build molecule list + molecule_list = list(zip(node_np, adj_np)) + + # Decode into SMILES strings + smiles_list = graph_to_smiles(molecule_list, atom_decoder) + + return smiles_list + + + + + + + + + def molgan_graph_from_smiles(smiles: str, atom_vocab: list, bond_types: list, max_nodes: int) -> Optional[dict]: diff --git a/torch_molecule/generator/molgan/modeling_molgan.py b/torch_molecule/generator/molgan/modeling_molgan.py index e69de29..504b35c 100644 --- a/torch_molecule/generator/molgan/modeling_molgan.py +++ b/torch_molecule/generator/molgan/modeling_molgan.py @@ -0,0 +1,103 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .generator import MolGANGenerator +from .discriminator import MolGANDiscriminator +from .rewards import RewardNetwork + + +class MolGAN(nn.Module): + """ + Full MolGAN model integrating: + - Generator + - Discriminator + - Reward Network (oracle or neural) + """ + + def __init__( + self, + generator_config, + discriminator_config, + reward_config, + use_reward=True, + reward_lambda=1.0, + device="cpu" + ): + super().__init__() + self.device = device + self.use_reward = use_reward + self.reward_lambda = reward_lambda + + self.generator = MolGANGenerator(generator_config).to(device) + self.discriminator = MolGANDiscriminator(discriminator_config).to(device) + self.reward = RewardNetwork(**reward_config) if use_reward else None + + self.gen_opt = torch.optim.Adam(self.generator.parameters(), lr=generator_config.get("lr", 1e-3)) + self.dis_opt = torch.optim.Adam(self.discriminator.parameters(), lr=discriminator_config.get("lr", 1e-3)) + + def generate(self, batch_size): + z = torch.randn(batch_size, self.generator.latent_dim).to(self.device) + adj, node = self.generator(z) + return adj, node + + def compute_rewards(self, smiles_list=None, adj=None, node=None): + if self.reward is None: + if adj is None or node is None: + return torch.zeros(adj.size(0), device=self.device) + return torch.tensor([ + self.reward(smiles=s) for s in smiles_list + ], dtype=torch.float32, device=self.device) if smiles_list else self.reward(adj=adj, node=node) + + def decode_smiles(self, adj, node): + """ + Convert batch of (adj, node) to SMILES strings + This requires your graph_to_smiles function + """ + from your_utils.graph_to_smiles import graph_to_smiles + graphs = list(zip(node.cpu().numpy(), adj.cpu().numpy())) + return graph_to_smiles(graphs, atom_decoder=["C", "N", "O", "F"]) + + def fit(self, data_loader, epochs=10, log_every=1): + for epoch in range(epochs): + epoch_d_loss, epoch_g_loss, epoch_rewards = [], [], [] + for batch in data_loader: + real_adj = batch["adj"].to(self.device) + real_node = batch["node"].to(self.device) + real_smiles = batch.get("smiles", None) + + # === Train Discriminator === + self.dis_opt.zero_grad() + fake_adj, fake_node = self.generate(real_adj.size(0)) + + real_logits = self.discriminator(real_adj, real_node) + fake_logits = self.discriminator(fake_adj.detach(), fake_node.detach()) + d_loss = -torch.mean(real_logits) + torch.mean(fake_logits) + d_loss.backward() + self.dis_opt.step() + + # === Train Generator === + self.gen_opt.zero_grad() + fake_logits = self.discriminator(fake_adj, fake_node) + g_loss = -torch.mean(fake_logits) + + # Add reward loss if enabled + if self.use_reward: + smiles_fake = self.decode_smiles(fake_adj, fake_node) + rewards = self.compute_rewards(smiles_list=smiles_fake) + reward_loss = -torch.mean(rewards) + g_loss = g_loss + self.reward_lambda * reward_loss + else: + rewards = torch.zeros(real_adj.size(0)) + + g_loss.backward() + self.gen_opt.step() + + epoch_d_loss.append(d_loss.item()) + epoch_g_loss.append(g_loss.item()) + epoch_rewards.append(rewards.mean().item()) + + if (epoch + 1) % log_every == 0: + print(f"Epoch [{epoch + 1}/{epochs}] D_loss: {sum(epoch_d_loss)/len(epoch_d_loss):.4f}, " + f"G_loss: {sum(epoch_g_loss)/len(epoch_g_loss):.4f}, " + f"Reward: {sum(epoch_rewards)/len(epoch_rewards):.4f}") diff --git a/torch_molecule/generator/molgan/rewards.py b/torch_molecule/generator/molgan/rewards.py index 5096222..c0abfc8 100644 --- a/torch_molecule/generator/molgan/rewards.py +++ b/torch_molecule/generator/molgan/rewards.py @@ -1,10 +1,10 @@ +from typing import Optional import torch import torch.nn as nn -from typing import List, Optional, Union from rdkit import Chem from rdkit.Chem import QED, Crippen, rdMolDescriptors from .gan_utils import RelationalGCNLayer, molgan_graph_from_smiles -from ...utils.graph.graph_from_smiles import graph_from_smiles +from ...utils.graph.graph_to_smiles import graph_to_smiles # Non-Neural reward functions based on RDKit @@ -123,9 +123,9 @@ def fit_reward_network( epoch_losses = [] for batch in train_loader: - adj = batch["adj"].to(device) # [B, Y, N, N] - node = batch["node"].to(device) # [B, N, T] - reward = batch["reward"].to(device) # [B] + adj = batch["adj"].to(device) + node = batch["node"].to(device) + reward = batch["reward"].to(device) pred = model(adj, node) # [B] loss = criterion(pred, reward) @@ -146,63 +146,52 @@ def fit_reward_network( # Combined reward wrapper: which uses either neural or oracle rewards class RewardNetwork: """ - Combined reward network that can use either neural or oracle rewards. + Combined reward network that uses either a neural model or an oracle. + Accepts (adj, node) tensors as standard input. """ - def __init__(self, kind: str = "qed", num_atom_types=5, num_bond_types=4, hidden_dim=128, num_layers=2, num_nodes=9): + def __init__( + self, + kind="qed", + reward_net: Optional[RewardNeuralNetwork] = None, + atom_decoder=None, + device="cpu"): + self.kind = kind + self.device = device + self.atom_decoder = atom_decoder or ["C", "N", "O", "F"] + if kind in ["qed", "logp", "combo"]: self.oracle = RewardOracle(kind) self.neural = None - else: + elif kind == "neural": + assert reward_net is not None, "reward_net must be provided for 'neural' mode" self.oracle = None - self.neural = RewardNeuralNetwork(num_atom_types, num_bond_types, hidden_dim, num_layers, num_nodes) + self.neural = reward_net.to(device).eval() + else: + raise ValueError(f"Invalid kind: {kind}") - def train_neural(self, train_loader, epochs=10, lr=1e-3, weight_decay=0.0, device="cpu", verbose=True): - """ - Train the neural reward network using the provided DataLoader. + def __call__(self, adj: torch.Tensor, node: torch.Tensor) -> torch.Tensor: """ - if self.neural is None: - raise ValueError("No neural network defined. Use an oracle reward instead.") - - fit_reward_network( - self.neural, - train_loader, - epochs=epochs, - lr=lr, - weight_decay=weight_decay, - device=device, - verbose=verbose - ) - + Compute reward from graph tensors. - def default_converter(self, smiles: str) -> tuple: - try: - graph = molgan_graph_from_smiles( - smiles, - atom_vocab=["C", "N", "O", "F"], - bond_types=[1.0, 1.5, 2.0, 3.0], - max_nodes=9 - ) + Parameters + ---------- + adj : Tensor [B, Y, N, N] + node : Tensor [B, N, T] - if graph is None: - return None, None + Returns + ------- + Tensor [B] : reward per sample + """ + if self.neural: + with torch.no_grad(): + return self.neural(adj.to(self.device), node.to(self.device)) - adj = torch.tensor(graph["adj"], dtype=torch.float32).unsqueeze(0) - node = torch.tensor(graph["node"], dtype=torch.float32).unsqueeze(0) - return adj, node - except Exception as e: - print(f"[RewardNetwork] SMILES conversion failed: {smiles} → {e}") - return None, None + elif self.oracle: + graphs = list(zip(node.cpu().numpy(), adj.cpu().numpy())) + smiles_list = graph_to_smiles(graphs, self.atom_decoder) + rewards = [self.oracle(s) if s else 0.0 for s in smiles_list] + return torch.tensor(rewards, dtype=torch.float32, device=self.device) - def __call__(self, smiles: str) -> float: - if self.oracle: - return self.oracle(smiles) - elif self.neural: - # Convert SMILES to graph representation and pass through neural network - adj, node = self.default_converter(smiles) - if adj is not None and node is not None: - return self.neural(adj, node).item() - else: - return 0.0 else: - raise ValueError("No valid reward function defined.") + raise ValueError("No reward function defined.") From b7fc773451c50bbaf3b2e32597a9e6abe4f6ae5a Mon Sep 17 00:00:00 2001 From: Manda Kausthubh Date: Wed, 9 Jul 2025 15:20:04 +0530 Subject: [PATCH 06/14] Has errors needs fixes --- torch_molecule/generator/molgan/dataset.py | 18 +- .../generator/molgan/discriminator.py | 32 +- torch_molecule/generator/molgan/gan.py | 344 +++++++----------- torch_molecule/generator/molgan/generator.py | 36 +- .../generator/molgan/modeling_molgan.py | 8 +- torch_molecule/generator/molgan/rewards.py | 11 +- 6 files changed, 213 insertions(+), 236 deletions(-) diff --git a/torch_molecule/generator/molgan/dataset.py b/torch_molecule/generator/molgan/dataset.py index 322d946..1626513 100644 --- a/torch_molecule/generator/molgan/dataset.py +++ b/torch_molecule/generator/molgan/dataset.py @@ -3,6 +3,8 @@ from typing import List, Optional, Callable from rdkit import Chem +from .rewards import RewardNetwork +from .gan_utils import encode_smiles_to_graph from ...utils.graph.graph_from_smiles import graph_from_smiles @@ -19,7 +21,7 @@ class MolGraphDataset(Dataset): def __init__(self, smiles_list: List[str], - reward_function: Optional[Callable[[str], float]] = None, + reward_function: Optional[RewardNetwork] = None, max_nodes: int = 9, drop_invalid: bool = True): """ @@ -50,7 +52,11 @@ def __init__(self, raise ValueError("Too many atoms") # Compute reward if needed - reward = reward_function(smiles) if reward_function else 0.0 + graph = encode_smiles_to_graph(smiles) + if graph is None: + raise ValueError("Failed to encode SMILES to graph") + adj, node = graph + reward = reward_function(node, adj) if reward_function else 0.0 # Convert to graph graph = graph_from_smiles(smiles, properties=reward) @@ -85,3 +91,11 @@ def __getitem__(self, idx): "reward": torch.tensor(sample["reward"], dtype=torch.float32), "smiles": sample["smiles"] } + +def molgan_collate_fn(batch): + adj = torch.stack([item["adj"] for item in batch], dim=0) + node = torch.stack([item["node"] for item in batch], dim=0) + reward = torch.stack([item["reward"] for item in batch], dim=0) + smiles = [item["smiles"] for item in batch] + return {"adj": adj, "node": node, "reward": reward, "smiles": smiles} + diff --git a/torch_molecule/generator/molgan/discriminator.py b/torch_molecule/generator/molgan/discriminator.py index d3afb20..7773e39 100644 --- a/torch_molecule/generator/molgan/discriminator.py +++ b/torch_molecule/generator/molgan/discriminator.py @@ -47,24 +47,40 @@ class MolGANDiscriminator(nn.Module): Discriminator network for MolGAN using stacked Relational GCNs. """ - def __init__(self, config: MolGANDiscriminatorConfig): + def __init__(self, + num_atom_types=5, + num_bond_types=4, + num_nodes=9, + hidden_dims=[128, 128]): super().__init__() - self.config = config + + self.num_atom_types = num_atom_types + self.num_bond_types = num_bond_types + self.num_nodes = num_nodes + self.hidden_dim = hidden_dims + self.num_layers = len(hidden_dims) + 1 # I'm including the input layer self.gcn_layers = nn.ModuleList() self.gcn_layers.append( - RelationalGCNLayer(config.num_atom_types, config.hidden_dim, config.num_bond_types) + RelationalGCNLayer(num_atom_types, hidden_dims[0], num_bond_types) ) - for _ in range(1, config.num_layers): - self.gcn_layers.append( - RelationalGCNLayer(config.hidden_dim, config.hidden_dim, config.num_bond_types) + # for _ in range(1, num_layers): + # self.gcn_layers.append( + # RelationalGCNLayer(hidden_dims, hidden_dims, num_bond_types) + # ) + + input_dim = hidden_dims[0] + for hidden_dim in hidden_dims[1:]: + self.gcn_layers.append( + RelationalGCNLayer(input_dim, hidden_dim, num_bond_types) ) + input_dim = hidden_dim self.readout = nn.Sequential( - nn.Linear(config.num_nodes * config.hidden_dim, config.hidden_dim), + nn.Linear(num_nodes * hidden_dims[-1], hidden_dims[-1]), nn.ReLU(), - nn.Linear(config.hidden_dim, 1) + nn.Linear(hidden_dims[-1], 1) ) def forward(self, adj, node): diff --git a/torch_molecule/generator/molgan/gan.py b/torch_molecule/generator/molgan/gan.py index 1055e6d..b7b0b74 100644 --- a/torch_molecule/generator/molgan/gan.py +++ b/torch_molecule/generator/molgan/gan.py @@ -1,265 +1,193 @@ -from typing import Optional, List import torch -import torch.nn as nn +import torch.nn.functional as F +from dataclasses import dataclass +from typing import Optional, List + +from torch_molecule.base.generator import BaseMolecularGenerator +# If for future compatibility, do ensure Configs are imported from .generator import MolGANGenerator from .discriminator import MolGANDiscriminator from .rewards import RewardNetwork from .gan_utils import decode_smiles, encode_smiles_to_graph +from .dataset import MolGraphDataset, molgan_collate_fn from ...utils.graph.graph_to_smiles import graph_to_smiles +from typing import List, Optional +from dataclasses import field +import numpy as np +import torch -class MolGAN(nn.Module): +# The actual MolGAN implementation +@dataclass +class MolGAN(BaseMolecularGenerator): + """MolGAN implementation compatible with BaseMolecularGenerator interface.""" - """ - Full MolGAN model integrating: - - Generator - - Discriminator - - Reward Network (oracle or neural) - """ + model_name: str = field(default="MolGAN") def __init__( - self, - generator_config, - discriminator_config, - reward_config, - use_reward=True, - reward_lambda=1.0, - device="cpu"): + self, + latent_dim: int = 56, + hidden_dims_gen: List[int] = [128,128], + hidden_dims_disc: List[int] = [128, 128], + num_nodes: int = 9, + tau: float = 1.0, + num_atom_types: int = 5, + num_bond_types: int = 4, + use_reward: bool = False, + reward_network: Optional[RewardNetwork] = None, + device: Optional[str] = None + ): super().__init__() - self.device = device - self.use_reward = use_reward - self.reward_lambda = reward_lambda - - self.generator = MolGANGenerator(generator_config).to(device) - self.discriminator = MolGANDiscriminator(discriminator_config).to(device) - self.reward = RewardNetwork(**reward_config) if use_reward else None - - self.gen_opt = torch.optim.Adam(self.generator.parameters(), lr=generator_config.get("lr", 1e-3)) - self.dis_opt = torch.optim.Adam(self.discriminator.parameters(), lr=discriminator_config.get("lr", 1e-3)) - def generate(self, batch_size): - z = torch.randn(batch_size, self.generator.latent_dim).to(self.device) - adj, node = self.generator(z) - return adj, node + self.latent_dim = latent_dim + self.hidden_dims_gen = hidden_dims_gen + self.hidden_dims_disc = hidden_dims_disc + self.num_nodes = num_nodes + self.num_atom_types = num_atom_types + self.num_bond_types = num_bond_types + self.use_reward = use_reward + self.reward = reward_network + self.tau = tau + + self.device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu")) + + self.generator = MolGANGenerator( + latent_dim=latent_dim, + hidden_dims=hidden_dims_gen, + num_nodes=num_nodes, + num_atom_types=num_atom_types, + num_bond_types=num_bond_types, + tau=tau + ).to(self.device) + + self.discriminator = MolGANDiscriminator( + hidden_dims=hidden_dims_disc, + num_nodes=num_nodes, + num_atom_types=num_atom_types, + num_bond_types=num_bond_types + ).to(self.device) + + self.gen_opt = torch.optim.Adam(self.generator.parameters(), lr=1e-4) + self.dis_opt = torch.optim.Adam(self.discriminator.parameters(), lr=1e-4) - def compute_rewards( + def fit( self, - smiles_list: Optional[List[str]] = None, - adj: Optional[torch.Tensor] = None, - node: Optional[torch.Tensor] = None, - ): + X: List[str], + y: Optional[np.ndarray] = None, + epochs: int = 10, + batch_size: int = 32 + ) -> "MolGAN": """ - Compute reward using the internal RewardNetwork, either from SMILES or from graph tensors. + Fit the MolGAN model to a list of SMILES strings. Parameters ---------- - smiles_list : List[str], optional - List of SMILES strings to compute rewards for + X : List[str] + List of training SMILES strings. - adj : Tensor [B, Y, N, N], optional - Adjacency tensor + y : Optional[np.ndarray] + Optional reward targets. (Unused if using oracle or no reward) - node : Tensor [B, N, T], optional - Node tensor + epochs : int + Number of training epochs. + + batch_size : int + Batch size for training. Returns ------- - Tensor [B] - Reward values + self : MolGAN + The trained model. """ - if self.reward is None: - if adj is None or node is None: - raise ValueError("Either smiles_list or (adj, node) must be provided for reward computation.") - return torch.zeros(adj.size(0), device=self.device) - - if smiles_list is not None: - adjs, nodes = [], [] - for smiles in smiles_list: - try: - encoded_graph = encode_smiles_to_graph( - smiles, - atom_vocab=self.atom_decoder, - bond_types=self.bond_types, - max_nodes=self.max_nodes - ) - if encoded_graph is None: - raise ValueError(f"Invalid SMILES: {smiles}") - a, n = encoded_graph - adjs.append(a) - nodes.append(n) - except Exception: - # fallback to zeros if decoding fails - adjs.append(torch.zeros(len(self.bond_types), self.max_nodes, self.max_nodes)) - nodes.append(torch.zeros(self.max_nodes, len(self.atom_decoder))) - - adj_batch = torch.stack(adjs).to(self.device) - node_batch = torch.stack(nodes).to(self.device) - return self.reward(adj_batch, node_batch) - - elif adj is not None and node is not None: - return self.reward(adj.to(self.device), node.to(self.device)) - - else: - raise ValueError("Either smiles_list or (adj, node) must be provided for reward computation.") + from torch.utils.data import DataLoader - def fit( - self, - data_loader, - epochs: int = 10, - log_every: int = 1, - reward_scale: float = 1.0 - ): - """ - Train the MolGAN model using adversarial and (optional) reward-based learning. - - Parameters - ---------- - data_loader : DataLoader - DataLoader yielding batches of {"adj", "node", "smiles"} dictionaries. - - epochs : int - Number of training epochs. + dataset = MolGraphDataset( + smiles_list=X, + reward_function=self.reward if self.use_reward else None, + max_nodes=self.num_nodes + ) - log_every : int - Frequency of logging losses. + train_loader = DataLoader( + dataset, + batch_size=batch_size, + shuffle=True, + collate_fn=molgan_collate_fn, + drop_last=True + ) - reward_scale : float - Weight of reward loss in the generator's total loss. - """ self.generator.train() self.discriminator.train() - if self.reward and hasattr(self.reward, 'neural') and self.reward.neural: - self.reward.neural.eval() - for epoch in range(epochs): - d_losses, g_losses, reward_vals = [], [], [] + for epoch in range(1, epochs + 1): + epoch_d_loss = [] + epoch_g_loss = [] - for batch in data_loader: - real_adj = batch["adj"].to(self.device) # [B, Y, N, N] - real_node = batch["node"].to(self.device) # [B, N, T] - smiles = batch.get("smiles", None) + for batch in train_loader: + real_adj = batch["adj"].to(self.device) # [B, Y, N, N] + real_node = batch["node"].to(self.device) # [B, N, T] + real_reward = batch["reward"].to(self.device) # [B] - batch_size = real_adj.size(0) + batch_size_actual = real_adj.size(0) + z = torch.randn(batch_size_actual, self.latent_dim).to(self.device) # === Train Discriminator === self.dis_opt.zero_grad() - fake_adj, fake_node = self.generate(batch_size) - real_logits = self.discriminator(real_adj, real_node) - fake_logits = self.discriminator(fake_adj.detach(), fake_node.detach()) + # Real loss + d_real = self.discriminator(real_adj, real_node) + d_loss_real = F.binary_cross_entropy_with_logits(d_real, torch.ones_like(d_real)) + + # Fake loss + with torch.no_grad(): + fake_adj, fake_node = self.generator(z) + d_fake = self.discriminator(fake_adj, fake_node) + d_loss_fake = F.binary_cross_entropy_with_logits(d_fake, torch.zeros_like(d_fake)) - d_loss = -torch.mean(real_logits) + torch.mean(fake_logits) + d_loss = d_loss_real + d_loss_fake d_loss.backward() self.dis_opt.step() # === Train Generator === self.gen_opt.zero_grad() - fake_logits = self.discriminator(fake_adj, fake_node) - g_loss = -torch.mean(fake_logits) - - # === Add reward loss if applicable === - if self.use_reward: - rewards = self.compute_rewards(adj=fake_adj, node=fake_node) # [B] - reward_loss = -rewards.mean() - g_loss += reward_scale * reward_loss - reward_vals.append(rewards.mean().item()) + fake_adj, fake_node = self.generator(z) + d_fake = self.discriminator(fake_adj, fake_node) + + g_adv_loss = F.binary_cross_entropy_with_logits(d_fake, torch.ones_like(d_fake)) + + # Reward-guided loss (optional) + if self.use_reward and self.reward is not None: + with torch.no_grad(): + rwd = self.reward(fake_adj, fake_node) # [B] + g_rwd_loss = -rwd.mean() else: - reward_vals.append(0.0) + g_rwd_loss = 0.0 + g_loss = g_adv_loss + g_rwd_loss g_loss.backward() self.gen_opt.step() - d_losses.append(d_loss.item()) - g_losses.append(g_loss.item()) - - if (epoch + 1) % log_every == 0: - print(f"[Epoch {epoch+1}/{epochs}] " - f"D_loss: {sum(d_losses)/len(d_losses):.4f} | " - f"G_loss: {sum(g_losses)/len(g_losses):.4f} | " - f"Reward: {sum(reward_vals)/len(reward_vals):.4f}") - + epoch_d_loss.append(d_loss.item()) + epoch_g_loss.append(g_loss.item()) - def fit( - self, - data_loader, - epochs: int = 10, - log_every: int = 1, - reward_scale: float = 1.0 - ): - """ - Train the MolGAN model using adversarial and (optional) reward-based learning. + print(f"[Epoch {epoch}/{epochs}] D_loss: {np.mean(epoch_d_loss):.4f} | G_loss: {np.mean(epoch_g_loss):.4f}") - Parameters - ---------- - data_loader : DataLoader - DataLoader yielding batches of {"adj", "node", "smiles"} dictionaries. + return self - epochs : int - Number of training epochs. - log_every : int - Frequency of logging losses. - - reward_scale : float - Weight of reward loss in the generator's total loss. + def generate(self, n_samples: int, **kwargs) -> List[str]: """ - self.generator.train() - self.discriminator.train() - if self.reward and hasattr(self.reward, 'neural') and self.reward.neural: - self.reward.neural.eval() - - for epoch in range(epochs): - d_losses, g_losses, reward_vals = [], [], [] + Generate molecules from random latent vectors. - for batch in data_loader: - real_adj = batch["adj"].to(self.device) - real_node = batch["node"].to(self.device) - smiles = batch.get("smiles", None) - - batch_size = real_adj.size(0) - - # === Train Discriminator === - self.dis_opt.zero_grad() - fake_adj, fake_node = self.generate(batch_size) - - real_logits = self.discriminator(real_adj, real_node) - fake_logits = self.discriminator(fake_adj.detach(), fake_node.detach()) - - d_loss = -torch.mean(real_logits) + torch.mean(fake_logits) - d_loss.backward() - self.dis_opt.step() - - # === Train Generator === - self.gen_opt.zero_grad() - fake_logits = self.discriminator(fake_adj, fake_node) - g_loss = -torch.mean(fake_logits) - - # === Add reward loss if applicable === - if self.use_reward: - # Convert fake graphs to SMILES if reward expects SMILES - if self.reward.oracle is not None: - smiles_fake = self.decode_smiles(fake_adj, fake_node) - rewards = self.compute_rewards(smiles_list=smiles_fake) - else: - rewards = self.compute_rewards(adj=fake_adj, node=fake_node) - - reward_loss = -rewards.mean() - g_loss += reward_scale * reward_loss - reward_vals.append(rewards.mean().item()) - else: - reward_vals.append(0.0) - - g_loss.backward() - self.gen_opt.step() - - d_losses.append(d_loss.item()) - g_losses.append(g_loss.item()) - - if (epoch + 1) % log_every == 0: - print(f"[Epoch {epoch+1}/{epochs}] " - f"D_loss: {sum(d_losses)/len(d_losses):.4f} | " - f"G_loss: {sum(g_losses)/len(g_losses):.4f} | " - f"Reward: {sum(reward_vals)/len(reward_vals):.4f}") + Returns + ------- + List[str] : Valid SMILES strings + """ + self.generator.eval() + with torch.no_grad(): + z = torch.randn(n_samples, self.latent_dim).to(self.device) + adj, node = self.generator(z) + smiles = decode_smiles(adj, node) + return [s for s in smiles if s is not None] diff --git a/torch_molecule/generator/molgan/generator.py b/torch_molecule/generator/molgan/generator.py index d8d559d..a03e75f 100644 --- a/torch_molecule/generator/molgan/generator.py +++ b/torch_molecule/generator/molgan/generator.py @@ -3,8 +3,9 @@ import torch.nn.functional as F - -class MolGANConfig: +# Including MolGANGeneratorConfig to allow for better instantiation +# if required for future iternations +class MolGANGeneratorConfig: """ Configuration class for MolGAN Generator and Discriminator. @@ -27,7 +28,7 @@ def __init__(self, self.tau = tau - +# MolGANGenerator class MolGANGenerator(nn.Module): """ @@ -40,16 +41,27 @@ class MolGANGenerator(nn.Module): Uses Gumbel-Softmax to approximate discrete molecular structure. """ - def __init__(self, config): + def __init__(self, + latent_dim=56, + hidden_dims=[128, 128, 256], + num_nodes=9, + num_atom_types=5, + num_bond_types=4, + tau=1.0): super().__init__() - self.config = config + self.latent_dim = latent_dim + self.hidden_dims = hidden_dims + self.num_nodes = num_nodes + self.num_atom_types = num_atom_types + self.num_bond_types = num_bond_types + self.tau = tau - output_dim = (config.num_nodes * config.num_atom_types) + \ - (config.num_nodes * config.num_nodes * config.num_bond_types) + output_dim = (num_nodes * num_atom_types) + \ + (num_nodes * num_nodes * num_bond_types) layers = [] - input_dim = config.latent_dim - for hidden_dim in config.hidden_dims: + input_dim = latent_dim + for hidden_dim in hidden_dims: layers.append(nn.Linear(input_dim, hidden_dim)) layers.append(nn.ReLU()) input_dim = hidden_dim @@ -61,7 +73,7 @@ def forward(self, z): B = z.size(0) out = self.fc(z) - N, T, Y = self.config.num_nodes, self.config.num_atom_types, self.config.num_bond_types + N, T, Y = self.num_nodes, self.num_atom_types, self.num_bond_types node_size = N * T adj_size = N * N * Y @@ -70,7 +82,7 @@ def forward(self, z): adj = adj_flat.view(B, Y, N, N) # Gumbel-softmax - node = F.gumbel_softmax(node, tau=self.config.tau, hard=True, dim=-1) - adj = F.gumbel_softmax(adj, tau=self.config.tau, hard=True, dim=1) + node = F.gumbel_softmax(node, tau=self.tau, hard=True, dim=-1) + adj = F.gumbel_softmax(adj, tau=self.tau, hard=True, dim=1) return adj, node diff --git a/torch_molecule/generator/molgan/modeling_molgan.py b/torch_molecule/generator/molgan/modeling_molgan.py index 504b35c..eaa7a65 100644 --- a/torch_molecule/generator/molgan/modeling_molgan.py +++ b/torch_molecule/generator/molgan/modeling_molgan.py @@ -5,6 +5,7 @@ from .generator import MolGANGenerator from .discriminator import MolGANDiscriminator from .rewards import RewardNetwork +from ...utils.graph.graph_to_smiles import graph_to_smiles class MolGAN(nn.Module): @@ -44,17 +45,17 @@ def generate(self, batch_size): def compute_rewards(self, smiles_list=None, adj=None, node=None): if self.reward is None: if adj is None or node is None: - return torch.zeros(adj.size(0), device=self.device) + raise ValueError("Either smiles_list or (adj, node) must be provided for reward computation.") + return torch.zeros(adj.size(0), device=self.device) return torch.tensor([ self.reward(smiles=s) for s in smiles_list - ], dtype=torch.float32, device=self.device) if smiles_list else self.reward(adj=adj, node=node) + ], dtype=torch.float32, device=self.device) if smiles_list else self.reward(self.decode_smiles(adj, node)) def decode_smiles(self, adj, node): """ Convert batch of (adj, node) to SMILES strings This requires your graph_to_smiles function """ - from your_utils.graph_to_smiles import graph_to_smiles graphs = list(zip(node.cpu().numpy(), adj.cpu().numpy())) return graph_to_smiles(graphs, atom_decoder=["C", "N", "O", "F"]) @@ -64,7 +65,6 @@ def fit(self, data_loader, epochs=10, log_every=1): for batch in data_loader: real_adj = batch["adj"].to(self.device) real_node = batch["node"].to(self.device) - real_smiles = batch.get("smiles", None) # === Train Discriminator === self.dis_opt.zero_grad() diff --git a/torch_molecule/generator/molgan/rewards.py b/torch_molecule/generator/molgan/rewards.py index c0abfc8..ebf7d96 100644 --- a/torch_molecule/generator/molgan/rewards.py +++ b/torch_molecule/generator/molgan/rewards.py @@ -52,7 +52,12 @@ class RewardNeuralNetwork(nn.Module): Reward Network that predicts reward from (adj, node) graphs. """ - def __init__(self, num_atom_types=5, num_bond_types=4, hidden_dim=128, num_layers=2, num_nodes=9): + def __init__(self, + num_atom_types=5, + num_bond_types=4, + hidden_dim=128, + num_layers=2, + num_nodes=9): super().__init__() self.gcn_layers = nn.ModuleList() self.gcn_layers.append(RelationalGCNLayer(num_atom_types, hidden_dim, num_bond_types)) @@ -163,6 +168,8 @@ def __init__( if kind in ["qed", "logp", "combo"]: self.oracle = RewardOracle(kind) self.neural = None + if reward_net is not None: + raise ValueError("reward_net should not be provided for oracle modes") elif kind == "neural": assert reward_net is not None, "reward_net must be provided for 'neural' mode" self.oracle = None @@ -183,7 +190,7 @@ def __call__(self, adj: torch.Tensor, node: torch.Tensor) -> torch.Tensor: ------- Tensor [B] : reward per sample """ - if self.neural: + if self.neural is not None: with torch.no_grad(): return self.neural(adj.to(self.device), node.to(self.device)) From e671553737020c0c058fe471872901a02e4aeaa9 Mon Sep 17 00:00:00 2001 From: Manda Kausthubh Date: Wed, 9 Jul 2025 21:46:17 +0530 Subject: [PATCH 07/14] Almost complete implementation --- torch_molecule/generator/molgan/dataset.py | 32 +++- torch_molecule/generator/molgan/gan.py | 72 +++++++-- torch_molecule/generator/molgan/gan_utils.py | 143 +++++++---------- .../generator/molgan/modeling_molgan.py | 103 ------------- .../molgan/{rewards.py => rewards_molgan.py} | 145 ++++++++++-------- 5 files changed, 225 insertions(+), 270 deletions(-) rename torch_molecule/generator/molgan/{rewards.py => rewards_molgan.py} (54%) diff --git a/torch_molecule/generator/molgan/dataset.py b/torch_molecule/generator/molgan/dataset.py index 1626513..09ee873 100644 --- a/torch_molecule/generator/molgan/dataset.py +++ b/torch_molecule/generator/molgan/dataset.py @@ -1,9 +1,9 @@ from torch.utils.data import Dataset import torch -from typing import List, Optional, Callable +from typing import List, Optional from rdkit import Chem -from .rewards import RewardNetwork +from .rewards_molgan import RewardNetwork from .gan_utils import encode_smiles_to_graph from ...utils.graph.graph_from_smiles import graph_from_smiles @@ -99,3 +99,31 @@ def molgan_collate_fn(batch): smiles = [item["smiles"] for item in batch] return {"adj": adj, "node": node, "reward": reward, "smiles": smiles} + + +class MolGraphRewardDataset(torch.utils.data.Dataset): + def __init__(self, smiles_list, reward_fn, max_nodes=9): + self.samples = [] + for smiles in smiles_list: + adj_node = encode_smiles_to_graph(smiles, max_nodes=max_nodes) + if adj_node is None: + continue + adj, node = adj_node + reward = reward_fn(smiles) + self.samples.append({ + "adj": adj, + "node": node, + "reward": reward + }) + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + sample = self.samples[idx] + return { + "adj": sample["adj"], + "node": sample["node"], + "reward": torch.tensor(sample["reward"], dtype=torch.float32) + } + diff --git a/torch_molecule/generator/molgan/gan.py b/torch_molecule/generator/molgan/gan.py index b7b0b74..2866d12 100644 --- a/torch_molecule/generator/molgan/gan.py +++ b/torch_molecule/generator/molgan/gan.py @@ -1,22 +1,21 @@ import torch +import os, json +import numpy as np +import warnings import torch.nn.functional as F from dataclasses import dataclass from typing import Optional, List +from dataclasses import field from torch_molecule.base.generator import BaseMolecularGenerator # If for future compatibility, do ensure Configs are imported from .generator import MolGANGenerator from .discriminator import MolGANDiscriminator -from .rewards import RewardNetwork -from .gan_utils import decode_smiles, encode_smiles_to_graph +from .rewards_molgan import RewardOracle +from .gan_utils import decode_smiles_from_graph from .dataset import MolGraphDataset, molgan_collate_fn -from ...utils.graph.graph_to_smiles import graph_to_smiles -from typing import List, Optional -from dataclasses import field -import numpy as np -import torch # The actual MolGAN implementation @dataclass @@ -35,7 +34,6 @@ def __init__( num_atom_types: int = 5, num_bond_types: int = 4, use_reward: bool = False, - reward_network: Optional[RewardNetwork] = None, device: Optional[str] = None ): super().__init__() @@ -47,7 +45,6 @@ def __init__( self.num_atom_types = num_atom_types self.num_bond_types = num_bond_types self.use_reward = use_reward - self.reward = reward_network self.tau = tau self.device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu")) @@ -75,8 +72,9 @@ def fit( self, X: List[str], y: Optional[np.ndarray] = None, + reward: Optional[RewardOracle] = None, epochs: int = 10, - batch_size: int = 32 + batch_size: int = 32, ) -> "MolGAN": """ Fit the MolGAN model to a list of SMILES strings. @@ -101,11 +99,14 @@ def fit( The trained model. """ + if y is not None: + warnings.warn("y is not used in MolGAN training. Use reward function instead.") + from torch.utils.data import DataLoader dataset = MolGraphDataset( smiles_list=X, - reward_function=self.reward if self.use_reward else None, + reward_function = reward if self.use_reward else None, max_nodes=self.num_nodes ) @@ -125,9 +126,9 @@ def fit( epoch_g_loss = [] for batch in train_loader: - real_adj = batch["adj"].to(self.device) # [B, Y, N, N] - real_node = batch["node"].to(self.device) # [B, N, T] - real_reward = batch["reward"].to(self.device) # [B] + real_adj = batch["adj"].to(self.device) + real_node = batch["node"].to(self.device) + real_reward = batch["reward"].to(self.device) batch_size_actual = real_adj.size(0) z = torch.randn(batch_size_actual, self.latent_dim).to(self.device) @@ -157,9 +158,9 @@ def fit( g_adv_loss = F.binary_cross_entropy_with_logits(d_fake, torch.ones_like(d_fake)) # Reward-guided loss (optional) - if self.use_reward and self.reward is not None: + if self.use_reward and reward is not None: with torch.no_grad(): - rwd = self.reward(fake_adj, fake_node) # [B] + rwd = reward(fake_adj, fake_node) # [B] g_rwd_loss = -rwd.mean() else: g_rwd_loss = 0.0 @@ -188,6 +189,43 @@ def generate(self, n_samples: int, **kwargs) -> List[str]: with torch.no_grad(): z = torch.randn(n_samples, self.latent_dim).to(self.device) adj, node = self.generator(z) - smiles = decode_smiles(adj, node) + smiles = decode_smiles_from_graph(adj, node) + if smiles is None: + return [] return [s for s in smiles if s is not None] + + # Initial implementation for saving and loading the model + def save_pretrained(self, save_directory: str, configfile: Optional[str] = None): + os.makedirs(save_directory, exist_ok=True) + if configfile is None: + configfile = "config.json" + + # Save model weights + torch.save(self.generator.state_dict(), os.path.join(save_directory, "generator.pt")) + torch.save(self.discriminator.state_dict(), os.path.join(save_directory, "discriminator.pt")) + + # Save config + config = { + "latent_dim": self.latent_dim, + "hidden_dims_gen": self.hidden_dims_gen, + "hidden_dims_disc": self.hidden_dims_disc, + "num_nodes": self.num_nodes, + "num_atom_types": self.num_atom_types, + "num_bond_types": self.num_bond_types, + "tau": self.tau, + "use_reward": self.use_reward + } + with open(os.path.join(save_directory, configfile), "w") as f: + json.dump(config, f) + + @classmethod + def from_pretrained(cls, load_directory: str, device: Optional[str] = None, configfile: str = "config.json") -> "MolGAN": + with open(os.path.join(load_directory, configfile)) as f: + config = json.load(f) + + model = cls(**config, device=device) + model.generator.load_state_dict(torch.load(os.path.join(load_directory, "generator.pt"), map_location=device)) + model.discriminator.load_state_dict(torch.load(os.path.join(load_directory, "discriminator.pt"), map_location=device)) + return model + diff --git a/torch_molecule/generator/molgan/gan_utils.py b/torch_molecule/generator/molgan/gan_utils.py index 153286e..6cd9208 100644 --- a/torch_molecule/generator/molgan/gan_utils.py +++ b/torch_molecule/generator/molgan/gan_utils.py @@ -4,10 +4,12 @@ import torch.nn.functional as F import numpy as np from rdkit import Chem -from ...utils.graph.graph_to_smiles import graph_to_smiles -from ...utils.graph.graph_from_smiles import graph_from_smiles - - +from ...utils.graph.graph_to_smiles import ( + build_molecule_with_partial_charges, + correct_mol, + mol2smiles, + get_mol +) class RelationalGCNLayer(nn.Module): @@ -96,100 +98,69 @@ def encode_smiles_to_graph( return torch.tensor(adj), torch.tensor(node) -def decode_smiles( - adj: torch.Tensor, - node: torch.Tensor, - atom_decoder: list = ["C", "N", "O", "F"] - ) -> list: +ATOM_DECODER = ["C", "N", "O", "F"] # Adjust based on your vocabulary +BOND_DICT = [ + None, + Chem.rdchem.BondType.SINGLE, + Chem.rdchem.BondType.DOUBLE, + Chem.rdchem.BondType.TRIPLE, + Chem.rdchem.BondType.AROMATIC, +] + +def decode_smiles_from_graph( + adj: torch.Tensor, + node: torch.Tensor, + atom_decoder: Optional[list] = ATOM_DECODER +) -> Optional[str]: """ - Convert a batch of (adj, node) tensors to SMILES strings. + Converts (adj, node) graph back to a SMILES string. Parameters ---------- adj : torch.Tensor - Adjacency tensor of shape [B, Y, N, N] - + Tensor of shape [Y, N, N] with binary bond type edges. node : torch.Tensor - Node feature tensor of shape [B, N, T] - - atom_decoder : list of str - Atom types in order of one-hot encoding indices - - Returns - ------- - List[str or None] - Decoded SMILES strings or None for invalid molecules - """ - # Ensure tensors are detached and moved to CPU - adj_np = adj.detach().cpu().numpy() - node_np = node.detach().cpu().numpy() - - # Build molecule list - molecule_list = list(zip(node_np, adj_np)) - - # Decode into SMILES strings - smiles_list = graph_to_smiles(molecule_list, atom_decoder) - - return smiles_list - - - - - - - - - - - -def molgan_graph_from_smiles(smiles: str, atom_vocab: list, bond_types: list, max_nodes: int) -> Optional[dict]: - """ - Convert SMILES to MolGAN-style (adjacency, node) graph. - - Parameters - ---------- - smiles : str - SMILES string - - atom_vocab : list of str - List of allowed atom types (e.g., ['C', 'N', 'O', 'F']) - - bond_types : list of float - List of bond types (e.g., [1.0, 1.5, 2.0, 3.0]) - - max_nodes : int - Maximum number of atoms + Tensor of shape [N, T] with atom type softmax/one-hot. + atom_decoder : list + List mapping indices to atom symbols. Returns ------- - dict with keys: - 'adj': [Y, N, N] tensor - 'node': [N, T] tensor + Optional[str] + SMILES string if successful, None otherwise. """ - mol = Chem.MolFromSmiles(smiles) - if mol is None or mol.GetNumAtoms() > max_nodes: + try: + atom_types = node.argmax(dim=-1) # [N] + edge_types = torch.argmax(adj, dim=0) # [N, N], index of strongest bond type + + # Convert to RDKit Mol + mol_init = build_molecule_with_partial_charges(atom_types, edge_types, atom_decoder) + + # Try to correct connectivity and valency + for connection in (True, False): + mol_corr, _ = correct_mol(mol_init, connection=connection) + if mol_corr is not None: + break + else: + mol_corr = mol_init # fallback + + # Final sanitization + smiles = mol2smiles(mol_corr) + if not smiles: + smiles = Chem.MolToSmiles(mol_corr) + + # Canonicalize and return + mol = get_mol(smiles) + if mol is not None: + frags = Chem.rdmolops.GetMolFrags(mol, asMols=True, sanitizeFrags=False) + largest = max(frags, key=lambda m: m.GetNumAtoms()) + final_smiles = mol2smiles(largest) + return final_smiles if final_smiles and len(final_smiles) > 1 else None return None - T = len(atom_vocab) - Y = len(bond_types) + except Exception as e: + print(f"[MolGAN Decode] Error during decoding: {e}") + return None - node = np.zeros((max_nodes, T)) - for i, atom in enumerate(mol.GetAtoms()): - symbol = atom.GetSymbol() - if symbol in atom_vocab: - node[i, atom_vocab.index(symbol)] = 1 - adj = np.zeros((Y, max_nodes, max_nodes)) - for bond in mol.GetBonds(): - i = bond.GetBeginAtomIdx() - j = bond.GetEndAtomIdx() - btype = bond.GetBondTypeAsDouble() - if btype in bond_types: - k = bond_types.index(btype) - adj[k, i, j] = 1 - adj[k, j, i] = 1 - return { - "adj": torch.tensor(adj, dtype=torch.float32).unsqueeze(0), - "node": torch.tensor(node, dtype=torch.float32).unsqueeze(0) - } diff --git a/torch_molecule/generator/molgan/modeling_molgan.py b/torch_molecule/generator/molgan/modeling_molgan.py index eaa7a65..e69de29 100644 --- a/torch_molecule/generator/molgan/modeling_molgan.py +++ b/torch_molecule/generator/molgan/modeling_molgan.py @@ -1,103 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - -from .generator import MolGANGenerator -from .discriminator import MolGANDiscriminator -from .rewards import RewardNetwork -from ...utils.graph.graph_to_smiles import graph_to_smiles - - -class MolGAN(nn.Module): - """ - Full MolGAN model integrating: - - Generator - - Discriminator - - Reward Network (oracle or neural) - """ - - def __init__( - self, - generator_config, - discriminator_config, - reward_config, - use_reward=True, - reward_lambda=1.0, - device="cpu" - ): - super().__init__() - self.device = device - self.use_reward = use_reward - self.reward_lambda = reward_lambda - - self.generator = MolGANGenerator(generator_config).to(device) - self.discriminator = MolGANDiscriminator(discriminator_config).to(device) - self.reward = RewardNetwork(**reward_config) if use_reward else None - - self.gen_opt = torch.optim.Adam(self.generator.parameters(), lr=generator_config.get("lr", 1e-3)) - self.dis_opt = torch.optim.Adam(self.discriminator.parameters(), lr=discriminator_config.get("lr", 1e-3)) - - def generate(self, batch_size): - z = torch.randn(batch_size, self.generator.latent_dim).to(self.device) - adj, node = self.generator(z) - return adj, node - - def compute_rewards(self, smiles_list=None, adj=None, node=None): - if self.reward is None: - if adj is None or node is None: - raise ValueError("Either smiles_list or (adj, node) must be provided for reward computation.") - return torch.zeros(adj.size(0), device=self.device) - return torch.tensor([ - self.reward(smiles=s) for s in smiles_list - ], dtype=torch.float32, device=self.device) if smiles_list else self.reward(self.decode_smiles(adj, node)) - - def decode_smiles(self, adj, node): - """ - Convert batch of (adj, node) to SMILES strings - This requires your graph_to_smiles function - """ - graphs = list(zip(node.cpu().numpy(), adj.cpu().numpy())) - return graph_to_smiles(graphs, atom_decoder=["C", "N", "O", "F"]) - - def fit(self, data_loader, epochs=10, log_every=1): - for epoch in range(epochs): - epoch_d_loss, epoch_g_loss, epoch_rewards = [], [], [] - for batch in data_loader: - real_adj = batch["adj"].to(self.device) - real_node = batch["node"].to(self.device) - - # === Train Discriminator === - self.dis_opt.zero_grad() - fake_adj, fake_node = self.generate(real_adj.size(0)) - - real_logits = self.discriminator(real_adj, real_node) - fake_logits = self.discriminator(fake_adj.detach(), fake_node.detach()) - d_loss = -torch.mean(real_logits) + torch.mean(fake_logits) - d_loss.backward() - self.dis_opt.step() - - # === Train Generator === - self.gen_opt.zero_grad() - fake_logits = self.discriminator(fake_adj, fake_node) - g_loss = -torch.mean(fake_logits) - - # Add reward loss if enabled - if self.use_reward: - smiles_fake = self.decode_smiles(fake_adj, fake_node) - rewards = self.compute_rewards(smiles_list=smiles_fake) - reward_loss = -torch.mean(rewards) - g_loss = g_loss + self.reward_lambda * reward_loss - else: - rewards = torch.zeros(real_adj.size(0)) - - g_loss.backward() - self.gen_opt.step() - - epoch_d_loss.append(d_loss.item()) - epoch_g_loss.append(g_loss.item()) - epoch_rewards.append(rewards.mean().item()) - - if (epoch + 1) % log_every == 0: - print(f"Epoch [{epoch + 1}/{epochs}] D_loss: {sum(epoch_d_loss)/len(epoch_d_loss):.4f}, " - f"G_loss: {sum(epoch_g_loss)/len(epoch_g_loss):.4f}, " - f"Reward: {sum(epoch_rewards)/len(epoch_rewards):.4f}") diff --git a/torch_molecule/generator/molgan/rewards.py b/torch_molecule/generator/molgan/rewards_molgan.py similarity index 54% rename from torch_molecule/generator/molgan/rewards.py rename to torch_molecule/generator/molgan/rewards_molgan.py index ebf7d96..83261b1 100644 --- a/torch_molecule/generator/molgan/rewards.py +++ b/torch_molecule/generator/molgan/rewards_molgan.py @@ -1,9 +1,9 @@ -from typing import Optional +from typing import List, Optional import torch import torch.nn as nn from rdkit import Chem from rdkit.Chem import QED, Crippen, rdMolDescriptors -from .gan_utils import RelationalGCNLayer, molgan_graph_from_smiles +from .gan_utils import RelationalGCNLayer from ...utils.graph.graph_to_smiles import graph_to_smiles @@ -28,7 +28,7 @@ def combo_reward(smiles: str, weights=(0.7, 0.3)) -> float: logp_score = Crippen.MolLogP(mol) return weights[0] * qed_score + weights[1] * logp_score -class RewardOracle: +class RewardOracleNonNeural: def __init__(self, kind="qed"): if kind == "qed": self.func = qed_reward @@ -55,20 +55,21 @@ class RewardNeuralNetwork(nn.Module): def __init__(self, num_atom_types=5, num_bond_types=4, - hidden_dim=128, - num_layers=2, + hidden_dims:List[int]=[128, 128], num_nodes=9): super().__init__() self.gcn_layers = nn.ModuleList() - self.gcn_layers.append(RelationalGCNLayer(num_atom_types, hidden_dim, num_bond_types)) + self.gcn_layers.append(RelationalGCNLayer(num_atom_types, hidden_dims[0], num_bond_types)) - for _ in range(1, num_layers): - self.gcn_layers.append(RelationalGCNLayer(hidden_dim, hidden_dim, num_bond_types)) + current_dim = hidden_dims[0] + self.num_layers = len(hidden_dims) + for i in range(1, self.num_layers): + self.gcn_layers.append(RelationalGCNLayer(current_dim, hidden_dims[i], num_bond_types)) self.readout = nn.Sequential( - nn.Linear(num_nodes * hidden_dim, hidden_dim), + nn.Linear(num_nodes * hidden_dims[-1], hidden_dims[-1]), nn.ReLU(), - nn.Linear(hidden_dim, 1) + nn.Linear(hidden_dims[-1], 1) ) def forward(self, adj, node): @@ -84,74 +85,72 @@ def forward(self, adj, node): return self.readout(h).squeeze(-1) -def fit_reward_network( - reward_model: RewardNeuralNetwork, - train_loader, - epochs: int = 10, - lr: float = 1e-3, - weight_decay: float = 0.0, - device: str = "cpu", - verbose: bool = True -): - """ - Train the reward model to approximate oracle rewards. + def fit( + self, + train_loader, + epochs: int = 10, + lr: float = 1e-3, + weight_decay: float = 0.0, + verbose: bool = True + ): + """ + Train the reward self to approximate oracle rewards. - Parameters - ---------- - reward_model : RewardNeuralNetwork - The neural network to train + Parameters + ---------- + reward_self : RewardNeuralNetwork + The neural network to train - train_loader : DataLoader - Yields batches of (adj, node, reward) + train_loader : DataLoader + Yields batches of (adj, node, reward) - epochs : int - Number of training epochs + epochs : int + Number of training epochs - lr : float - Learning rate + lr : float + Learning rate - weight_decay : float - Optional L2 regularization + weight_decay : float + Optional L2 regularization - device : str - Device to run on ("cpu" or "cuda") + device : str + Device to run on ("cpu" or "cuda") - verbose : bool - Whether to print losses - """ - model = reward_model.to(device) - optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay) - criterion = nn.MSELoss() + verbose : bool + Whether to print losses + """ + optimizer = torch.optim.Adam(self.parameters(), lr=lr, weight_decay=weight_decay) + criterion = nn.MSELoss() - model.train() - for epoch in range(epochs): - epoch_losses = [] + self.train() + for epoch in range(epochs): + epoch_losses = [] - for batch in train_loader: - adj = batch["adj"].to(device) - node = batch["node"].to(device) - reward = batch["reward"].to(device) + for batch in train_loader: + adj = batch["adj"].to(self.device) + node = batch["node"].to(self.device) + reward = batch["reward"].to(self.device) - pred = model(adj, node) # [B] - loss = criterion(pred, reward) + pred = self(adj, node) # [B] + loss = criterion(pred, reward) - optimizer.zero_grad() - loss.backward() - optimizer.step() + optimizer.zero_grad() + loss.backward() + optimizer.step() - epoch_losses.append(loss.item()) + epoch_losses.append(loss.item()) - if verbose: - print(f"[Epoch {epoch+1}/{epochs}] RewardNet Loss: {sum(epoch_losses)/len(epoch_losses):.4f}") + if verbose: + print(f"[Epoch {epoch+1}/{epochs}] RewardNet Loss: {sum(epoch_losses)/len(epoch_losses):.4f}") # Combined reward wrapper: which uses either neural or oracle rewards -class RewardNetwork: +class RewardOracle: """ - Combined reward network that uses either a neural model or an oracle. + Combined reward network that uses either a neural self or an oracle. Accepts (adj, node) tensors as standard input. """ @@ -159,14 +158,32 @@ def __init__( self, kind="qed", reward_net: Optional[RewardNeuralNetwork] = None, - atom_decoder=None, - device="cpu"): + atom_decoder: List[str]=["C", "N", "O", "F"], + device="cpu" + ): + """ + Parameters + ---------- + kind : str + Type of reward function to use. Options: + - "qed": QED score + - "logp": LogP score + - "combo": Combination of QED and LogP + - "neural": Use a neural network for reward prediction + reward_net : RewardNeuralNetwork, optional + If kind is "neural", this should be a trained RewardNeuralNetwork instance. + atom_decoder : list of str, Optional + If kind is "qed", "logp", or "combo", this should be a list of atom types to decode graphs. + Defaults to ["C", "N", "O", "F"]. + device : str + Device to run the reward computation on ("cpu" or "cuda"). + """ self.kind = kind self.device = device self.atom_decoder = atom_decoder or ["C", "N", "O", "F"] if kind in ["qed", "logp", "combo"]: - self.oracle = RewardOracle(kind) + self.oracle = RewardOracleNonNeural(kind) self.neural = None if reward_net is not None: raise ValueError("reward_net should not be provided for oracle modes") @@ -177,7 +194,11 @@ def __init__( else: raise ValueError(f"Invalid kind: {kind}") - def __call__(self, adj: torch.Tensor, node: torch.Tensor) -> torch.Tensor: + def __call__( + self, + adj: torch.Tensor, + node: torch.Tensor + ) -> torch.Tensor: """ Compute reward from graph tensors. From b1f73dc38104c306ccf5c5ef044dcaf3a793b1a9 Mon Sep 17 00:00:00 2001 From: Manda Kausthubh Date: Wed, 9 Jul 2025 22:17:52 +0530 Subject: [PATCH 08/14] Finished writing tests --- tests/generator/molgan.py | 66 +++++++++++++++++++++ torch_molecule/generator/molgan/__init__.py | 16 +++++ torch_molecule/generator/molgan/gan.py | 20 +++++++ 3 files changed, 102 insertions(+) create mode 100644 tests/generator/molgan.py diff --git a/tests/generator/molgan.py b/tests/generator/molgan.py new file mode 100644 index 0000000..a3b3fc1 --- /dev/null +++ b/tests/generator/molgan.py @@ -0,0 +1,66 @@ +import os +import numpy as np +import torch +from rdkit import RDLogger +from torch_molecule.generator.molgan import MolGAN, RewardOracle + +RDLogger.DisableLog("rdApp.*") + +def test_molgan(): + # Sample SMILES list + smiles_list = [ + "CCO", "CCN", "CCC", "COC", + "CCCl", "CCF", "CBr", "CN(C)C", "CC(=O)O", "c1ccccc1", + 'CNC[C@H]1OCc2cnnn2CCCC(=O)N([C@H](C)CO)C[C@@H]1C', + 'CNC[C@@H]1OCc2cnnn2CCCC(=O)N([C@H](C)CO)C[C@H]1C', + 'C[C@H]1CN([C@@H](C)CO)C(=O)CCCn2cc(nn2)CO[C@@H]1CN(C)C(=O)CCC(F)(F)F', + 'CC1=CC=C(C=C1)C2=CC(=NN2C3=CC=C(C=C3)S(=O)(=O)N)C(F)(F)F' + ] + + # 1. Initialize MolGAN + print("\n=== Testing MolGAN Initialization ===") + GANConfig = { + "num_nodes": 9, + "num_layers": 4, + "num_atom_types": 5, + "num_bond_types": 4, + "latent_dim": 56, + "hidden_dims_gen": [128, 128], + "hidden_dims_disc": [128, 128], + "tau": 1.0, + } + model = MolGAN(**GANConfig, device="cpu") + print("MolGAN initialized successfully") + + # 2. Fit with QED reward + print("\n=== Testing MolGAN Training with QED Reward ===") + reward = RewardOracle(kind="qed") + model.fit(X=smiles_list, reward=reward, epochs=5, batch_size=16) + print("MolGAN trained successfully") + + # 3. Generation + print("\n=== Testing MolGAN Generation ===") + gen_smiles = model.generate(n_samples=10) + print(f"Generated {len(gen_smiles)} SMILES") + print("Example generated molecules:", gen_smiles[:3]) + + # 4. Save and Reload + print("\n=== Testing MolGAN Save & Load ===") + save_dir = "molgan-test" + model.save_pretrained(save_dir) + print(f"Model saved to {save_dir}") + + model2 = MolGAN.from_pretrained(save_dir) + print("Model loaded successfully") + + gen_smiles2 = model2.generate(n_samples=5) + print("Generated after loading:", gen_smiles2[:3]) + + # 5. Cleanup + import shutil + if os.path.exists(save_dir): + shutil.rmtree(save_dir) + print(f"Cleaned up {save_dir}") + +if __name__ == "__main__": + test_molgan() diff --git a/torch_molecule/generator/molgan/__init__.py b/torch_molecule/generator/molgan/__init__.py index 8b13789..84c8d82 100644 --- a/torch_molecule/generator/molgan/__init__.py +++ b/torch_molecule/generator/molgan/__init__.py @@ -1 +1,17 @@ +from .dataset import MolGraphDataset, MolGraphRewardDataset, molgan_collate_fn +from .gan import MolGAN +from .generator import MolGANGenerator +from .discriminator import MolGANDiscriminator +from .rewards_molgan import RewardOracle, RewardOracleNonNeural, RewardNeuralNetwork +__all__ = [ + "MolGAN", + "MolGraphDataset", + "MolGraphRewardDataset", + "molgan_collate_fn", + "MolGANGenerator", + "MolGANDiscriminator", + "RewardOracle", + "RewardOracleNonNeural", + "RewardNeuralNetwork" +] diff --git a/torch_molecule/generator/molgan/gan.py b/torch_molecule/generator/molgan/gan.py index 2866d12..bc4a7ac 100644 --- a/torch_molecule/generator/molgan/gan.py +++ b/torch_molecule/generator/molgan/gan.py @@ -229,3 +229,23 @@ def from_pretrained(cls, load_directory: str, device: Optional[str] = None, conf model.discriminator.load_state_dict(torch.load(os.path.join(load_directory, "discriminator.pt"), map_location=device)) return model + + def _setup_optimizers(self): + return self.gen_opt, self.dis_opt # Or return a scheduler if applicable + + def _train_epoch(self, train_loader, optimizer): + # Delegate to your existing training loop inside `fit()` + # If not reusable, just raise NotImplementedError + raise NotImplementedError("MolGAN does not use `_train_epoch`; training is handled in `fit()`") + + def _get_model_params(self, checkpoint=None): + return { + "latent_dim": self.latent_dim, + "hidden_dims_gen": self.hidden_dims_gen, + "hidden_dims_disc": self.hidden_dims_disc, + "num_nodes": self.num_nodes, + "tau": self.tau, + "num_atom_types": self.num_atom_types, + "num_bond_types": self.num_bond_types, + "use_reward": self.use_reward + } From c953abf68391532bd2a8bb68cc8d1179b291d2f1 Mon Sep 17 00:00:00 2001 From: Manda Kausthubh Date: Wed, 9 Jul 2025 23:10:23 +0530 Subject: [PATCH 09/14] Finished basic changes --- tests/generator/molgan.py | 9 ++++++--- torch_molecule/generator/molgan/gan.py | 7 +++++++ 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/tests/generator/molgan.py b/tests/generator/molgan.py index a3b3fc1..09e8265 100644 --- a/tests/generator/molgan.py +++ b/tests/generator/molgan.py @@ -1,8 +1,9 @@ import os -import numpy as np -import torch from rdkit import RDLogger -from torch_molecule.generator.molgan import MolGAN, RewardOracle +from torch_molecule.generator.molgan import ( + MolGAN, + RewardOracle, +) RDLogger.DisableLog("rdApp.*") @@ -16,6 +17,7 @@ def test_molgan(): 'C[C@H]1CN([C@@H](C)CO)C(=O)CCCn2cc(nn2)CO[C@@H]1CN(C)C(=O)CCC(F)(F)F', 'CC1=CC=C(C=C1)C2=CC(=NN2C3=CC=C(C=C3)S(=O)(=O)N)C(F)(F)F' ] + model_decoder = ["C", "N", "O", "F", "Cl", "Br"] # 1. Initialize MolGAN print("\n=== Testing MolGAN Initialization ===") @@ -28,6 +30,7 @@ def test_molgan(): "hidden_dims_gen": [128, 128], "hidden_dims_disc": [128, 128], "tau": 1.0, + "use_reward": True, } model = MolGAN(**GANConfig, device="cpu") print("MolGAN initialized successfully") diff --git a/torch_molecule/generator/molgan/gan.py b/torch_molecule/generator/molgan/gan.py index bc4a7ac..925fd56 100644 --- a/torch_molecule/generator/molgan/gan.py +++ b/torch_molecule/generator/molgan/gan.py @@ -102,6 +102,13 @@ def fit( if y is not None: warnings.warn("y is not used in MolGAN training. Use reward function instead.") + if reward is not None and reward.kind is "neural": + if len(reward.atom_decoder) != self.num_atom_types: + raise ValueError( + f"Reward network atom decoder size {len(reward.atom_decoder)} does not match model's num_atom_types {self.num_atom_types}" + ) + + from torch.utils.data import DataLoader dataset = MolGraphDataset( From db68dc5924626e6a7e0b0cff0538273dcf2f049e Mon Sep 17 00:00:00 2001 From: Manda Kausthubh Date: Fri, 26 Sep 2025 21:31:20 +0530 Subject: [PATCH 10/14] Complete clean-up --- torch_molecule/generator/molgan/__init__.py | 17 -- torch_molecule/generator/molgan/dataset.py | 129 --------- .../generator/molgan/discriminator.py | 107 -------- torch_molecule/generator/molgan/gan.py | 258 ------------------ torch_molecule/generator/molgan/gan_utils.py | 166 ----------- torch_molecule/generator/molgan/generator.py | 88 ------ .../generator/molgan/modeling_molgan.py | 0 .../generator/molgan/rewards_molgan.py | 225 --------------- 8 files changed, 990 deletions(-) delete mode 100644 torch_molecule/generator/molgan/dataset.py delete mode 100644 torch_molecule/generator/molgan/discriminator.py delete mode 100644 torch_molecule/generator/molgan/gan.py delete mode 100644 torch_molecule/generator/molgan/gan_utils.py delete mode 100644 torch_molecule/generator/molgan/generator.py delete mode 100644 torch_molecule/generator/molgan/modeling_molgan.py delete mode 100644 torch_molecule/generator/molgan/rewards_molgan.py diff --git a/torch_molecule/generator/molgan/__init__.py b/torch_molecule/generator/molgan/__init__.py index 84c8d82..e69de29 100644 --- a/torch_molecule/generator/molgan/__init__.py +++ b/torch_molecule/generator/molgan/__init__.py @@ -1,17 +0,0 @@ -from .dataset import MolGraphDataset, MolGraphRewardDataset, molgan_collate_fn -from .gan import MolGAN -from .generator import MolGANGenerator -from .discriminator import MolGANDiscriminator -from .rewards_molgan import RewardOracle, RewardOracleNonNeural, RewardNeuralNetwork - -__all__ = [ - "MolGAN", - "MolGraphDataset", - "MolGraphRewardDataset", - "molgan_collate_fn", - "MolGANGenerator", - "MolGANDiscriminator", - "RewardOracle", - "RewardOracleNonNeural", - "RewardNeuralNetwork" -] diff --git a/torch_molecule/generator/molgan/dataset.py b/torch_molecule/generator/molgan/dataset.py deleted file mode 100644 index 09ee873..0000000 --- a/torch_molecule/generator/molgan/dataset.py +++ /dev/null @@ -1,129 +0,0 @@ -from torch.utils.data import Dataset -import torch -from typing import List, Optional -from rdkit import Chem - -from .rewards_molgan import RewardNetwork -from .gan_utils import encode_smiles_to_graph -from ...utils.graph.graph_from_smiles import graph_from_smiles - - -class MolGraphDataset(Dataset): - """ - Dataset for MolGAN: converts SMILES strings to graph format. - - Outputs a dict with: - - 'adj': [Y, N, N] adjacency tensor - - 'node': [N, T] node feature matrix - - 'reward': float (optional) - - 'smiles': original SMILES (optional) - """ - - def __init__(self, - smiles_list: List[str], - reward_function: Optional[RewardNetwork] = None, - max_nodes: int = 9, - drop_invalid: bool = True): - """ - Parameters - ---------- - smiles_list : List[str] - List of SMILES strings to convert into graph format. - - reward_function : Callable[[str], float], optional - If provided, computes a scalar reward per molecule (e.g., QED, logP). - Must accept a SMILES string and return a float. - - max_nodes : int - Maximum allowed number of atoms (molecules exceeding this are dropped). - - drop_invalid : bool - Whether to skip invalid or unparsable SMILES. - """ - self.samples = [] - - for smiles in smiles_list: - try: - mol = Chem.MolFromSmiles(smiles) - if mol is None: - raise ValueError("Invalid SMILES") - - if mol.GetNumAtoms() > max_nodes: - raise ValueError("Too many atoms") - - # Compute reward if needed - graph = encode_smiles_to_graph(smiles) - if graph is None: - raise ValueError("Failed to encode SMILES to graph") - adj, node = graph - reward = reward_function(node, adj) if reward_function else 0.0 - - # Convert to graph - graph = graph_from_smiles(smiles, properties=reward) - - # Sanity check - if 'adj' not in graph or 'node' not in graph: - raise ValueError("Incomplete graph data") - - graph['reward'] = reward - graph['smiles'] = smiles - self.samples.append(graph) - - except Exception as e: - if not drop_invalid: - self.samples.append({ - "adj": torch.zeros(1, max_nodes, max_nodes), - "node": torch.zeros(max_nodes, 1), - "reward": 0.0, - "smiles": smiles - }) - else: - print(f"[MolGraphDataset] Skipping SMILES {smiles}: {e}") - - def __len__(self): - return len(self.samples) - - def __getitem__(self, idx): - sample = self.samples[idx] - return { - "adj": torch.tensor(sample["adj"], dtype=torch.float32), - "node": torch.tensor(sample["node"], dtype=torch.float32), - "reward": torch.tensor(sample["reward"], dtype=torch.float32), - "smiles": sample["smiles"] - } - -def molgan_collate_fn(batch): - adj = torch.stack([item["adj"] for item in batch], dim=0) - node = torch.stack([item["node"] for item in batch], dim=0) - reward = torch.stack([item["reward"] for item in batch], dim=0) - smiles = [item["smiles"] for item in batch] - return {"adj": adj, "node": node, "reward": reward, "smiles": smiles} - - - -class MolGraphRewardDataset(torch.utils.data.Dataset): - def __init__(self, smiles_list, reward_fn, max_nodes=9): - self.samples = [] - for smiles in smiles_list: - adj_node = encode_smiles_to_graph(smiles, max_nodes=max_nodes) - if adj_node is None: - continue - adj, node = adj_node - reward = reward_fn(smiles) - self.samples.append({ - "adj": adj, - "node": node, - "reward": reward - }) - - def __len__(self): - return len(self.samples) - - def __getitem__(self, idx): - sample = self.samples[idx] - return { - "adj": sample["adj"], - "node": sample["node"], - "reward": torch.tensor(sample["reward"], dtype=torch.float32) - } - diff --git a/torch_molecule/generator/molgan/discriminator.py b/torch_molecule/generator/molgan/discriminator.py deleted file mode 100644 index 7773e39..0000000 --- a/torch_molecule/generator/molgan/discriminator.py +++ /dev/null @@ -1,107 +0,0 @@ -import torch.nn as nn -from .gan_utils import RelationalGCNLayer - - -class MolGANDiscriminatorConfig: - """ - Configuration class for MolGAN Discriminator. - - Stores architectural hyperparameters and allows modular configuration. - """ - - def __init__(self, - num_atom_types=5, - num_bond_types=4, - num_nodes=9, - hidden_dim=128, - num_layers=2): - """ - Parameters - ---------- - num_atom_types : int - Number of atom types in node features (input channels). - - num_bond_types : int - Number of bond types (number of relational edge types). - - num_nodes : int - Max number of nodes in the graph (used for flattening before readout). - - hidden_dim : int - Hidden dimension size for R-GCN layers. - - num_layers : int - Number of stacked R-GCN layers. - """ - self.num_atom_types = num_atom_types - self.num_bond_types = num_bond_types - self.num_nodes = num_nodes - self.hidden_dim = hidden_dim - self.num_layers = num_layers - - - - -class MolGANDiscriminator(nn.Module): - """ - Discriminator network for MolGAN using stacked Relational GCNs. - """ - - def __init__(self, - num_atom_types=5, - num_bond_types=4, - num_nodes=9, - hidden_dims=[128, 128]): - super().__init__() - - self.num_atom_types = num_atom_types - self.num_bond_types = num_bond_types - self.num_nodes = num_nodes - self.hidden_dim = hidden_dims - self.num_layers = len(hidden_dims) + 1 # I'm including the input layer - - self.gcn_layers = nn.ModuleList() - self.gcn_layers.append( - RelationalGCNLayer(num_atom_types, hidden_dims[0], num_bond_types) - ) - - # for _ in range(1, num_layers): - # self.gcn_layers.append( - # RelationalGCNLayer(hidden_dims, hidden_dims, num_bond_types) - # ) - - input_dim = hidden_dims[0] - for hidden_dim in hidden_dims[1:]: - self.gcn_layers.append( - RelationalGCNLayer(input_dim, hidden_dim, num_bond_types) - ) - input_dim = hidden_dim - - self.readout = nn.Sequential( - nn.Linear(num_nodes * hidden_dims[-1], hidden_dims[-1]), - nn.ReLU(), - nn.Linear(hidden_dims[-1], 1) - ) - - def forward(self, adj, node): - """ - Parameters: - adj: Tensor of shape [B, Y, N, N] -- adjacency tensor - node: Tensor of shape [B, N, T] -- one-hot or softmax node features - - Returns: - Tensor of shape [B] with real/fake logits - """ - h = node - for gcn in self.gcn_layers: - h = gcn(adj, h) - - h = h.view(h.size(0), -1) - return self.readout(h).squeeze(-1) - - - - - - - diff --git a/torch_molecule/generator/molgan/gan.py b/torch_molecule/generator/molgan/gan.py deleted file mode 100644 index 925fd56..0000000 --- a/torch_molecule/generator/molgan/gan.py +++ /dev/null @@ -1,258 +0,0 @@ -import torch -import os, json -import numpy as np -import warnings -import torch.nn.functional as F -from dataclasses import dataclass -from typing import Optional, List -from dataclasses import field - -from torch_molecule.base.generator import BaseMolecularGenerator - -# If for future compatibility, do ensure Configs are imported -from .generator import MolGANGenerator -from .discriminator import MolGANDiscriminator -from .rewards_molgan import RewardOracle -from .gan_utils import decode_smiles_from_graph -from .dataset import MolGraphDataset, molgan_collate_fn - - -# The actual MolGAN implementation -@dataclass -class MolGAN(BaseMolecularGenerator): - """MolGAN implementation compatible with BaseMolecularGenerator interface.""" - - model_name: str = field(default="MolGAN") - - def __init__( - self, - latent_dim: int = 56, - hidden_dims_gen: List[int] = [128,128], - hidden_dims_disc: List[int] = [128, 128], - num_nodes: int = 9, - tau: float = 1.0, - num_atom_types: int = 5, - num_bond_types: int = 4, - use_reward: bool = False, - device: Optional[str] = None - ): - super().__init__() - - self.latent_dim = latent_dim - self.hidden_dims_gen = hidden_dims_gen - self.hidden_dims_disc = hidden_dims_disc - self.num_nodes = num_nodes - self.num_atom_types = num_atom_types - self.num_bond_types = num_bond_types - self.use_reward = use_reward - self.tau = tau - - self.device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu")) - - self.generator = MolGANGenerator( - latent_dim=latent_dim, - hidden_dims=hidden_dims_gen, - num_nodes=num_nodes, - num_atom_types=num_atom_types, - num_bond_types=num_bond_types, - tau=tau - ).to(self.device) - - self.discriminator = MolGANDiscriminator( - hidden_dims=hidden_dims_disc, - num_nodes=num_nodes, - num_atom_types=num_atom_types, - num_bond_types=num_bond_types - ).to(self.device) - - self.gen_opt = torch.optim.Adam(self.generator.parameters(), lr=1e-4) - self.dis_opt = torch.optim.Adam(self.discriminator.parameters(), lr=1e-4) - - def fit( - self, - X: List[str], - y: Optional[np.ndarray] = None, - reward: Optional[RewardOracle] = None, - epochs: int = 10, - batch_size: int = 32, - ) -> "MolGAN": - """ - Fit the MolGAN model to a list of SMILES strings. - - Parameters - ---------- - X : List[str] - List of training SMILES strings. - - y : Optional[np.ndarray] - Optional reward targets. (Unused if using oracle or no reward) - - epochs : int - Number of training epochs. - - batch_size : int - Batch size for training. - - Returns - ------- - self : MolGAN - The trained model. - """ - - if y is not None: - warnings.warn("y is not used in MolGAN training. Use reward function instead.") - - if reward is not None and reward.kind is "neural": - if len(reward.atom_decoder) != self.num_atom_types: - raise ValueError( - f"Reward network atom decoder size {len(reward.atom_decoder)} does not match model's num_atom_types {self.num_atom_types}" - ) - - - from torch.utils.data import DataLoader - - dataset = MolGraphDataset( - smiles_list=X, - reward_function = reward if self.use_reward else None, - max_nodes=self.num_nodes - ) - - train_loader = DataLoader( - dataset, - batch_size=batch_size, - shuffle=True, - collate_fn=molgan_collate_fn, - drop_last=True - ) - - self.generator.train() - self.discriminator.train() - - for epoch in range(1, epochs + 1): - epoch_d_loss = [] - epoch_g_loss = [] - - for batch in train_loader: - real_adj = batch["adj"].to(self.device) - real_node = batch["node"].to(self.device) - real_reward = batch["reward"].to(self.device) - - batch_size_actual = real_adj.size(0) - z = torch.randn(batch_size_actual, self.latent_dim).to(self.device) - - # === Train Discriminator === - self.dis_opt.zero_grad() - - # Real loss - d_real = self.discriminator(real_adj, real_node) - d_loss_real = F.binary_cross_entropy_with_logits(d_real, torch.ones_like(d_real)) - - # Fake loss - with torch.no_grad(): - fake_adj, fake_node = self.generator(z) - d_fake = self.discriminator(fake_adj, fake_node) - d_loss_fake = F.binary_cross_entropy_with_logits(d_fake, torch.zeros_like(d_fake)) - - d_loss = d_loss_real + d_loss_fake - d_loss.backward() - self.dis_opt.step() - - # === Train Generator === - self.gen_opt.zero_grad() - fake_adj, fake_node = self.generator(z) - d_fake = self.discriminator(fake_adj, fake_node) - - g_adv_loss = F.binary_cross_entropy_with_logits(d_fake, torch.ones_like(d_fake)) - - # Reward-guided loss (optional) - if self.use_reward and reward is not None: - with torch.no_grad(): - rwd = reward(fake_adj, fake_node) # [B] - g_rwd_loss = -rwd.mean() - else: - g_rwd_loss = 0.0 - - g_loss = g_adv_loss + g_rwd_loss - g_loss.backward() - self.gen_opt.step() - - epoch_d_loss.append(d_loss.item()) - epoch_g_loss.append(g_loss.item()) - - print(f"[Epoch {epoch}/{epochs}] D_loss: {np.mean(epoch_d_loss):.4f} | G_loss: {np.mean(epoch_g_loss):.4f}") - - return self - - - def generate(self, n_samples: int, **kwargs) -> List[str]: - """ - Generate molecules from random latent vectors. - - Returns - ------- - List[str] : Valid SMILES strings - """ - self.generator.eval() - with torch.no_grad(): - z = torch.randn(n_samples, self.latent_dim).to(self.device) - adj, node = self.generator(z) - smiles = decode_smiles_from_graph(adj, node) - if smiles is None: - return [] - return [s for s in smiles if s is not None] - - - # Initial implementation for saving and loading the model - def save_pretrained(self, save_directory: str, configfile: Optional[str] = None): - os.makedirs(save_directory, exist_ok=True) - if configfile is None: - configfile = "config.json" - - # Save model weights - torch.save(self.generator.state_dict(), os.path.join(save_directory, "generator.pt")) - torch.save(self.discriminator.state_dict(), os.path.join(save_directory, "discriminator.pt")) - - # Save config - config = { - "latent_dim": self.latent_dim, - "hidden_dims_gen": self.hidden_dims_gen, - "hidden_dims_disc": self.hidden_dims_disc, - "num_nodes": self.num_nodes, - "num_atom_types": self.num_atom_types, - "num_bond_types": self.num_bond_types, - "tau": self.tau, - "use_reward": self.use_reward - } - with open(os.path.join(save_directory, configfile), "w") as f: - json.dump(config, f) - - @classmethod - def from_pretrained(cls, load_directory: str, device: Optional[str] = None, configfile: str = "config.json") -> "MolGAN": - with open(os.path.join(load_directory, configfile)) as f: - config = json.load(f) - - model = cls(**config, device=device) - model.generator.load_state_dict(torch.load(os.path.join(load_directory, "generator.pt"), map_location=device)) - model.discriminator.load_state_dict(torch.load(os.path.join(load_directory, "discriminator.pt"), map_location=device)) - return model - - - def _setup_optimizers(self): - return self.gen_opt, self.dis_opt # Or return a scheduler if applicable - - def _train_epoch(self, train_loader, optimizer): - # Delegate to your existing training loop inside `fit()` - # If not reusable, just raise NotImplementedError - raise NotImplementedError("MolGAN does not use `_train_epoch`; training is handled in `fit()`") - - def _get_model_params(self, checkpoint=None): - return { - "latent_dim": self.latent_dim, - "hidden_dims_gen": self.hidden_dims_gen, - "hidden_dims_disc": self.hidden_dims_disc, - "num_nodes": self.num_nodes, - "tau": self.tau, - "num_atom_types": self.num_atom_types, - "num_bond_types": self.num_bond_types, - "use_reward": self.use_reward - } diff --git a/torch_molecule/generator/molgan/gan_utils.py b/torch_molecule/generator/molgan/gan_utils.py deleted file mode 100644 index 6cd9208..0000000 --- a/torch_molecule/generator/molgan/gan_utils.py +++ /dev/null @@ -1,166 +0,0 @@ -from typing import Optional -import torch -import torch.nn as nn -import torch.nn.functional as F -import numpy as np -from rdkit import Chem -from ...utils.graph.graph_to_smiles import ( - build_molecule_with_partial_charges, - correct_mol, - mol2smiles, - get_mol -) - - -class RelationalGCNLayer(nn.Module): - def __init__(self, in_dim, out_dim, num_relations): - super().__init__() - self.num_relations = num_relations - self.linears = nn.ModuleList([nn.Linear(in_dim, out_dim) for _ in range(num_relations)]) - self.bias = nn.Parameter(torch.zeros(out_dim)) - - def forward(self, adj, h): - """ - adj: [B, Y, N, N] - h: [B, N, D] - """ - out = 0 - for i in range(self.num_relations): - adj_i = adj[:, i, :, :] - h_i = self.linears[i](h) - out += torch.bmm(adj_i, h_i) - - out = out + self.bias - return F.relu(out) - - -def encode_smiles_to_graph( - smiles: str, - atom_vocab: list = ["C", "N", "O", "F"], - bond_types: list = [1.0, 1.5, 2.0, 3.0], - max_nodes: int = 9 -) -> Optional[tuple[torch.Tensor, torch.Tensor]]: - """ - Convert a SMILES string into (adj, node) tensors. - - Parameters - ---------- - smiles : str - Input SMILES string - - atom_vocab : list of str - List of valid atom types - - bond_types : list of float - Allowed bond types (e.g., 1.0: single, 2.0: double) - - max_nodes : int - Max number of atoms (graph will be padded) - - Returns - ------- - adj : Tensor [Y, N, N] - Multi-relational adjacency tensor - - node : Tensor [N, T] - One-hot atom features - """ - mol = Chem.MolFromSmiles(smiles) - if mol is None or mol.GetNumAtoms() > max_nodes: - # raise ValueError(f"Invalid or oversized molecule: {smiles}") - return None - - N = max_nodes - T = len(atom_vocab) - Y = len(bond_types) - - # Initialize node features - node = np.zeros((N, T), dtype=np.float32) - for i, atom in enumerate(mol.GetAtoms()): - if i >= N: - break - symbol = atom.GetSymbol() - if symbol in atom_vocab: - node[i, atom_vocab.index(symbol)] = 1.0 - - # Initialize adjacency tensor - adj = np.zeros((Y, N, N), dtype=np.float32) - for bond in mol.GetBonds(): - i = bond.GetBeginAtomIdx() - j = bond.GetEndAtomIdx() - btype = bond.GetBondTypeAsDouble() - if btype in bond_types and i < N and j < N: - k = bond_types.index(btype) - adj[k, i, j] = 1.0 - adj[k, j, i] = 1.0 # undirected - - # Convert to torch.Tensor - return torch.tensor(adj), torch.tensor(node) - - -ATOM_DECODER = ["C", "N", "O", "F"] # Adjust based on your vocabulary -BOND_DICT = [ - None, - Chem.rdchem.BondType.SINGLE, - Chem.rdchem.BondType.DOUBLE, - Chem.rdchem.BondType.TRIPLE, - Chem.rdchem.BondType.AROMATIC, -] - -def decode_smiles_from_graph( - adj: torch.Tensor, - node: torch.Tensor, - atom_decoder: Optional[list] = ATOM_DECODER -) -> Optional[str]: - """ - Converts (adj, node) graph back to a SMILES string. - - Parameters - ---------- - adj : torch.Tensor - Tensor of shape [Y, N, N] with binary bond type edges. - node : torch.Tensor - Tensor of shape [N, T] with atom type softmax/one-hot. - atom_decoder : list - List mapping indices to atom symbols. - - Returns - ------- - Optional[str] - SMILES string if successful, None otherwise. - """ - try: - atom_types = node.argmax(dim=-1) # [N] - edge_types = torch.argmax(adj, dim=0) # [N, N], index of strongest bond type - - # Convert to RDKit Mol - mol_init = build_molecule_with_partial_charges(atom_types, edge_types, atom_decoder) - - # Try to correct connectivity and valency - for connection in (True, False): - mol_corr, _ = correct_mol(mol_init, connection=connection) - if mol_corr is not None: - break - else: - mol_corr = mol_init # fallback - - # Final sanitization - smiles = mol2smiles(mol_corr) - if not smiles: - smiles = Chem.MolToSmiles(mol_corr) - - # Canonicalize and return - mol = get_mol(smiles) - if mol is not None: - frags = Chem.rdmolops.GetMolFrags(mol, asMols=True, sanitizeFrags=False) - largest = max(frags, key=lambda m: m.GetNumAtoms()) - final_smiles = mol2smiles(largest) - return final_smiles if final_smiles and len(final_smiles) > 1 else None - return None - - except Exception as e: - print(f"[MolGAN Decode] Error during decoding: {e}") - return None - - - diff --git a/torch_molecule/generator/molgan/generator.py b/torch_molecule/generator/molgan/generator.py deleted file mode 100644 index a03e75f..0000000 --- a/torch_molecule/generator/molgan/generator.py +++ /dev/null @@ -1,88 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - - -# Including MolGANGeneratorConfig to allow for better instantiation -# if required for future iternations -class MolGANGeneratorConfig: - """ - Configuration class for MolGAN Generator and Discriminator. - - This class stores hyperparameters and architectural details used to construct - the MolGAN generator and other related modules. It allows modular control over - model depth, input/output dimensionality, and Gumbel-softmax behavior. - """ - def __init__(self, - latent_dim=56, - hidden_dims=[128, 128, 256], - num_nodes=9, - num_atom_types=5, - num_bond_types=4, - tau=1.0): - self.latent_dim = latent_dim - self.hidden_dims = hidden_dims - self.num_nodes = num_nodes - self.num_atom_types = num_atom_types - self.num_bond_types = num_bond_types - self.tau = tau - - -# MolGANGenerator -class MolGANGenerator(nn.Module): - - """ - Generator network for MolGAN. - - Maps a latent vector z to a molecular graph represented by: - - Adjacency tensor A ∈ [B, Y, N, N] (bonds) - - Node features X ∈ [B, N, T] (atoms) - - Uses Gumbel-Softmax to approximate discrete molecular structure. - """ - - def __init__(self, - latent_dim=56, - hidden_dims=[128, 128, 256], - num_nodes=9, - num_atom_types=5, - num_bond_types=4, - tau=1.0): - super().__init__() - self.latent_dim = latent_dim - self.hidden_dims = hidden_dims - self.num_nodes = num_nodes - self.num_atom_types = num_atom_types - self.num_bond_types = num_bond_types - self.tau = tau - - output_dim = (num_nodes * num_atom_types) + \ - (num_nodes * num_nodes * num_bond_types) - - layers = [] - input_dim = latent_dim - for hidden_dim in hidden_dims: - layers.append(nn.Linear(input_dim, hidden_dim)) - layers.append(nn.ReLU()) - input_dim = hidden_dim - layers.append(nn.Linear(input_dim, output_dim)) - - self.fc = nn.Sequential(*layers) - - def forward(self, z): - B = z.size(0) - out = self.fc(z) - - N, T, Y = self.num_nodes, self.num_atom_types, self.num_bond_types - node_size = N * T - adj_size = N * N * Y - - node_flat, adj_flat = torch.split(out, [node_size, adj_size], dim=1) - node = node_flat.view(B, N, T) - adj = adj_flat.view(B, Y, N, N) - - # Gumbel-softmax - node = F.gumbel_softmax(node, tau=self.tau, hard=True, dim=-1) - adj = F.gumbel_softmax(adj, tau=self.tau, hard=True, dim=1) - - return adj, node diff --git a/torch_molecule/generator/molgan/modeling_molgan.py b/torch_molecule/generator/molgan/modeling_molgan.py deleted file mode 100644 index e69de29..0000000 diff --git a/torch_molecule/generator/molgan/rewards_molgan.py b/torch_molecule/generator/molgan/rewards_molgan.py deleted file mode 100644 index 83261b1..0000000 --- a/torch_molecule/generator/molgan/rewards_molgan.py +++ /dev/null @@ -1,225 +0,0 @@ -from typing import List, Optional -import torch -import torch.nn as nn -from rdkit import Chem -from rdkit.Chem import QED, Crippen, rdMolDescriptors -from .gan_utils import RelationalGCNLayer -from ...utils.graph.graph_to_smiles import graph_to_smiles - - -# Non-Neural reward functions based on RDKit -def qed_reward(smiles: str) -> float: - mol = Chem.MolFromSmiles(smiles) - return QED.qed(mol) if mol else 0.0 - -def logp_reward(smiles: str) -> float: - mol = Chem.MolFromSmiles(smiles) - return Crippen.MolLogP(mol) if mol else 0.0 - -def weight_reward(smiles: str) -> float: - mol = Chem.MolFromSmiles(smiles) - return rdMolDescriptors.CalcExactMolWt(mol) if mol else 0.0 - -def combo_reward(smiles: str, weights=(0.7, 0.3)) -> float: - mol = Chem.MolFromSmiles(smiles) - if mol is None: - return 0.0 - qed_score = QED.qed(mol) - logp_score = Crippen.MolLogP(mol) - return weights[0] * qed_score + weights[1] * logp_score - -class RewardOracleNonNeural: - def __init__(self, kind="qed"): - if kind == "qed": - self.func = qed_reward - elif kind == "logp": - self.func = logp_reward - elif kind == "combo": - self.func = lambda s: combo_reward(s, weights=(0.6, 0.4)) - else: - raise ValueError(f"Unknown reward type: {kind}") - - def __call__(self, smiles: str) -> float: - return self.func(smiles) - - - - - -# Reward Network using Relational GCNs -class RewardNeuralNetwork(nn.Module): - """ - Reward Network that predicts reward from (adj, node) graphs. - """ - - def __init__(self, - num_atom_types=5, - num_bond_types=4, - hidden_dims:List[int]=[128, 128], - num_nodes=9): - super().__init__() - self.gcn_layers = nn.ModuleList() - self.gcn_layers.append(RelationalGCNLayer(num_atom_types, hidden_dims[0], num_bond_types)) - - current_dim = hidden_dims[0] - self.num_layers = len(hidden_dims) - for i in range(1, self.num_layers): - self.gcn_layers.append(RelationalGCNLayer(current_dim, hidden_dims[i], num_bond_types)) - - self.readout = nn.Sequential( - nn.Linear(num_nodes * hidden_dims[-1], hidden_dims[-1]), - nn.ReLU(), - nn.Linear(hidden_dims[-1], 1) - ) - - def forward(self, adj, node): - """ - adj: [B, Y, N, N] - node: [B, N, T] - """ - h = node - for layer in self.gcn_layers: - h = layer(adj, h) - - h = h.view(h.size(0), -1) - return self.readout(h).squeeze(-1) - - - def fit( - self, - train_loader, - epochs: int = 10, - lr: float = 1e-3, - weight_decay: float = 0.0, - verbose: bool = True - ): - """ - Train the reward self to approximate oracle rewards. - - Parameters - ---------- - reward_self : RewardNeuralNetwork - The neural network to train - - train_loader : DataLoader - Yields batches of (adj, node, reward) - - epochs : int - Number of training epochs - - lr : float - Learning rate - - weight_decay : float - Optional L2 regularization - - device : str - Device to run on ("cpu" or "cuda") - - verbose : bool - Whether to print losses - """ - optimizer = torch.optim.Adam(self.parameters(), lr=lr, weight_decay=weight_decay) - criterion = nn.MSELoss() - - self.train() - for epoch in range(epochs): - epoch_losses = [] - - for batch in train_loader: - adj = batch["adj"].to(self.device) - node = batch["node"].to(self.device) - reward = batch["reward"].to(self.device) - - pred = self(adj, node) # [B] - loss = criterion(pred, reward) - - optimizer.zero_grad() - loss.backward() - optimizer.step() - - epoch_losses.append(loss.item()) - - if verbose: - print(f"[Epoch {epoch+1}/{epochs}] RewardNet Loss: {sum(epoch_losses)/len(epoch_losses):.4f}") - - - - - -# Combined reward wrapper: which uses either neural or oracle rewards -class RewardOracle: - """ - Combined reward network that uses either a neural self or an oracle. - Accepts (adj, node) tensors as standard input. - """ - - def __init__( - self, - kind="qed", - reward_net: Optional[RewardNeuralNetwork] = None, - atom_decoder: List[str]=["C", "N", "O", "F"], - device="cpu" - ): - """ - Parameters - ---------- - kind : str - Type of reward function to use. Options: - - "qed": QED score - - "logp": LogP score - - "combo": Combination of QED and LogP - - "neural": Use a neural network for reward prediction - reward_net : RewardNeuralNetwork, optional - If kind is "neural", this should be a trained RewardNeuralNetwork instance. - atom_decoder : list of str, Optional - If kind is "qed", "logp", or "combo", this should be a list of atom types to decode graphs. - Defaults to ["C", "N", "O", "F"]. - device : str - Device to run the reward computation on ("cpu" or "cuda"). - """ - self.kind = kind - self.device = device - self.atom_decoder = atom_decoder or ["C", "N", "O", "F"] - - if kind in ["qed", "logp", "combo"]: - self.oracle = RewardOracleNonNeural(kind) - self.neural = None - if reward_net is not None: - raise ValueError("reward_net should not be provided for oracle modes") - elif kind == "neural": - assert reward_net is not None, "reward_net must be provided for 'neural' mode" - self.oracle = None - self.neural = reward_net.to(device).eval() - else: - raise ValueError(f"Invalid kind: {kind}") - - def __call__( - self, - adj: torch.Tensor, - node: torch.Tensor - ) -> torch.Tensor: - """ - Compute reward from graph tensors. - - Parameters - ---------- - adj : Tensor [B, Y, N, N] - node : Tensor [B, N, T] - - Returns - ------- - Tensor [B] : reward per sample - """ - if self.neural is not None: - with torch.no_grad(): - return self.neural(adj.to(self.device), node.to(self.device)) - - elif self.oracle: - graphs = list(zip(node.cpu().numpy(), adj.cpu().numpy())) - smiles_list = graph_to_smiles(graphs, self.atom_decoder) - rewards = [self.oracle(s) if s else 0.0 for s in smiles_list] - return torch.tensor(rewards, dtype=torch.float32, device=self.device) - - else: - raise ValueError("No reward function defined.") From 4ee78ffd0972136eb4ca93c159013b5c86c18d0a Mon Sep 17 00:00:00 2001 From: Manda Kausthubh Date: Sat, 27 Sep 2025 11:59:30 +0530 Subject: [PATCH 11/14] Added new Generator and Discriminator --- .../generator/molgan/molgan_dataset.py | 85 ++++++++++ .../generator/molgan/molgan_gen_disc.py | 146 ++++++++++++++++++ .../generator/molgan/molgan_model.py | 26 ++++ .../generator/molgan/molgan_r_gcn.py | 53 +++++++ .../generator/molgan/molgan_utils.py | 0 5 files changed, 310 insertions(+) create mode 100644 torch_molecule/generator/molgan/molgan_dataset.py create mode 100644 torch_molecule/generator/molgan/molgan_gen_disc.py create mode 100644 torch_molecule/generator/molgan/molgan_model.py create mode 100644 torch_molecule/generator/molgan/molgan_r_gcn.py create mode 100644 torch_molecule/generator/molgan/molgan_utils.py diff --git a/torch_molecule/generator/molgan/molgan_dataset.py b/torch_molecule/generator/molgan/molgan_dataset.py new file mode 100644 index 0000000..e440ea3 --- /dev/null +++ b/torch_molecule/generator/molgan/molgan_dataset.py @@ -0,0 +1,85 @@ +import torch +from torch.utils.data import Dataset +from rdkit import Chem +from typing import List + + +class MolGANDataset(Dataset): + """ + PyTorch Dataset class for MolGAN model, specifically dealing with converting + SMILES data to Graph tensor data, which is suitable for MolGAN training. + """ + + def __init__( + self, + data: List[str], + atom_types: List[str], + bond_types: List[str], + max_num_atoms: int = 50, + ): + """ + Initialize the MolGANDataset. + + Parameters + ---------- + data : list of str + List of SMILES strings representing the molecules. + atom_types : list of str + List of allowed atom types. + bond_types : list of str + List of allowed bond types. + max_num_atoms : int + Maximum number of atoms in a molecule for padding purposes. + """ + self.data = data + self.atom_types = atom_types + self.bond_types = bond_types + self.max_num_atoms = max_num_atoms + self.atom_type_to_idx = {atom: idx for idx, atom in enumerate(atom_types)} + self.bond_type_to_idx = {bond: idx for idx, bond in enumerate(bond_types)} + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + smiles = self.data[idx] + mol = Chem.MolFromSmiles(smiles) + if mol is None: + raise ValueError(f"Invalid SMILES string at index {idx}: {smiles}") + + num_atoms = mol.GetNumAtoms() + if num_atoms > self.max_num_atoms: + raise ValueError(f"Molecule at index {idx} exceeds max_num_atoms: {num_atoms} > {self.max_num_atoms}") + + # Initialize node features and adjacency matrix + node_features = torch.zeros((self.max_num_atoms, len(self.atom_types)), dtype=torch.float) + adjacency_matrix = torch.zeros((self.max_num_atoms, self.max_num_atoms, len(self.bond_types)), dtype=torch.float) + + # Fill node features + for i, atom in enumerate(mol.GetAtoms()): + atom_type = atom.GetSymbol() + if atom_type in self.atom_type_to_idx: + node_features[i, self.atom_type_to_idx[atom_type]] = 1.0 + + # Fill adjacency matrix + for bond in mol.GetBonds(): + begin_idx = bond.GetBeginAtomIdx() + end_idx = bond.GetEndAtomIdx() + bond_type = str(bond.GetBondType()) + if bond_type in self.bond_type_to_idx: + adjacency_matrix[begin_idx, end_idx, self.bond_type_to_idx[bond_type]] = 1.0 + adjacency_matrix[end_idx, begin_idx, self.bond_type_to_idx[bond_type]] = 1.0 + + return node_features, adjacency_matrix + + + + + + + + + + + + diff --git a/torch_molecule/generator/molgan/molgan_gen_disc.py b/torch_molecule/generator/molgan/molgan_gen_disc.py new file mode 100644 index 0000000..b6fadbb --- /dev/null +++ b/torch_molecule/generator/molgan/molgan_gen_disc.py @@ -0,0 +1,146 @@ +import torch +from dataclasses import dataclass +from typing import Tuple +from .molgan_r_gcn import RelationalGCNLayer # Local import to avoid circular dependency + +@dataclass +class MolGANGeneratorConfig: + def __init__( + self, + z_dim: int = 32, + g_conv_dim: int = 64, + d_conv_dim: int = 64, + g_num_layers: int = 3, + d_num_layers: int = 3, + num_atom_types: int = 5, + num_bond_types: int = 4, + max_num_atoms: int = 9, + dropout: float = 0.0, + use_batchnorm: bool = True, + ): + self.z_dim = z_dim + self.g_conv_dim = g_conv_dim + self.d_conv_dim = d_conv_dim + self.g_num_layers = g_num_layers + self.d_num_layers = d_num_layers + self.num_atom_types = num_atom_types + self.num_bond_types = num_bond_types + self.max_num_atoms = max_num_atoms + self.dropout = dropout + self.use_batchnorm = use_batchnorm + + +# MolGAN Generotor +class MolGANGenerator(torch.nn.Module): + def __init__(self, config: MolGANGeneratorConfig): + super(MolGANGenerator, self).__init__() + self.z_dim = config.z_dim + self.g_conv_dim = config.g_conv_dim + self.g_num_layers = config.g_num_layers + self.num_atom_types = config.num_atom_types + self.num_bond_types = config.num_bond_types + self.max_num_atoms = config.max_num_atoms + self.dropout = config.dropout + self.use_batchnorm = config.use_batchnorm + + layers = [] + input_dim = self.z_dim + for i in range(self.g_num_layers): + output_dim = self.g_conv_dim * (2 ** i) + layers.append(torch.nn.Linear(input_dim, output_dim)) + if self.use_batchnorm: + layers.append(torch.nn.BatchNorm1d(output_dim)) + layers.append(torch.nn.ReLU()) + if self.dropout > 0: + layers.append(torch.nn.Dropout(self.dropout)) + input_dim = output_dim + + self.fc_layers = torch.nn.Sequential(*layers) + self.atom_fc = torch.nn.Linear(input_dim, self.max_num_atoms * self.num_atom_types) + self.bond_fc = torch.nn.Linear(input_dim, self.num_bond_types * self.max_num_atoms * self.max_num_atoms) + + def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size = z.size(0) + h = self.fc_layers(z) + atom_logits = self.atom_fc(h).view(batch_size, self.max_num_atoms, self.num_atom_types) + # Output bond logits with [batch, num_bond_types, max_num_atoms, max_num_atoms] order + bond_logits = self.bond_fc(h).view(batch_size, self.num_bond_types, self.max_num_atoms, self.max_num_atoms) + return atom_logits, bond_logits + + + +# MolGAN Discriminator +@dataclass +class MolGANDiscriminatorConfig: + def __init__( + self, + in_dim: int = 5, # Number of atom types (node feature dim). Typically set automatically. + hidden_dim: int = 64, # Hidden feature/channel size for GCN layers. + num_layers: int = 3, # Number of R-GCN layers (depth). + num_relations: int = 4, # Number of bond types (relation types per edge). + max_num_atoms: int = 9, # Max node count in padded tensor. + dropout: float = 0.0, # Dropout between layers. + use_batchnorm: bool = True, # BatchNorm or similar normalization. + readout: str = 'sum', # Readout type (sum/mean/max for pooling nodes to graph-level vector) + ): + self.in_dim = in_dim + self.hidden_dim = hidden_dim + self.num_layers = num_layers + self.num_relations = num_relations + self.max_num_atoms = max_num_atoms + self.dropout = dropout + self.use_batchnorm = use_batchnorm + self.readout = readout + + +class MolGANDiscriminator(torch.nn.Module): + def __init__(self, config: MolGANDiscriminatorConfig): + super(MolGANDiscriminator, self).__init__() + + self.in_dim = config.in_dim + self.hidden_dim = config.hidden_dim + self.num_layers = config.num_layers + self.num_relations = config.num_relations + self.max_num_atoms = config.max_num_atoms + self.dropout = config.dropout + self.use_batchnorm = config.use_batchnorm + self.readout = config.readout + + layers = [] + input_dim = self.in_dim + for i in range(self.num_layers): + output_dim = self.hidden_dim * (2 ** i) + layers.append(RelationalGCNLayer(input_dim, output_dim, self.num_relations)) + if self.use_batchnorm: + layers.append(torch.nn.BatchNorm1d(self.max_num_atoms)) + layers.append(torch.nn.LeakyReLU(0.2)) + if self.dropout > 0: + layers.append(torch.nn.Dropout(self.dropout)) + input_dim = output_dim + + self.gcn_layers = torch.nn.ModuleList(layers) + self.fc = torch.nn.Linear(input_dim, 1) + + def forward(self, atom_feats: torch.Tensor, adj: torch.Tensor) -> torch.Tensor: + # atom_feats: [batch, max_num_atoms, num_atom_types] + # adj: [batch, num_bond_types, max_num_atoms, max_num_atoms] + h = atom_feats + for layer in self.gcn_layers: + if isinstance(layer, RelationalGCNLayer): + h = layer(h, adj) + else: + h = layer(h) + + if self.readout == 'sum': + g = h.sum(dim=1) # [batch, hidden_dim] + elif self.readout == 'mean': + g = h.mean(dim=1) + elif self.readout == 'max': + g, _ = h.max(dim=1) + else: + raise ValueError(f"Unknown readout type: {self.readout}") + + out = self.fc(g) # [batch, 1] + return out.squeeze(-1) # [batch] + + diff --git a/torch_molecule/generator/molgan/molgan_model.py b/torch_molecule/generator/molgan/molgan_model.py new file mode 100644 index 0000000..ffa0f7c --- /dev/null +++ b/torch_molecule/generator/molgan/molgan_model.py @@ -0,0 +1,26 @@ +from typing import Optional +import torch +from dataclasses import dataclass +from .molgan_gen_disc import * +from .molgan_dataset import MolGANDataset + + +class MolGANModel(torch.nn.Module): + def __init__( + self, + generator_config: MolGANGeneratorConfig, + discriminator_config: MolGANDiscriminatorConfig, + reward_network_config: Optional[MolGANDiscriminatorConfig] = None, + ): + super(MolGANModel, self).__init__() + + # Initialize generator and discriminator + self.gen = MolGANGenerator(generator_config) + self.disc = MolGANDiscriminator(discriminator_config) + + # By default, the reward network is the same as the discriminator + self.reward_net = ( + MolGANDiscriminator(reward_network_config) + if reward_network_config is not None + else MolGANDiscriminator(discriminator_config) + ) diff --git a/torch_molecule/generator/molgan/molgan_r_gcn.py b/torch_molecule/generator/molgan/molgan_r_gcn.py new file mode 100644 index 0000000..c0e8124 --- /dev/null +++ b/torch_molecule/generator/molgan/molgan_r_gcn.py @@ -0,0 +1,53 @@ +import torch +import torch.nn as nn + +class RelationalGCNLayer(nn.Module): + """ + Relational Graph Convolutional Layer for fully connected dense graphs. + Input: + - node_feats: [batch, num_nodes, in_dim] + - adj: [batch, num_relations, num_nodes, num_nodes] + Output: + - node_feats: [batch, num_nodes, out_dim] + """ + def __init__(self, in_dim, out_dim, num_relations, use_bias=True): + super(RelationalGCNLayer, self).__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.num_relations = num_relations + + # One weight matrix per relation/bond type + self.rel_weights = nn.Parameter(torch.Tensor(num_relations, in_dim, out_dim)) + if use_bias: + self.bias = nn.Parameter(torch.Tensor(out_dim)) + else: + self.register_parameter('bias', None) + + self.reset_parameters() + + def reset_parameters(self): + nn.init.xavier_uniform_(self.rel_weights) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def forward(self, node_feats, adj): + # node_feats: [batch, num_nodes, in_dim] + # adj: [batch, num_relations, num_nodes, num_nodes] + batch_size, num_nodes, _ = node_feats.shape + + out = torch.zeros(batch_size, num_nodes, self.out_dim, device=node_feats.device) + + for rel in range(self.num_relations): + # Multiply node features by relation weight + # [batch, num_nodes, in_dim] @ [in_dim, out_dim] -> [batch, num_nodes, out_dim] + h_rel = torch.matmul(node_feats, self.rel_weights[rel]) + # Propagate messages using adjacency for this relation: + # [batch, num_nodes, out_dim] ← [batch, num_nodes, num_nodes] @ [batch, num_nodes, out_dim] + # Here adj[:, rel, :, :] gives [batch, num_nodes, num_nodes] + out += torch.bmm(adj[:, rel], h_rel) + + if self.bias is not None: + out += self.bias + + return out # You can add activation after this (ReLU, LeakyReLU, etc.) + diff --git a/torch_molecule/generator/molgan/molgan_utils.py b/torch_molecule/generator/molgan/molgan_utils.py new file mode 100644 index 0000000..e69de29 From 0faad3d8dd0d4dce1fdffec8c2b805e52d6b036d Mon Sep 17 00:00:00 2001 From: Manda Kausthubh Date: Sun, 28 Sep 2025 00:40:48 +0530 Subject: [PATCH 12/14] Basic Training loop added --- .../generator/molgan/molgan_gen_disc.py | 32 ++++- .../generator/molgan/molgan_model.py | 136 ++++++++++++++++-- 2 files changed, 153 insertions(+), 15 deletions(-) diff --git a/torch_molecule/generator/molgan/molgan_gen_disc.py b/torch_molecule/generator/molgan/molgan_gen_disc.py index b6fadbb..e1c2506 100644 --- a/torch_molecule/generator/molgan/molgan_gen_disc.py +++ b/torch_molecule/generator/molgan/molgan_gen_disc.py @@ -121,26 +121,44 @@ def __init__(self, config: MolGANDiscriminatorConfig): self.gcn_layers = torch.nn.ModuleList(layers) self.fc = torch.nn.Linear(input_dim, 1) - def forward(self, atom_feats: torch.Tensor, adj: torch.Tensor) -> torch.Tensor: + def forward( + self, + atom_feats: torch.Tensor, + adj: torch.Tensor, + mask: torch.Tensor + ) -> torch.Tensor: # atom_feats: [batch, max_num_atoms, num_atom_types] # adj: [batch, num_bond_types, max_num_atoms, max_num_atoms] + # mask: [batch, max_num_atoms] (float, 1=real, 0=pad) h = atom_feats for layer in self.gcn_layers: if isinstance(layer, RelationalGCNLayer): h = layer(h, adj) else: - h = layer(h) + # If using BatchNorm1d, input should be [batch, features, nodes] + if isinstance(layer, torch.nn.BatchNorm1d): + # Permute for batchnorm: [batch, nodes, features] → [batch, features, nodes] + h = layer(h.permute(0, 2, 1)).permute(0, 2, 1) + else: + h = layer(h) + + # MASKED GRAPH READOUT + # mask: [batch, max_num_atoms] float + mask = mask.unsqueeze(-1) # [batch, max_num_atoms, 1] + h_masked = h * mask # zeros padded nodes if self.readout == 'sum': - g = h.sum(dim=1) # [batch, hidden_dim] + g = h_masked.sum(dim=1) # [batch, hidden_dim] elif self.readout == 'mean': - g = h.mean(dim=1) + # Prevent divide-by-zero with (mask.sum(dim=1, keepdim=True)+1e-8) + g = h_masked.sum(dim=1) / (mask.sum(dim=1) + 1e-8) elif self.readout == 'max': - g, _ = h.max(dim=1) + # Set padded to large neg, then max + h_masked_pad = h.clone() + h_masked_pad[mask.squeeze(-1) == 0] = float('-inf') + g, _ = h_masked_pad.max(dim=1) else: raise ValueError(f"Unknown readout type: {self.readout}") out = self.fc(g) # [batch, 1] return out.squeeze(-1) # [batch] - - diff --git a/torch_molecule/generator/molgan/molgan_model.py b/torch_molecule/generator/molgan/molgan_model.py index ffa0f7c..9d99318 100644 --- a/torch_molecule/generator/molgan/molgan_model.py +++ b/torch_molecule/generator/molgan/molgan_model.py @@ -1,26 +1,146 @@ from typing import Optional import torch -from dataclasses import dataclass +import torch.nn as nn from .molgan_gen_disc import * -from .molgan_dataset import MolGANDataset -class MolGANModel(torch.nn.Module): +class MolGANModel(nn.Module): def __init__( self, - generator_config: MolGANGeneratorConfig, - discriminator_config: MolGANDiscriminatorConfig, + generator_config: Optional[MolGANGeneratorConfig] = None, + discriminator_config: Optional[MolGANDiscriminatorConfig] = None, reward_network_config: Optional[MolGANDiscriminatorConfig] = None, ): super(MolGANModel, self).__init__() # Initialize generator and discriminator - self.gen = MolGANGenerator(generator_config) - self.disc = MolGANDiscriminator(discriminator_config) + self.gen_config = generator_config if generator_config is not None else MolGANGeneratorConfig() + self.gen = MolGANGenerator(self.gen_config) + + self.disc_config = discriminator_config if discriminator_config is not None else MolGANDiscriminatorConfig() + self.disc = MolGANDiscriminator(self.disc_config) # By default, the reward network is the same as the discriminator self.reward_net = ( MolGANDiscriminator(reward_network_config) if reward_network_config is not None - else MolGANDiscriminator(discriminator_config) + else MolGANDiscriminator(self.disc_config) ) + + + def generate(self, batch_size: int): + """Generate a batch of molecules.""" + return self.gen(batch_size) + + + def discriminate( + self, + atom_type_matrix: torch.Tensor, + bond_type_tensor: torch.Tensor, + molecule_mask: Optional[torch.Tensor], + ): + """Discriminate a batch of molecules.""" + return self.disc(atom_type_matrix, bond_type_tensor, molecule_mask) + + def reward( + self, + atom_type_matrix: torch.Tensor, + bond_type_tensor: torch.Tensor, + molecule_mask: Optional[torch.Tensor], + ): + """Compute reward for a batch of molecules.""" + return self.reward_net(atom_type_matrix, bond_type_tensor, molecule_mask) + + def config_training( + self, + gen_optimizer: torch.optim.Optimizer, + disc_optimizer: torch.optim.Optimizer, + lambda_rl: float = 0.0, + reward_optimizer: Optional[torch.optim.Optimizer] = None, + gen_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + disc_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + reward_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + ): + """Configure optimizers and schedulers for training.""" + self.gen_optimizer = gen_optimizer + self.disc_optimizer = disc_optimizer + self.reward_optimizer = reward_optimizer + self.lambda_rl = lambda_rl + + self.gen_scheduler = gen_scheduler + self.disc_scheduler = disc_scheduler + self.reward_scheduler = reward_scheduler + + def training_step(self, batch, reward_fn, pretrain=False): + node_features, adjacency_matrix = batch + batch_size = node_features.size(0) + adjacency_matrix = adjacency_matrix.permute(0, 3, 1, 2) + mask = (node_features.sum(-1) != 0).float() + + z = torch.randn(batch_size, self.gen_config.z_dim, device=node_features.device) + fake_atom_logits, fake_bond_logits = self.gen(z) + fake_atom = torch.softmax(fake_atom_logits, -1) + fake_bond = torch.softmax(fake_bond_logits, 1) + fake_mask = (fake_atom.argmax(-1) != 0).float() + + # === Discriminator update === + self.disc_optimizer.zero_grad() + real_scores = self.disc(node_features, adjacency_matrix, mask) + fake_scores = self.disc(fake_atom, fake_bond, fake_mask) + wgan_loss = -(real_scores.mean() - fake_scores.mean()) + wgan_loss.backward() + self.disc_optimizer.step() + if self.disc_scheduler: self.disc_scheduler.step() + + # === Reward net update === + if self.reward_optimizer is not None: + self.reward_optimizer.zero_grad() + reward_targets = reward_fn(node_features, adjacency_matrix, mask) + pred_rewards = self.reward_net(node_features, adjacency_matrix, mask) + r_loss = torch.nn.functional.mse_loss(pred_rewards, reward_targets) + r_loss.backward() + self.reward_optimizer.step() + if self.reward_scheduler: self.reward_scheduler.step() + else: + r_loss = torch.tensor(0.0, device=node_features.device) + + # === Generator update === + self.gen_optimizer.zero_grad() + fake_atom_logits, fake_bond_logits = self.gen(z) + fake_atom = torch.softmax(fake_atom_logits, -1) + fake_bond = torch.softmax(fake_bond_logits, 1) + fake_mask = (fake_atom.argmax(-1) != 0).float() + fake_scores = self.disc(fake_atom, fake_bond, fake_mask) + g_wgan_loss = -fake_scores.mean() + if not pretrain and hasattr(self, 'lambda_rl') and self.lambda_rl > 0: + with torch.no_grad(): + rewards = self.reward_net(fake_atom, fake_bond, fake_mask) + rl_loss = -rewards.mean() + else: + rl_loss = torch.tensor(0.0, device=node_features.device) + total_loss = g_wgan_loss + getattr(self, 'lambda_rl', 0.0) * rl_loss + total_loss.backward() + self.gen_optimizer.step() + if self.gen_scheduler: self.gen_scheduler.step() + + return { + 'd_loss': wgan_loss.item(), + 'g_loss': g_wgan_loss.item(), + 'rl_loss': rl_loss.item(), + 'r_loss': r_loss.item() if self.reward_optimizer is not None else None + } + + + def train_epoch(self, dataloader, reward_fn, pretrain=False, log_interval=100): + self.gen.train() + self.disc.train() + if self.reward_net: self.reward_net.train() + for i, batch in enumerate(dataloader): + result = self.training_step(batch, reward_fn, pretrain) + if i % log_interval == 0: + print({k: round(v, 5) for k, v in result.items()}) + + + + + From 44c574646ad1b1cf02f0ec35495bfd733e119688 Mon Sep 17 00:00:00 2001 From: Manda Kausthubh Date: Sun, 28 Sep 2025 00:51:59 +0530 Subject: [PATCH 13/14] Created abstractions --- .../generator/molgan/molgan_generator.py | 78 +++++++++++++++++++ .../generator/molgan/molgan_model.py | 2 +- .../generator/molgan/molgan_utils.py | 0 3 files changed, 79 insertions(+), 1 deletion(-) create mode 100644 torch_molecule/generator/molgan/molgan_generator.py delete mode 100644 torch_molecule/generator/molgan/molgan_utils.py diff --git a/torch_molecule/generator/molgan/molgan_generator.py b/torch_molecule/generator/molgan/molgan_generator.py new file mode 100644 index 0000000..96889ff --- /dev/null +++ b/torch_molecule/generator/molgan/molgan_generator.py @@ -0,0 +1,78 @@ +import torch +import numpy as np +from typing import Optional, List, Union +from .molgan_model import MolGANModel +from .molgan_dataset import MolGANDataset +from torch_molecule.base.generator import BaseMolecularGenerator + +class MolGANMolecularGenerator(BaseMolecularGenerator): + """ + MolGAN model wrapper for standardized molecular generation API. + Inherits fit/generate signature from BaseMolecularGenerator. + """ + def __init__( + self, + generator, discriminator, reward_net=None, + gen_config=None, disc_config=None, + device: Optional[Union[torch.device, str]] = None, + model_name: str = "MolGANMolecularGenerator" + ): + super().__init__(device=device, model_name=model_name) + self.generator = generator + self.discriminator = discriminator + self.reward_net = reward_net + self.gen_config = gen_config + self.disc_config = disc_config + + self.device = device if device is not None else torch.device("cpu") + self.generator.to(self.device) + self.discriminator.to(self.device) + if self.reward_net: self.reward_net.to(self.device) + self.is_fitted = False + + def fit(self, X: List[str], y: Optional[np.ndarray] = None, epochs=100, batch_size=32, **kwargs): + """ + Train MolGAN on molecules (SMILES). + """ + # 1. Prepare dataset (requires atom_types and bond_types in kwargs) + atom_types = kwargs.get("atom_types") + bond_types = kwargs.get("bond_types") + max_num_atoms = kwargs.get("max_num_atoms", 50) + dataset = MolGANDataset(X, atom_types, bond_types, max_num_atoms) + dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True) + # 2. Set up optimizers/schedulers etc (credits: your config_training & usual setup) + self.config_training( + gen_optimizer=kwargs["gen_optimizer"], + disc_optimizer=kwargs["disc_optimizer"], + reward_optimizer=kwargs.get("reward_optimizer", None), + gen_scheduler=kwargs.get("gen_scheduler", None), + disc_scheduler=kwargs.get("disc_scheduler", None), + reward_scheduler=kwargs.get("reward_scheduler", None), + ) + # 3. Training loop (calls training_step as above) + for epoch in range(epochs): + for batch in dataloader: + # Keep reward_fn optional for RL phase + reward_fn = kwargs.get("reward_fn", None) + pretrain = kwargs.get("pretrain", False) + self.training_step(batch, reward_fn, pretrain) + self.is_fitted = True + return self + + def generate(self, n_samples: int, **kwargs) -> List[str]: + """ + Generate n_samples molecules as SMILES. + """ + self.generator.eval() + zs = torch.randn(n_samples, self.generator.z_dim, device=self.device) + with torch.no_grad(): + atom_logits, bond_logits = self.generator(zs) + # [n_samples, max_num_atoms, num_atom_types], [n_samples, num_bond_types, max_num_atoms, max_num_atoms] + atom_types = atom_logits.argmax(dim=-1).cpu().numpy() # [n_samples, max_num_atoms] + bond_types = bond_logits.argmax(dim=1).cpu().numpy() # [n_samples, max_num_atoms, max_num_atoms] + # Convert to SMILES (implement graph2smiles utility based on your atom_types/bond_types) + smiles_strings = [ + graph2smiles(atom_types[i], bond_types[i], kwargs.get("atom_types"), kwargs.get("bond_types")) + for i in range(n_samples) + ] + return smiles_strings diff --git a/torch_molecule/generator/molgan/molgan_model.py b/torch_molecule/generator/molgan/molgan_model.py index 9d99318..1dba5de 100644 --- a/torch_molecule/generator/molgan/molgan_model.py +++ b/torch_molecule/generator/molgan/molgan_model.py @@ -55,7 +55,7 @@ def config_training( self, gen_optimizer: torch.optim.Optimizer, disc_optimizer: torch.optim.Optimizer, - lambda_rl: float = 0.0, + lambda_rl: float = 0.25, reward_optimizer: Optional[torch.optim.Optimizer] = None, gen_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, disc_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, diff --git a/torch_molecule/generator/molgan/molgan_utils.py b/torch_molecule/generator/molgan/molgan_utils.py deleted file mode 100644 index e69de29..0000000 From f46ce262d0e15e04f80383fdee4eaabd94c42f57 Mon Sep 17 00:00:00 2001 From: Manda Kausthubh Date: Thu, 2 Oct 2025 16:28:55 +0530 Subject: [PATCH 14/14] Added default function and MolGANDataset implementation --- .../generator/molgan/molgan_dataset.py | 132 +++++---- .../generator/molgan/molgan_gen_disc.py | 41 ++- .../generator/molgan/molgan_generator.py | 250 +++++++++++++----- .../generator/molgan/molgan_model.py | 40 ++- .../generator/molgan/molgan_utils.py | 14 + 5 files changed, 352 insertions(+), 125 deletions(-) create mode 100644 torch_molecule/generator/molgan/molgan_utils.py diff --git a/torch_molecule/generator/molgan/molgan_dataset.py b/torch_molecule/generator/molgan/molgan_dataset.py index e440ea3..bcea90a 100644 --- a/torch_molecule/generator/molgan/molgan_dataset.py +++ b/torch_molecule/generator/molgan/molgan_dataset.py @@ -1,81 +1,105 @@ +from typing import List, Optional, Callable +from rdkit import Chem import torch from torch.utils.data import Dataset -from rdkit import Chem -from typing import List - +from .molgan_utils import qed_reward_fn class MolGANDataset(Dataset): """ - PyTorch Dataset class for MolGAN model, specifically dealing with converting - SMILES data to Graph tensor data, which is suitable for MolGAN training. + A PyTorch Dataset for MolGAN, with all RDKit and graph tensor processing + precomputed in __init__ for fast, pure-tensor __getitem__ access. + Optionally caches property values for each molecule. """ - def __init__( self, data: List[str], atom_types: List[str], bond_types: List[str], max_num_atoms: int = 50, + cache_properties: bool = False, + property_fn: Optional[Callable] = None, + return_mol: bool = False, + device: Optional[torch.device] = None ): - """ - Initialize the MolGANDataset. - - Parameters - ---------- - data : list of str - List of SMILES strings representing the molecules. - atom_types : list of str - List of allowed atom types. - bond_types : list of str - List of allowed bond types. - max_num_atoms : int - Maximum number of atoms in a molecule for padding purposes. - """ self.data = data self.atom_types = atom_types self.bond_types = bond_types self.max_num_atoms = max_num_atoms self.atom_type_to_idx = {atom: idx for idx, atom in enumerate(atom_types)} self.bond_type_to_idx = {bond: idx for idx, bond in enumerate(bond_types)} + self.return_mol = return_mol + self.device = torch.device(device) if device is not None else None + + self.node_features = [] + self.adjacency_matrices = [] + self.mols = [] + self.cached_properties = [] if cache_properties and property_fn else None + + self.property_fn = property_fn if property_fn is not None else qed_reward_fn + + for idx, smiles in enumerate(self.data): + mol = Chem.MolFromSmiles(smiles) + self.mols.append(mol) + # Default: if invalid, fill with zeros and (optionally) property 0 + nf = torch.zeros((self.max_num_atoms, len(self.atom_types)), dtype=torch.float) + adj = torch.zeros((self.max_num_atoms, self.max_num_atoms, len(self.bond_types)), dtype=torch.float) + prop_val = 0.0 if cache_properties else None + + if mol is not None: + num_atoms = mol.GetNumAtoms() + if num_atoms > self.max_num_atoms: + raise ValueError(f"Molecule at index {idx} exceeds max_num_atoms: {num_atoms} > {self.max_num_atoms}") + + for i, atom in enumerate(mol.GetAtoms()): + atom_type = atom.GetSymbol() + if atom_type in self.atom_type_to_idx: + nf[i, self.atom_type_to_idx[atom_type]] = 1.0 + + for bond in mol.GetBonds(): + begin_idx = bond.GetBeginAtomIdx() + end_idx = bond.GetEndAtomIdx() + bond_type = str(bond.GetBondType()) + if bond_type in self.bond_type_to_idx: + bidx = self.bond_type_to_idx[bond_type] + adj[begin_idx, end_idx, bidx] = 1.0 + adj[end_idx, begin_idx, bidx] = 1.0 + + if cache_properties and self.property_fn: + try: + prop_val = self.property_fn(mol) + except Exception: + prop_val = 0.0 + + # Move tensors to device immediately if a device is set + if self.device is not None: + nf = nf.to(self.device) + adj = adj.to(self.device) + + self.node_features.append(nf) + self.adjacency_matrices.append(adj) + if cache_properties and property_fn and self.cached_properties is not None: + self.cached_properties.append(prop_val) def __len__(self): return len(self.data) def __getitem__(self, idx): - smiles = self.data[idx] - mol = Chem.MolFromSmiles(smiles) - if mol is None: - raise ValueError(f"Invalid SMILES string at index {idx}: {smiles}") - - num_atoms = mol.GetNumAtoms() - if num_atoms > self.max_num_atoms: - raise ValueError(f"Molecule at index {idx} exceeds max_num_atoms: {num_atoms} > {self.max_num_atoms}") - - # Initialize node features and adjacency matrix - node_features = torch.zeros((self.max_num_atoms, len(self.atom_types)), dtype=torch.float) - adjacency_matrix = torch.zeros((self.max_num_atoms, self.max_num_atoms, len(self.bond_types)), dtype=torch.float) - - # Fill node features - for i, atom in enumerate(mol.GetAtoms()): - atom_type = atom.GetSymbol() - if atom_type in self.atom_type_to_idx: - node_features[i, self.atom_type_to_idx[atom_type]] = 1.0 - - # Fill adjacency matrix - for bond in mol.GetBonds(): - begin_idx = bond.GetBeginAtomIdx() - end_idx = bond.GetEndAtomIdx() - bond_type = str(bond.GetBondType()) - if bond_type in self.bond_type_to_idx: - adjacency_matrix[begin_idx, end_idx, self.bond_type_to_idx[bond_type]] = 1.0 - adjacency_matrix[end_idx, begin_idx, self.bond_type_to_idx[bond_type]] = 1.0 - - return node_features, adjacency_matrix - - - - - + parts = [ + self.node_features[idx], + self.adjacency_matrices[idx] + ] + # add optional property + if self.cached_properties is not None: + parts.append(self.cached_properties[idx]) + # add optional Mol object (can always access it if you want) + if self.return_mol: + parts.append(self.mols[idx]) + + # Default: (node_features, adjacency_matrix) + # With property: (node_features, adjacency_matrix, property) + # With property and mol: (node_features, adjacency_matrix, property, mol) + # With only mol: (node_features, adjacency_matrix, mol) + return tuple(parts) diff --git a/torch_molecule/generator/molgan/molgan_gen_disc.py b/torch_molecule/generator/molgan/molgan_gen_disc.py index e1c2506..e71b971 100644 --- a/torch_molecule/generator/molgan/molgan_gen_disc.py +++ b/torch_molecule/generator/molgan/molgan_gen_disc.py @@ -2,6 +2,7 @@ from dataclasses import dataclass from typing import Tuple from .molgan_r_gcn import RelationalGCNLayer # Local import to avoid circular dependency +import torch.nn.functional as F @dataclass class MolGANGeneratorConfig: @@ -16,7 +17,9 @@ def __init__( num_bond_types: int = 4, max_num_atoms: int = 9, dropout: float = 0.0, + tau: float = 1.0, use_batchnorm: bool = True, + device: str = 'cuda' if torch.cuda.is_available() else 'cpu', ): self.z_dim = z_dim self.g_conv_dim = g_conv_dim @@ -28,6 +31,8 @@ def __init__( self.max_num_atoms = max_num_atoms self.dropout = dropout self.use_batchnorm = use_batchnorm + self.tau = tau # Gumbel-Softmax temperature + self.device = device # MolGAN Generotor @@ -42,6 +47,9 @@ def __init__(self, config: MolGANGeneratorConfig): self.max_num_atoms = config.max_num_atoms self.dropout = config.dropout self.use_batchnorm = config.use_batchnorm + self.tau = config.tau + self.device = config.device + self.to(self.device) layers = [] input_dim = self.z_dim @@ -59,13 +67,38 @@ def __init__(self, config: MolGANGeneratorConfig): self.atom_fc = torch.nn.Linear(input_dim, self.max_num_atoms * self.num_atom_types) self.bond_fc = torch.nn.Linear(input_dim, self.num_bond_types * self.max_num_atoms * self.max_num_atoms) - def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def forward(self, z: torch.Tensor, sample_mode='softmax') -> Tuple[torch.Tensor, torch.Tensor]: batch_size = z.size(0) h = self.fc_layers(z) atom_logits = self.atom_fc(h).view(batch_size, self.max_num_atoms, self.num_atom_types) # Output bond logits with [batch, num_bond_types, max_num_atoms, max_num_atoms] order bond_logits = self.bond_fc(h).view(batch_size, self.num_bond_types, self.max_num_atoms, self.max_num_atoms) - return atom_logits, bond_logits + + # Nodes + if sample_mode == 'softmax': + node = torch.softmax(atom_logits, dim=-1) + elif sample_mode == 'soft_gumbel': + node = F.gumbel_softmax(atom_logits, tau=self.tau, hard=False, dim=-1) + elif sample_mode == 'hard_gumbel': + node = F.gumbel_softmax(atom_logits, tau=self.tau, hard=True, dim=-1) + elif sample_mode == 'argmax': + node = atom_logits.argmax(dim=-1) + else: + raise ValueError(f"Unknown sample_mode: {sample_mode}") + + # Adjacency + if sample_mode == 'softmax': + adj = torch.softmax(bond_logits, dim=1) + elif sample_mode == 'soft_gumbel': + adj = F.gumbel_softmax(bond_logits, tau=self.tau, hard=False, dim=1) + elif sample_mode == 'hard_gumbel': + adj = F.gumbel_softmax(bond_logits, tau=self.tau, hard=True, dim=1) + else: + raise ValueError(f"Unknown sample_mode: {sample_mode}") + + return node, adj + + @@ -82,6 +115,7 @@ def __init__( dropout: float = 0.0, # Dropout between layers. use_batchnorm: bool = True, # BatchNorm or similar normalization. readout: str = 'sum', # Readout type (sum/mean/max for pooling nodes to graph-level vector) + device: str = 'cuda' if torch.cuda.is_available() else 'cpu', ): self.in_dim = in_dim self.hidden_dim = hidden_dim @@ -91,6 +125,7 @@ def __init__( self.dropout = dropout self.use_batchnorm = use_batchnorm self.readout = readout + self.device = device class MolGANDiscriminator(torch.nn.Module): @@ -120,6 +155,8 @@ def __init__(self, config: MolGANDiscriminatorConfig): self.gcn_layers = torch.nn.ModuleList(layers) self.fc = torch.nn.Linear(input_dim, 1) + self.device = config.device + self.to(self.device) def forward( self, diff --git a/torch_molecule/generator/molgan/molgan_generator.py b/torch_molecule/generator/molgan/molgan_generator.py index 96889ff..4a02164 100644 --- a/torch_molecule/generator/molgan/molgan_generator.py +++ b/torch_molecule/generator/molgan/molgan_generator.py @@ -1,78 +1,206 @@ import torch -import numpy as np -from typing import Optional, List, Union +from typing import Optional, Union, List, Callable from .molgan_model import MolGANModel -from .molgan_dataset import MolGANDataset +from .molgan_gen_disc import MolGANGeneratorConfig, MolGANDiscriminatorConfig from torch_molecule.base.generator import BaseMolecularGenerator +from .molgan_dataset import MolGANDataset +from torch_molecule.utils import graph_to_smiles, graph_from_smiles + + + + +class MolGANGenerativeModel(BaseMolecularGenerator): -class MolGANMolecularGenerator(BaseMolecularGenerator): """ - MolGAN model wrapper for standardized molecular generation API. - Inherits fit/generate signature from BaseMolecularGenerator. + This generator implements the MolGAN model for molecular graph generation. + + The model uses a GAN like architecture with a generator and discriminator, + combined with a reward network to optimize for desired molecular properties. + The generator produces molecular graphs represented as adjacency matrices, with the discriminator + and reward network evaluating their validity and quality. The reward network can be trained to optimize + for specific chemical properties, such as drug-likeness or synthetic accessibility. + + + References: + ---------- + - De Cao, N., & Kipf, T. (2018). MolGAN: An implicit generative model for small molecular graphs. + arXiv preprint arXiv:1805.11973. Link: https://arxiv.org/pdf/1805.11973 + + Parameters: + ---------- + MolGANGeneratorConfig : MolGANGeneratorConfig, optional + Configuration for the generator network. If None, default values are used. + + MolGANDiscriminatorConfig : MolGANDiscriminatorConfig, Optional + Configuration for the discriminator and reward network. If None, default values are used. + + Lambda_rl : float, Optional + Weight for the reinforcement learning reward in the generator loss. Default is 0.25. + + device : Optional[Union[torch.device, str]], optional + Device to run the model on. If None, defaults to CPU or GPU if available. + + model_name : str, Optional + Name of the model. Default is "MolGANGenerativeModel". + """ + def __init__( self, - generator, discriminator, reward_net=None, - gen_config=None, disc_config=None, + generator_config: Optional[MolGANGeneratorConfig] = None, + discriminator_config: Optional[MolGANDiscriminatorConfig] = None, + lambda_rl: float = 0.25, device: Optional[Union[torch.device, str]] = None, - model_name: str = "MolGANMolecularGenerator" + model_name: str = "MolGANGenerativeModel", ): super().__init__(device=device, model_name=model_name) - self.generator = generator - self.discriminator = discriminator - self.reward_net = reward_net - self.gen_config = gen_config - self.disc_config = disc_config - - self.device = device if device is not None else torch.device("cpu") - self.generator.to(self.device) - self.discriminator.to(self.device) - if self.reward_net: self.reward_net.to(self.device) - self.is_fitted = False - - def fit(self, X: List[str], y: Optional[np.ndarray] = None, epochs=100, batch_size=32, **kwargs): + + # Initialize MolGAN model + self.model = MolGANModel( + generator_config=generator_config, + discriminator_config=discriminator_config, + reward_network_config=discriminator_config, + ).to(self.device) + + self.lambda_rl = lambda_rl + self.gen_optimizer = None + self.disc_optimizer = None + self.reward_optimizer = None + self.gen_scheduler = None + self.disc_scheduler = None + self.reward_scheduler = None + self.use_reward = False + + self.epoch = 0 + self.step = 0 + + def training_config( + self, + lambda_rl: float = 0.25, + reward_function: Optional[Callable] = None, + gen_optimizer: Optional[torch.optim.Optimizer] = None, + disc_optimizer: Optional[torch.optim.Optimizer] = None, + reward_optimizer: Optional[torch.optim.Optimizer] = None, + gen_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + disc_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + reward_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + use_reward: bool = True, + epochs: int = 300, + batch_size: int = 32, + atom_types: List[str] = ['C', 'N', 'O', 'F', 'P', 'S', 'Cl', 'Br', 'I', 'H'], + bond_types: List[str] = ['SINGLE', 'DOUBLE', 'TRIPLE', 'AROMATIC'], + max_num_atoms: int = 50, + ): """ - Train MolGAN on molecules (SMILES). + Configure training parameters for MolGAN. + + Parameters: + ---------- + lambda_rl : float, Optional + Weight for the reinforcement learning reward in the generator loss. Default is 0.25. + + reward_function : Optional[Callable], optional + + gen_optimizer : torch.optim.Optimizer + Optimizer for the generator network. + + disc_optimizer : torch.optim.Optimizer + Optimizer for the discriminator network. + + reward_optimizer : Optional[torch.optim.Optimizer], optional + Optimizer for the reward network. If None, the discriminator optimizer is used. + + gen_scheduler : Optional[torch.optim.lr_scheduler._LRScheduler], optional + Learning rate scheduler for the generator optimizer. + + disc_scheduler : Optional[torch.optim.lr_scheduler._LRScheduler], optional + Learning rate scheduler for the discriminator optimizer. + + reward_scheduler : Optional[torch.optim.lr_scheduler._LRScheduler], optional + Learning rate scheduler for the reward optimizer. + + use_reward : bool, optional + Whether to use the reward network during training. Default is True. + + epochs : int + Number of training epochs. Default is 300. + + atom_types : List[str] + List of atom types to consider in the molecular graphs. Default includes common organic atoms. + + bond_types : List[str] + List of bond types to consider in the molecular graphs. Default includes common bond types. + + max_num_atoms : int + Maximum number of atoms in the generated molecular graphs. Default is 50. """ - # 1. Prepare dataset (requires atom_types and bond_types in kwargs) - atom_types = kwargs.get("atom_types") - bond_types = kwargs.get("bond_types") - max_num_atoms = kwargs.get("max_num_atoms", 50) - dataset = MolGANDataset(X, atom_types, bond_types, max_num_atoms) - dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True) - # 2. Set up optimizers/schedulers etc (credits: your config_training & usual setup) - self.config_training( - gen_optimizer=kwargs["gen_optimizer"], - disc_optimizer=kwargs["disc_optimizer"], - reward_optimizer=kwargs.get("reward_optimizer", None), - gen_scheduler=kwargs.get("gen_scheduler", None), - disc_scheduler=kwargs.get("disc_scheduler", None), - reward_scheduler=kwargs.get("reward_scheduler", None), + + if gen_optimizer is None: gen_optimizer = torch.optim.Adam(self.model.gen.parameters(), lr=0.0001, betas=(0.5, 0.999)) + if disc_optimizer is None: disc_optimizer = torch.optim.Adam(self.model.disc.parameters(), lr=0.0001, betas=(0.5, 0.999)) + if reward_optimizer is None: reward_optimizer = disc_optimizer + + self.model.config_training( + gen_optimizer=gen_optimizer, + disc_optimizer=disc_optimizer, + lambda_rl=lambda_rl, + reward_optimizer=reward_optimizer, + gen_scheduler=gen_scheduler, + disc_scheduler=disc_scheduler, + reward_scheduler=reward_scheduler, ) - # 3. Training loop (calls training_step as above) - for epoch in range(epochs): - for batch in dataloader: - # Keep reward_fn optional for RL phase - reward_fn = kwargs.get("reward_fn", None) - pretrain = kwargs.get("pretrain", False) - self.training_step(batch, reward_fn, pretrain) - self.is_fitted = True - return self + self.lambda_rl = lambda_rl + self.reward_function = reward_function + self.gen_optimizer = gen_optimizer + self.disc_optimizer = disc_optimizer + self.reward_optimizer = ( + reward_optimizer if reward_optimizer is not None else disc_optimizer + ) + self.gen_scheduler = gen_scheduler + self.disc_scheduler = disc_scheduler + self.reward_scheduler = reward_scheduler + self.use_reward = use_reward + self.epochs = epochs + self.atom_types = atom_types + self.bond_types = bond_types + self.max_num_atoms = max_num_atoms + self.batch_size = batch_size + - def generate(self, n_samples: int, **kwargs) -> List[str]: + def fit( self, X:List[str], y=None ) -> "BaseMolecularGenerator": """ - Generate n_samples molecules as SMILES. + Fit the MolGAN model to the training data. + + Parameters: + ---------- + X : List[str] + List of SMILES strings representing the training molecules. + + y : Optional[np.ndarray], optional + Optional array of target values for supervised training. Default is None. (Not used in MolGAN) """ - self.generator.eval() - zs = torch.randn(n_samples, self.generator.z_dim, device=self.device) - with torch.no_grad(): - atom_logits, bond_logits = self.generator(zs) - # [n_samples, max_num_atoms, num_atom_types], [n_samples, num_bond_types, max_num_atoms, max_num_atoms] - atom_types = atom_logits.argmax(dim=-1).cpu().numpy() # [n_samples, max_num_atoms] - bond_types = bond_logits.argmax(dim=1).cpu().numpy() # [n_samples, max_num_atoms, max_num_atoms] - # Convert to SMILES (implement graph2smiles utility based on your atom_types/bond_types) - smiles_strings = [ - graph2smiles(atom_types[i], bond_types[i], kwargs.get("atom_types"), kwargs.get("bond_types")) - for i in range(n_samples) - ] - return smiles_strings + + if self.gen_optimizer is None or self.disc_optimizer is None: + # raise ValueError("Please configure training optimizers using `training_config()` before fitting.") + # Set default optimizers if not configured + self.training_config( + gen_optimizer=torch.optim.Adam(self.model.gen.parameters(), lr=0.0001, betas=(0.5, 0.999)), + disc_optimizer=torch.optim.Adam(self.model.disc.parameters(), lr=0.0001, betas=(0.5, 0.999)), + lambda_rl=0.25, + use_reward=True, + ) + + # Create a dataloader from the SMILES strings + dataset = MolGANDataset(data=X, atom_types=self.atom_types, bond_types=self.bond_types, max_num_atoms=self.max_num_atoms, return_mol=False, device=self.device) + dataloader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, shuffle=True) + + self.model.train() + for _ in range(self.epochs): + self.model.train_epoch( + dataloader, + reward_fn= None if not self.use_reward else self.reward_function + ) + + return self + + + diff --git a/torch_molecule/generator/molgan/molgan_model.py b/torch_molecule/generator/molgan/molgan_model.py index 1dba5de..07a3b68 100644 --- a/torch_molecule/generator/molgan/molgan_model.py +++ b/torch_molecule/generator/molgan/molgan_model.py @@ -15,7 +15,7 @@ def __init__( # Initialize generator and discriminator self.gen_config = generator_config if generator_config is not None else MolGANGeneratorConfig() - self.gen = MolGANGenerator(self.gen_config) + self.gen: MolGANGenerator = MolGANGenerator(self.gen_config) self.disc_config = discriminator_config if discriminator_config is not None else MolGANDiscriminatorConfig() self.disc = MolGANDiscriminator(self.disc_config) @@ -28,9 +28,20 @@ def __init__( ) - def generate(self, batch_size: int): + def generate(self, batch_size: int, sample_mode: Optional[str] = None): """Generate a batch of molecules.""" - return self.gen(batch_size) + z = torch.randn( + batch_size, + self.gen_config.z_dim, + device = torch.device(self.gen.device) + ) + if sample_mode is None: + if self.training: + return self.gen(z, sample_mode='softmax') + else: + return self.gen(z, sample_mode='argmax') + else: + return self.gen(z, sample_mode=sample_mode) def discriminate( @@ -78,7 +89,7 @@ def training_step(self, batch, reward_fn, pretrain=False): mask = (node_features.sum(-1) != 0).float() z = torch.randn(batch_size, self.gen_config.z_dim, device=node_features.device) - fake_atom_logits, fake_bond_logits = self.gen(z) + fake_atom_logits, fake_bond_logits = self.gen(z, sample_mode='softmax') fake_atom = torch.softmax(fake_atom_logits, -1) fake_bond = torch.softmax(fake_bond_logits, 1) fake_mask = (fake_atom.argmax(-1) != 0).float() @@ -106,9 +117,7 @@ def training_step(self, batch, reward_fn, pretrain=False): # === Generator update === self.gen_optimizer.zero_grad() - fake_atom_logits, fake_bond_logits = self.gen(z) - fake_atom = torch.softmax(fake_atom_logits, -1) - fake_bond = torch.softmax(fake_bond_logits, 1) + fake_atom_logits, fake_bond_logits = self.gen(z, sample_mode='softmax') fake_mask = (fake_atom.argmax(-1) != 0).float() fake_scores = self.disc(fake_atom, fake_bond, fake_mask) g_wgan_loss = -fake_scores.mean() @@ -140,7 +149,22 @@ def train_epoch(self, dataloader, reward_fn, pretrain=False, log_interval=100): if i % log_interval == 0: print({k: round(v, 5) for k, v in result.items()}) - + def evaluate(self, dataloader, reward_fn): + self.gen.eval() + self.disc.eval() + if self.reward_net: self.reward_net.eval() + eval_metrics = {'d_loss': 0.0, 'g_loss': 0.0, 'rl_loss': 0.0, 'r_loss': 0.0} + count = 0 + with torch.no_grad(): + for batch in dataloader: + result = self.training_step(batch, reward_fn, pretrain=False) + for k in eval_metrics.keys(): + if result[k] is not None: + eval_metrics[k] += result[k] + count += 1 + for k in eval_metrics.keys(): + eval_metrics[k] /= count + return {k: round(v, 5) for k, v in eval_metrics.items()} diff --git a/torch_molecule/generator/molgan/molgan_utils.py b/torch_molecule/generator/molgan/molgan_utils.py new file mode 100644 index 0000000..1f68ce3 --- /dev/null +++ b/torch_molecule/generator/molgan/molgan_utils.py @@ -0,0 +1,14 @@ +from rdkit.Chem import QED + +# This is used as the default reward function for MolGAN +def qed_reward_fn(mol): + """ + Computes the QED score of a single RDKit Mol object. + Returns 0.0 for invalid molecules or errors. + """ + if mol is not None: + try: + return QED.qed(mol) + except Exception: + return 0.0 + return 0.0