`groups` support for Conv1D?
purefunctor opened this issue · 3 comments
Hi, this is a really awesome project!
I'm trying to port a model that makes use of 1D convolutions but the immediate thing I ran into was that the Conv1D layer didn't have a parameter for groups
, as present in PyTorch/Tensorflow. Learning resources on low-level NN programming is a little terse, but I'd like to tackle this!
My (high-level) understanding of it goes something along the lines of:
from torchinfo import summary
x = torch.nn.Conv1d(9, 3, kernel_size=1, groups=1, bias=False)
y = torch.nn.Conv1d(9, 3, kernel_size=1, groups=3, bias=False)
summary(x)
summary(y)
Which gives:
=================================================================
Layer (type:depth-idx) Param #
=================================================================
Conv1d 27
=================================================================
Total params: 27
Trainable params: 27
Non-trainable params: 0
=================================================================
=================================================================
Layer (type:depth-idx) Param #
=================================================================
Conv1d 9
=================================================================
Total params: 9
Trainable params: 9
Non-trainable params: 0
=================================================================
In the case of x
, with groups=1
, each input channel has its own 1x1 kernel. This is reflected in the weights:
tensor([[[-0.2372],
[-0.0259],
[ 0.0075],
[ 0.0602],
[ 0.0205],
[-0.2290],
[-0.3304],
[ 0.1302],
[ 0.0697]],
[[ 0.3148],
[-0.1045],
[ 0.2087],
[ 0.2184],
[ 0.2869],
[-0.1255],
[-0.0349],
[ 0.2754],
[ 0.0341]],
[[-0.0065],
[ 0.0904],
[ 0.1445],
[ 0.0337],
[-0.0661],
[ 0.2763],
[-0.1375],
[ 0.0841],
[-0.0864]]]) torch.Size([3, 9, 1])
For each output channel, there are 9 1x1 kernels which correspond to each input channel.
Meanwhile, for the case of y
with groups=3
, it has the following weights:
tensor([[[-0.0507],
[ 0.1446],
[ 0.1827]],
[[-0.1260],
[ 0.2465],
[ 0.5095]],
[[ 0.4771],
[ 0.1377],
[-0.0265]]]) torch.Size([3, 3, 1])
For each output channel, there are 3 1x1 kernels which correspond to each input channel group.
An intuitive way I've found to see how this works is:
x.weight.requires_grad = False
y.weight.requires_grad = False
torch.nn.init.constant_(x.weight[0], 1.0)
torch.nn.init.constant_(x.weight[1], 0.9)
torch.nn.init.constant_(x.weight[2], 0.8)
torch.nn.init.constant_(y.weight[0], 1.0)
torch.nn.init.constant_(y.weight[1], 0.9)
torch.nn.init.constant_(y.weight[2], 0.8)
i = torch.ones(9, 2)
print(i)
print(x(i))
print(y(i))
and this yields:
tensor([[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.]])
tensor([[9.0000, 9.0000],
[8.1000, 8.1000],
[7.2000, 7.2000]])
tensor([[3.0000, 3.0000],
[2.7000, 2.7000],
[2.4000, 2.4000]])
The result of y(i)
yields significantly less "energy" than x(i)
as each output channel now has less kernels to work with, 3 instead of 9.
Hello!
Thanks for the issue, and for the exploration of the "groups" functionality. I'll just add the description that I found in the Conv1D documentation:
groups controls the connections between inputs and outputs. in_channels and out_channels must both be divisible by groups. For example,
- At groups=1, all inputs are convolved to all outputs.
- At groups=2, the operation becomes equivalent to having two conv layers side by side, each seeing half the input channels and producing half the output channels, and both subsequently concatenated.
- At groups= in_channels, each input channel is convolved with its own set of filters (of size out_channels/in_channels).
I think I have a rough sense of how to implement this, but it may be a few weeks until I have time to really sit down and work on it. With that in mind, I guess I'll list out my workflow for how I would probably approach the problem, and then if you (or someone else) would like to take a shot at, that would be cool!
- Set up a test network. You're already most of the way there, but I'd probably make a copy of the conv1d_torch model used for testing, and maybe even add some layers to make sure that some edge cases are being covered (different kernel sizes, dilations, etc).
- Set up a test program. Again, maybe starting with a copy of torch_conv1d_test would be easiest.
- Modify the STL Conv1D implementation to add "groups" functionality:
- There's no reason you have to start with the STL implementation, but it's usually the one I find most read-able
- I usually start with the templated version of the layer, but there's no reason you have to.
- Figure out how to handle the differences in the weights layout. Since the "grouped" Conv1D seems to have a different number of weights, you'll probably need to adjust the way the weights are stored in memory (i.e. the sizes of some vectors/arrays) and the way they're loaded into memory.
- Modify the
forward()
method to give the correct output. This is where having the test case with your reference PyTorch model is very helpful. My gut feeling is that you'll need to put this loop inside a loop over the number of groups. - "Port" your changes to the other RTNeural backends. This can be a bit difficult, and is something that I'm definitely prepared to help with.
- Go back and check that the other Conv1D tests still pass
- Run the layer benchmarks to make sure that we haven't made the Conv1D layer slower when
groups = 1
Anyway, hopefully this is helpful in case you or anyone wants to tackle this. If not, just give it a little time while I finish up some other things, and I can come back to this. If you do start working on it, feel free to message me with any questions or intermediate progress updates!
That's indeed really helpful, yup!
I'm trying to understand what the "state" and "state_cols" are in the convolutions, do you have any pointers on that?
Another good way to think about grouped convolutions is that each group can be thought of as if they had their own convolutions--as in, if I had a 9in->9out
convolution with 3 groups, I'd have 3 3in->3out
convolutions that'll get summed up.
Very cool! For state_cols
the idea is that it's a "helper" variable to store only the columns of the state that will be multiplied by the weights (they're not guaranteed to be contiguous depending on the dilation rate). See how the state_cols
are set here.