vislearn/FrEIA

Dimension error about the `SequenceINN`

ffhibnese opened this issue · 4 comments

My input tensor is 18 x 512, so I initialize the SequenceINN with dim [18, 512]. And then I put a 1 x 18 x 512 tensor into the INN, and the process terminated and raised RuntimeError: mat1 and mat2 cannot be multiplied.

Is there any mistake I made when I applied the related code?

Thanks for the question. Can you provide isolated code that shows how you construct your network, and a stack trace? This would greatly simplify finding the error.

Okay, I'll show my code in detail.
I constructed the network with the following function:

def create_inn(style_dim, n_layer, block, c_dim=None):
    # Define the INN.
    # Affine coupling block
    def subnet_fc(dims_in, dims_out):
        return nn.Sequential(
            nn.Linear(dims_in, 256),
            nn.LeakyReLU(0.1),
            nn.Linear(256, dims_out),
        )

    if block == 'all_in_one':
        style_component = SequenceINN(style_dim)
        # A simple chain of operations is collected by ReversibleSequential
        for k in range(n_layer):
            if c_dim is not None:
                style_component.append(AllInOneBlock, cond=0, cond_shape=(c_dim, ), subnet_constructor=subnet_fc, permute_soft=True)
            else:
                style_component.append(AllInOneBlock, subnet_constructor=subnet_fc, permute_soft=True)

    else:
        raise ValueError()

    print("Number of trainable parameters of INN: {}".format(count_parameters(style_component)))
    return style_component

Then I call it self.inn = create_inn( [18, 512], 8, "all_in_one", )
Later I put a 1x18x512 tensor into self.inn, and the aforementioned RuntimeError was raised.

Thanks for your reply.

Thanks for the code! The subnet constructor assumes that a tensor with one data dimension is passed, but the actual data is 2D. You need to use a subnetwork that can process such tensors, e.g. a Conv1D

Thanks for the code! The subnet constructor assumes that a tensor with one data dimension is passed, but the actual data is 2D. You need to use a subnetwork that can process such tensors, e.g. a Conv1D

Thanks for your clear reply, my problem is solved successfully!