1D convolution requires 3-dimensional input
Closed this issue · 2 comments
samgoldman97 commented
test = ntorch.randn((3,5), names=("embedding", "time"))
conv = ntorch.nn.Conv1d(in_channels=3, out_channels=10, kernel_size=2).spec("embedding", "time", "convout")
conv(test)
Raises:
Expected 3-dimensional input for 3-dimensional weight [10, 3, 2], but got 2-dimensional input of size [3, 5] instead
Whereas if I unsqueeze the first dimension, I get the right output:
test = ntorch.randn((1,3,5), names=("batch", "embedding", "time"))
conv = ntorch.nn.Conv1d(in_channels=3, out_channels=10, kernel_size=2).spec("embedding", "time", "convout")
conv(test)
I propose that the wrapper should handle this broadcasting and unsqueeze the first dimension automatically if it's missing, perhaps in _Update
. Something along the lines of this could be a quick fix, but not sure what else it would affect.
class _Update:
def rename(self, **kwargs):
self._updates = kwargs
return self
def __call__(self, input):
if "_spec" in self.__dict__:
input = input.transpose(*self._input_order).contiguous()
updates = {k: v for (v, k) in self._output_update.items()}
if (len(self.weight.shape) == (1+len(input.shape)):
input = ntorch.ntensor(input.values.unsqueeze(0), names=(("broadcast",) + tuple(input.shape.keys()))
return input.op(super(_Update, self).forward, **updates)
else:
updates = {} if "_updates" not in self.__dict__ else self._updates
return input.op(super(_Update, self).forward, **updates)
srush commented
Sure, I'm for this. Basically all NN layers require a batch
dim.