-
Notifications
You must be signed in to change notification settings - Fork 559
Open
Labels
Description
Summary
Dear developers,
I plan to add the data modifier into the Pytorch backend and implement the DPLR method there. While the pre-processing of data is straightforward, combining the model and the data modifier into one frozen model is still unclear to me. I am thinking about creating a wrapper class after training, as shown below. Do you have any suggestions or comments about this issue?
Best.
Detailed Description
class DataModifierModelWrapper(torch.nn.Module):
def __init__(
self,
model: torch.nn.Module,
modifier: Optional[torch.nn.Module] = None,
) -> None:
super().__init__()
self.model = model
if modifier is not None:
# Freeze modifier parameters
for p in modifier.parameters():
p.requires_grad = False
self.modifier = modifier
def forward(
self,
coord,
atype,
spin: Optional[torch.Tensor] = None,
box: Optional[torch.Tensor] = None,
do_atomic_virial=False,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
):
input_dict = {
"coord": coord,
"atype": atype,
"box": box,
"do_atomic_virial": do_atomic_virial,
"fparam": fparam,
"aparam": aparam,
}
has_spin = getattr(self.model, "has_spin", False)
if callable(has_spin):
has_spin = has_spin()
if has_spin:
input_dict["spin"] = spin
model_pred = self.model(**input_dict)
if self.modifier is not None:
modifier_pred = self.modifier(**input_dict)
for k, v in modifier_pred.items():
model_pred[k] = model_pred[k] + v
return model_pred
Further Information, Files, and Links
No response