arbitrary number of batch dimensions and returned samples
forgi86 opened this issue · 0 comments
forgi86 commented
Hi,
Thanks for the very nice package, I found it very useful! I added two features in this fork:
https://github.com/forgi86/mixture-density-network
- Support for more than one "batch" dimension. For instance, my version also works adding to the example script:
num_seq = 32
seq_len = 16
x = x.reshape(num_seq, seq_len, nx)
y = y.reshape(num_seq, seq_len, ny)
before training. My real use case is to use the mixture density head on top of a RNN backbone for sequential data (as the variable names in the code snipped above suggest). With this modification I do not have reshape to apply the mixture head in my code. It can be fed with the RNN features directly.
- Support for sampling multiple values. Added parameters
samples
andsqueeze
to thesample
method theMixtureDensityNetwork
class. The default settings (samples=1, squeeze=True
) are backward compatible. If insteadsamples > 1
, then the second-last dimension of the returned tensor is the sample index.
If you find either of the two changes useful, I can make a PR (they can go independently).
Cheers,
Marco