wl-zhao/VPD

VPDSeg doesn't run - questions on backbone parameters

DianCh opened this issue · 3 comments

Hi! Thank you for releasing this wonderful work. I have a couple questions when playing with the backbone from this repo:

  1. The output channels from self.unet(latents, t, c_crossattn=[c_crossattn]) are [320, 650, 1290, 1280], why do you have [320, 790, 1430, 1280] for the FPN in the VPDSeg config? Am I missing anything? Using the config I got the following error:
RuntimeError: Given groups=1, weight of size [256, 790, 1, 1], expected input[1, 650, 32, 32] to have 790 channels, but got 650 channels instead
  1. When using distributed training, I got errors from parameters not receiving gradients, for the following parameters:
img_backbone.unet.unet.diffusion_model.out.0.weight
img_backbone.unet.unet.diffusion_model.out.0.bias
img_backbone.unet.unet.diffusion_model.out.2.weight
img_backbone.unet.unet.diffusion_model.out.2.bias

which seems due to self.out from UNetModel is never used in the forward in the wrapper:

def register_hier_output(model):
    self = model.diffusion_model
    from ldm.modules.diffusionmodules.util import checkpoint, timestep_embedding

    def forward(x, timesteps=None, context=None, y=None, **kwargs):
        """
        Apply the model to an input batch.
        :param x: an [N x C x ...] Tensor of inputs.
        :param timesteps: a 1-D batch of timesteps.
        :param context: conditioning plugged in via crossattn
        :param y: an [N] Tensor of labels, if class-conditional.
        :return: an [N x C x ...] Tensor of outputs.
        """
        assert (y is not None) == (self.num_classes
                                   is not None), "must specify y if and only if the model is class-conditional"
        hs = []
        t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
        emb = self.time_embed(t_emb)

        if self.num_classes is not None:
            assert y.shape == (x.shape[0], )
            emb = emb + self.label_emb(y)

        h = x.type(self.dtype)
        for module in self.input_blocks:
            # import pdb; pdb.set_trace()
            h = module(h, emb, context)
            hs.append(h)
        h = self.middle_block(h, emb, context)
        out_list = []

        for i_out, module in enumerate(self.output_blocks):
            h = torch.cat([h, hs.pop()], dim=1)
            h = module(h, emb, context)
            if i_out in [1, 4, 7]:
                out_list.append(h)
        h = h.type(x.dtype)

        out_list.append(h)
        return out_list

    self.forward = forward

I'm using the SD commit 21f890f as suggested. To double-check if I'm using the right thing without manually deleting self.out, how did you avoid this issue in the training?

Can you please help me these questions? Thank you very much in advance!

Hi, thanks for your interest in our work. I hope the following should be helpful.

  1. The output channels of the UNetWrapper are dependent on the number of classes. In our VPDSeg, we have 150 categories in ADE20K dataset, therefore the output channels should be [320, 640 + 150, 1280 + 150, 1280].

  2. Since the final output layer of the UNet produces a 4-channel output (which is too small for a feature map), we drop the output layer and use the feature map before it instead. The gradient error can be solved by either deleting the out layer or setting find_unused_parameters=True in the DDP constructor.

@wl-zhao Thank you for the reply! I got the channel error by using the released VPDSeg config - how should I change the code to incorporate the 150 classes into UNetWrapper? Because the model doesn't seem to have the right channels.

To be more specific, I tried setting a break point right after model = build_segmentor(cfg.model, train_cfg=cfg.get('train_cfg'), test_cfg=cfg.get('test_cfg')) in train.py, and tested the backbone with the following:

>>> a = torch.randn(1, 3, 512, 512)
>>> feat = model.extract_feat(a).shape
>>> for x in feat:
        print(x.shape)
torch.Size([1, 320, 64, 64])
torch.Size([1, 650, 32, 32])
torch.Size([1, 1290, 16, 16])
torch.Size([1, 1280, 8, 8])

It seems that a wrong class_embeddings.pth (which contains only 10 classes) is used. Please refer to this url for the correct one.