You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm proposing to get rid of the 2d shape checks. These were added in #88 as part of the v0.2.4.
These checks are creating a huge barrier for applications that don't have vectorized data naturally. Flattening and unflattening will hurt efficiency tremendously.
The text was updated successfully, but these errors were encountered:
The reason they're there is that several parts of the internal code -- for example misc.batch_mvp and ForwardSDE.dg_ga_jvp -- implicitly assume there is a single batch dimension.
If we can go through and sort those out then I am in favour of this; I agree this is a wart. Ideally we've be able to have y0 take an arbitrary shape.
Off the top of my head I think the only parts of the code that needs to distinguish batch dimensions from channel dimensions is when creating a default Brownian motion (needing one sample per batch but not one per channel), and ForwardSDE.dg_ga_jvp, so those would need some way of specifying that detail.
In passing, why is flattening/unflattening hurting efficiency? It should be doable just be re-striding the tensor, which is cheap.
I'm proposing to get rid of the 2d shape checks. These were added in #88 as part of the v0.2.4.
These checks are creating a huge barrier for applications that don't have vectorized data naturally. Flattening and unflattening will hurt efficiency tremendously.
The text was updated successfully, but these errors were encountered: