What is meant by global conditioning only.
btickell opened this issue · 1 comments
I am attempting to use your wavenet implementation to model some climate data, where my condition vector changes with time. The code mentions only global conditioning is currently supported. What exactly does this mean from an architecture perspective?
Hey, yes you are right, local conditioning is not fully supported right now.
Let me explain the limitations:
- in training you would need to check if the conditioning is global (BxCx1) or local (BxCxT). When local, you need to remove the last element to make tensor lengths of input and conditioning match, i.e.
cond= cond[..,, :-1]
. Also in training you would be limited tohorizon==1
because of the following point - in generation there is currently no notion for local conditions. You will need one condition per generated sampled. Currently, the sampler function does not account for optional condition return value
tuple[sample,condition]
. Alternatively one could pass a local conditioning tensor to the generator function, but this would you to generateN==len(cond)
samples. Its currently not clear to me how to best support local conditioning from an API perspective. If you have an idea, let me know!
The second point affects the first point as follows: when horizon>1
in training, the training loop calls a differentiable generator to generate an n-step prediction. Since an API for local conditioning in generation is missing, the training is currently limited to horizon==1
.
If you are willing to stick with horizon==1
, the required update to the training loop is minimal. For generation, one could write a quick hack for generate
and generate_fast
that takes an optional local conditioning tensor and passes it one by one when calling model.forward(x_i, c=local_cond[..., i:i+1])
.