Ridding `_check_2d`
lxuechen opened this issue · 2 comments
I'm proposing to get rid of the 2d shape checks. These were added in #88 as part of the v0.2.4.
These checks are creating a huge barrier for applications that don't have vectorized data naturally. Flattening and unflattening will hurt efficiency tremendously.
The reason they're there is that several parts of the internal code -- for example misc.batch_mvp
and ForwardSDE.dg_ga_jvp
-- implicitly assume there is a single batch dimension.
If we can go through and sort those out then I am in favour of this; I agree this is a wart. Ideally we've be able to have y0
take an arbitrary shape.
Off the top of my head I think the only parts of the code that needs to distinguish batch dimensions from channel dimensions is when creating a default Brownian motion (needing one sample per batch but not one per channel), and ForwardSDE.dg_ga_jvp
, so those would need some way of specifying that detail.
In passing, why is flattening/unflattening hurting efficiency? It should be doable just be re-striding the tensor, which is cheap.
I am completely aware of why they are needed. I will come up with a design doc next weekend.