CompRhys/aviary

`Roost.forward()` and `Wren.forward()` return type tuple or generator?

Closed this issue · 2 comments

Going through the codebase to add type hints reminded me that I meant to ask if the return type Generator[Tensor, None, None] for Roost.forward() and Wren.forward() is intended?

def forward(self, elem_weights, elem_fea, self_fea_idx, nbr_fea_idx, cry_elem_idx):
"""
Forward pass through the material_nn and output_nn
"""
crys_fea = self.material_nn(
elem_weights, elem_fea, self_fea_idx, nbr_fea_idx, cry_elem_idx
)
crys_fea = F.relu(self.trunk_nn(crys_fea))
# apply neural network to map from learned features to target
return (output_nn(crys_fea) for output_nn in self.output_nns)

If the type should be tuple[Tensor] instead, we'd need to change to

return tuple(output_nn(crys_fea) for output_nn in self.output_nns) 

Can't think of any benefit to it being a generator so maybe best to make it a tuple?

Agreed.