wl-zhao/VPD

VPDSeg model state dictionary keys missing

apenzko opened this issue · 1 comments

Hi,
I'm not able to get the VPD Segmentation model running.

I have downloaded the stable diffusion checkpoint v1-5-pruned-emaonly.ckpt from https://github.com/runwayml/stable-diffusion.

Code

The following code produces the error:

cfg = mmcv.Config.fromfile(config_path)

# build the model and load checkpoint
cfg.model.train_cfg = None
model = build_segmentor(cfg.model, test_cfg=cfg.get('test_cfg'))

fp16_cfg = cfg.get('fp16', None)
if fp16_cfg is not None:
    print('not none')
    wrap_fp16_model(model)
    
checkpoint = load_checkpoint(model, ckpt_path, map_location='cpu')

Output

LatentDiffusion: Running in eps-prediction mode
DiffusionWrapper has 859.52 M params.
making attention of type 'vanilla' with 512 in_channels
Working with z of shape (1, 4, 32, 32) = 4096 dimensions.
making attention of type 'vanilla' with 512 in_channels
Restored from ../checkpoints/v1-5-pruned-emaonly.ckpt with 0 missing and 199 unexpected keys
Unexpected Keys: ['model_ema.decay', 'model_ema.num_updates', 'cond_stage_model.transformer.text_model.embeddings.position_ids', ... 'cond_stage_model.transformer.text_model.final_layer_norm.bias']

load checkpoint from local path: ../checkpoints/fpn_vpd_sd1-5_512x512_slim.ckpt
.conda/envs/vpd_env/lib/python3.8/site-packages/mmseg/models/losses/cross_entropy_loss.py:235: UserWarning: Default avg_non_ignore is False, if you would like to ignore the certain label and average loss over non-ignore labels, which is the same with PyTorch official cross_entropy, set avg_non_ignore=True.
warnings.warn(
The model and loaded state dict do not match exactly

unexpected key in source state_dict: sd_model.lvlb_weights

missing keys in source state_dict: encoder_vq.encoder.conv_in.weight, encoder_vq.encoder.conv_in.bias, ...

Has anyone encountered this or a similar issue and can tell me how to fix it?
Grateful for any help!

Turns out it works despite the missing keys!