Add list to torch.Tensor injection in yaml config
fguiotte opened this issue · 0 comments
Description & Motivation
I'd like to configure the weight of the loss in the config.yaml
.
Pitch
For example with this configuration file:
model:
out_channels: 3
criterion:
class_path: torch.nn.CrossEntropyLoss
init_args:
reduction: 'mean'
weight: [0.0, 1.0, 1.0]
The weight parameter should be a torch.Tensor
and not a Python list. So with my current config, instead of the loss instance, it returns a dictionary {'class_path': 'torch.nn.CrossEntropyLoss', 'init_args': {'reduction': 'sum', 'weight': [0.0, 1.0, 1.0]}}
for the criterion
parameter. If we remove the weight
parameter in the example configuration, it creates the loss instance as expected.
It would be helpful to have a Python list to `torch.Tensor injection to automatically create the tensors when needed.
Alternatives
In the mean time, I can create the loss myself in the model __init__
with the returned dictionary.
Additional context
No response
cc @Borda