Can't support batch size >1
XiaoqiangZhou opened this issue · 2 comments
XiaoqiangZhou commented
When the batch size > 1, will get a error of mismatch size for shape. The reason is beacuse the operation of flatten and dropout in defination of _routing
. Hope you can fix this bug.
lxtGH commented
A simple solution is to modify the forward functions as the original tensorflow code did.
def forward(self, inputs):
b, _, _, _ = inputs.size()
res = []
for input in inputs:
input = input.unsqueeze(0)
pooled_inputs = self._avg_pooling(input)
routing_weights = self._routing_fn(pooled_inputs)
kernels = torch.sum(routing_weights[: ,None, None, None, None] * self.weight, 0)
out = self._conv_forward(input, kernels)
res.append(out)
return torch.cat(res, dim=0)