VPDSeg model state dictionary keys missing
apenzko opened this issue · 1 comments
Hi,
I'm not able to get the VPD Segmentation model running.
I have downloaded the stable diffusion checkpoint v1-5-pruned-emaonly.ckpt
from https://github.com/runwayml/stable-diffusion.
Code
The following code produces the error:
cfg = mmcv.Config.fromfile(config_path)
# build the model and load checkpoint
cfg.model.train_cfg = None
model = build_segmentor(cfg.model, test_cfg=cfg.get('test_cfg'))
fp16_cfg = cfg.get('fp16', None)
if fp16_cfg is not None:
print('not none')
wrap_fp16_model(model)
checkpoint = load_checkpoint(model, ckpt_path, map_location='cpu')
Output
LatentDiffusion: Running in eps-prediction mode
DiffusionWrapper has 859.52 M params.
making attention of type 'vanilla' with 512 in_channels
Working with z of shape (1, 4, 32, 32) = 4096 dimensions.
making attention of type 'vanilla' with 512 in_channels
Restored from ../checkpoints/v1-5-pruned-emaonly.ckpt with 0 missing and 199 unexpected keys
Unexpected Keys: ['model_ema.decay', 'model_ema.num_updates', 'cond_stage_model.transformer.text_model.embeddings.position_ids', ... 'cond_stage_model.transformer.text_model.final_layer_norm.bias']
load checkpoint from local path: ../checkpoints/fpn_vpd_sd1-5_512x512_slim.ckpt
.conda/envs/vpd_env/lib/python3.8/site-packages/mmseg/models/losses/cross_entropy_loss.py:235: UserWarning: Default avg_non_ignore
is False, if you would like to ignore the certain label and average loss over non-ignore labels, which is the same with PyTorch official cross_entropy, set avg_non_ignore=True
.
warnings.warn(
The model and loaded state dict do not match exactly
unexpected key in source state_dict: sd_model.lvlb_weights
missing keys in source state_dict: encoder_vq.encoder.conv_in.weight, encoder_vq.encoder.conv_in.bias, ...
Has anyone encountered this or a similar issue and can tell me how to fix it?
Grateful for any help!
Turns out it works despite the missing keys!