Skip to content

[Feature Request] Data modifier in pytorch #4626

@ChiahsinChu

Description

@ChiahsinChu

Summary

Dear developers,

I plan to add the data modifier into the Pytorch backend and implement the DPLR method there. While the pre-processing of data is straightforward, combining the model and the data modifier into one frozen model is still unclear to me. I am thinking about creating a wrapper class after training, as shown below. Do you have any suggestions or comments about this issue?

Best.

Detailed Description

class DataModifierModelWrapper(torch.nn.Module):
    def __init__(
        self,
        model: torch.nn.Module,
        modifier: Optional[torch.nn.Module] = None,
    ) -> None:
        super().__init__()
        self.model = model
        if modifier is not None:
            # Freeze modifier parameters
            for p in modifier.parameters():
                p.requires_grad = False
        self.modifier = modifier

    def forward(
        self,
        coord,
        atype,
        spin: Optional[torch.Tensor] = None,
        box: Optional[torch.Tensor] = None,
        do_atomic_virial=False,
        fparam: Optional[torch.Tensor] = None,
        aparam: Optional[torch.Tensor] = None,
    ):
        input_dict = {
            "coord": coord,
            "atype": atype,
            "box": box,
            "do_atomic_virial": do_atomic_virial,
            "fparam": fparam,
            "aparam": aparam,
        }
        has_spin = getattr(self.model, "has_spin", False)
        if callable(has_spin):
            has_spin = has_spin()
        if has_spin:
            input_dict["spin"] = spin

        model_pred = self.model(**input_dict)
        if self.modifier is not None:
            modifier_pred = self.modifier(**input_dict)
            for k, v in modifier_pred.items():
                model_pred[k] = model_pred[k] + v
        return model_pred

Further Information, Files, and Links

No response

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions