coreylowman/dfdx

Grouped Linear layers

opfromthestart opened this issue · 1 comments

I have a model that needs to be able to run in real time, and one of the layers has a width of 6000, which makes it take about 0.3 seconds to run (I need 30-60fps). One solution I can think of is a grouped linear layer, where GroupLinear<I, O, C> takes an input of (B, C*I) and returns an output of (B, C*O) where the C blocks of I inputs are all processed independently of each other. I don't think new kernels would need to be written as it just does C smaller linear layers.

The storage of the weights should probably be approximately a Tensor<Rank3<C, I, O>,...>.