yuzuhais/CondConv-pytorch

Can't support batch size >1

XiaoqiangZhou opened this issue · 2 comments

When the batch size > 1, will get a error of mismatch size for shape. The reason is beacuse the operation of flatten and dropout in defination of _routing. Hope you can fix this bug.

lxtGH commented

A simple solution is to modify the forward functions as the original tensorflow code did.

def forward(self, inputs):
    b, _, _, _ = inputs.size()
    res = []
    for input in inputs:
        input = input.unsqueeze(0)
        pooled_inputs = self._avg_pooling(input)
        routing_weights = self._routing_fn(pooled_inputs)
        kernels = torch.sum(routing_weights[: ,None, None, None, None] * self.weight, 0)
        out = self._conv_forward(input, kernels)
        res.append(out)
    return torch.cat(res, dim=0)

@lxtGH , thank you very much for your cooperation. #4
I haven't closed the issue, so I'll close it. sorry.