tonyduan/mixture-density-network

arbitrary number of batch dimensions and returned samples

forgi86 opened this issue · 0 comments

Hi,

Thanks for the very nice package, I found it very useful! I added two features in this fork:

https://github.com/forgi86/mixture-density-network

  1. Support for more than one "batch" dimension. For instance, my version also works adding to the example script:
    num_seq = 32
    seq_len = 16
    x = x.reshape(num_seq, seq_len, nx)
    y = y.reshape(num_seq, seq_len, ny)

before training. My real use case is to use the mixture density head on top of a RNN backbone for sequential data (as the variable names in the code snipped above suggest). With this modification I do not have reshape to apply the mixture head in my code. It can be fed with the RNN features directly.

  1. Support for sampling multiple values. Added parameters samples and squeeze to the sample method the MixtureDensityNetwork class. The default settings (samples=1, squeeze=True) are backward compatible. If instead samples > 1, then the second-last dimension of the returned tensor is the sample index.

If you find either of the two changes useful, I can make a PR (they can go independently).

Cheers,
Marco