atomistic-machine-learning/schnetpack

Question about the target name

Closed this issue · 5 comments

As proposed by this issue, I also encounter the same problem.
#549

But the issue has already been closed without explanation.
Could you please help with this question?

Thanks in advance.

Best,
Zeta

Hi @MetaEnigma97,

I am not fully understanding the issue. Could you provide some more details about what you are doing?

Are you implementing your own training method or do you use the Trainer of pytorch-lightning?

Or are you writing your own LightningModule?
In that case you should check out the training_step() function (see: here). You can not just passpred and batch to the loss_fn, because the values will be overwritten. Instead create a new dictionary called targets, where you collect the target properties before doing the forward pass. Then pass targets and pred to the loss_fn.

Does this help to solve your problem?

Best,
Stefaan

Hi @Stefaanhess ,

Thank you very much for your quick reply.

Actually, I am trying to customize my own training data (atoms: ASE.Atoms, property: a 3 x 3 equivariant tensor property) and use Painn + Gated Equivariant Block to predict such tensor property.

The codes for the data customization part is like:

atoms_list = []
propertylist = []
for pos, cell, num, pbc, prop in zip(pos_list, cell_list, num_list, pbc_list, tensor_props):
    atoms = Atoms(
        positions=pos,
        cell=cell,
        numbers=num,
        pbc=pbc)
    atoms_list.append(atoms)

    property_list.append(
        {"tensor_prop": np.array(prop)})

The model is defined as

cutoff = 5.
n_atom_basis = 30

pairwise_distance = spk.atomistic.PairwiseDistances() # calculates pairwise distances between atoms
radial_basis = spk.nn.GaussianRBF(n_rbf=20, cutoff=cutoff)
painn = spk.representation.PaiNN(
    n_atom_basis=n_atom_basis,
    n_interactions=3,
    radial_basis=radial_basis,
    cutoff_fn=spk.nn.CosineCutoff(cutoff)
)

pred_tensor = spk.atomistic.Polarizability(
    n_in=n_atom_basis,
    polarizability_key="tensor_prop")

nnpot = spk.model.NeuralNetworkPotential(
    representation=painn,
    input_modules=[pairwise_distance],
    output_modules=[pred_tensor],
    postprocessors=[
        trn.CastTo64(),
    ]
)

output_energy = spk.task.ModelOutput(
    name="tensor_prop",
    loss_fn=torch.nn.MSELoss(),
    loss_weight=0.01,
    metrics={
        "MAE": torchmetrics.MeanAbsoluteError()
    }
)

task = spk.task.AtomisticTask(
    model=nnpot,
    outputs=[output_energy],
    optimizer_cls=torch.optim.AdamW,
    optimizer_args={"lr": 1e-4}
)

I found the label data of "tensor_prop" will be overwritten by model prediction results from the previous epoch during training.
Could you please share about how I should modify the code to avoid overwriting the label data using my customized dataset?

Thanks!

Best,
Zeta

This looks reasonable to me. I am not sure what is going on here. How did you find out, that the data is overwritten with the previous epoch (if possible: it would be helpful to see an example)? How do you train your model? do you use theTrainer from pytorch-lightning for this?
This how the training step looks in Task:

        targets = {
            output.target_property: batch[output.target_property]
            for output in self.outputs
            if not isinstance(output, UnsupervisedModelOutput)
        }
        try:
            targets["considered_atoms"] = batch["considered_atoms"]
        except:
            pass

        pred = self.predict_without_postprocessing(batch)
        pred, targets = self.apply_constraints(pred, targets)

        loss = self.loss_fn(pred, targets)

The use of targets prevents the overwriting of the data in the loss function. Is the loss function causing you this issue, or is it at a part of your code?

@Stefaanhess Thank you for your information.

I just used the Trainer from pytorch-lightning for the training.

I also tried a scalar prediction (average of the 3x3 tensor) by replacing the class spk.atomistic.Polarizability with spk.atomistic.Atomwise in the above codes.
I printed the batch["tensor_prop"] value step by step.
It seems that this lineinputs[self.output_key] = y in class spk.atomistic.Atomwise will overwrite the label data, and the value of batch["tensor_prop"] starts to change from here.
Maybe similar things happen in the line inputs[self.polarizability_key] = alpha in the class spk.atomistic.Polarizability.

I am not sure if it is a bug or my misusage.
I would appreciate if you could help with this.

Ah ok, I think I know what is going on.
The data will be overwritten in batch. But batch will not be used in the loss_fn. We first extract the properties to targets and then use pred=model(inputs). If you print targets and pred after every training step, they should do the right thing. You can ignore the values of batch are after the forward pass. Only use pred and targets!