XPixelGroup/HAT

各位大佬好,这个放大模型是SwinIR架构的吗?能适配auto1111的Stable diffusion来做图像放大吗?

Closed this issue · 7 comments

我不是很懂,这个模型是SwinIR架构的吗?
我放到auto1111的Stable diffusion webui /model/SwinIR 模型目录下面,貌似无法读取。

看了下它的实现是目前只能读
03_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth
这个模型是是吗?
https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/extensions-builtin/SwinIR/swinir_model_arch.py

有没有可能让webui 使用这个目前最强的放大算法呢?

chxy95 commented

@Erwin11 看了下SDwebui目前并不支持使用HAT模型。可以尝试给他们去提一下需求看看~

我放错目录了, 应该放在“stable-diffusion-webui/models/ESRGAN” 目录下, 可以识别出模型,但是无法使用,应该是webui不支持。

补充下webui 的不支持提示:KeyError: 'trunk_conv.weight'

放在“stable-diffusion-webui/models/SwinIR” 目录下是对的,Webui 现在不支持,报错如下:

RuntimeError: Error(s) in loading state_dict for SwinIR: Missing key(s) in state_dict: "layers.0.residual_group.blocks.0.attn.relative_position_index", "layers.0.residual_group.blocks.1.attn_mask", "layers.0.residual_group.blocks.1.attn.relative_position_index", "layers.0.residual_group.blocks.2.attn.relative_position_index", "layers.0.residual_group.blocks.3.attn_mask", "layers.0.residual_group.blocks.3.attn.relative_position_index", "layers.0.residual_group.blocks.4.attn.relative_position_index", "layers.0.residual_group.blocks.5.attn_mask", "layers.0.residual_group.blocks.5.attn.relative_position_index", "layers.0.conv.0.weight", "layers.0.conv.0.bias", "layers.0.conv.2.weight", "layers.0.conv.2.bias", "layers.0.conv.4.weight", "layers.0.conv.4.bias", "layers.1.residual_group.blocks.0.attn.relative_position_index", "layers.1.residual_group.blocks.1.attn_mask", "layers.1.residual_group.blocks.1.attn.relative_position_index", "layers.1.residual_group.blocks.2.attn.relative_position_index", "layers.1.residual_group.blocks.3.attn_mask", "layers.1.residual_group.blocks.3.attn.relative_position_index", "layers.1.residual_group.blocks.4.attn.relative_position_index", "layers.1.residual_group.blocks.5.attn_mask", "layers.1.residual_group.blocks.5.attn.relative_position_index", "layers.1.conv.0.weight", "layers.1.conv.0.bias", "layers.1.conv.2.weight", "layers.1.conv.2.bias", "layers.1.conv.4.weight", "layers.1.conv.4.bias", "layers.2.residual_group.blocks.0.attn.relative_position_index", "layers.2.residual_group.blocks.1.attn_mask", "layers.2.residual_group.blocks.1.attn.relative_position_index", "layers.2.residual_group.blocks.2.attn.relative_position_index", "layers.2.residual_group.blocks.3.attn_mask", "layers.2.residual_group.blocks.3.attn.relative_position_index", "layers.2.residual_group.blocks.4.attn.relative_position_index", "layers.2.residual_group.blocks.5.attn_mask", "layers.2.residual_group.blocks.5.attn.relative_position_index", "layers.2.conv.0.weight", "layers.2.conv.0.bias", "layers.2.conv.2.weight", "layers.2.conv.2.bias", "layers.2.conv.4.weight", "layers.2.conv.4.bias", "layers.3.residual_group.blocks.0.attn.relative_position_index", "layers.3.residual_group.blocks.1.attn_mask", "layers.3.residual_group.blocks.1.attn.relative_position_index", "layers.3.residual_group.blocks.2.attn.relative_position_index", "layers.3.residual_group.blocks.3.attn_mask", "layers.3.residual_group.blocks.3.attn.relative_position_index", "layers.3.residual_group.blocks.4.attn.relative_position_index", "layers.3.residual_group.blocks.5.attn_mask", "layers.3.residual_group.blocks.5.attn.relative_position_index", "layers.3.conv.0.weight", "layers.3.conv.0.bias", "layers.3.conv.2.weight", "layers.3.conv.2.bias", "layers.3.conv.4.weight", "layers.3.conv.4.bias", "layers.4.residual_group.blocks.0.attn.relative_position_index", "layers.4.residual_group.blocks.1.attn_mask", "layers.4.residual_group.blocks.1.attn.relative_position_index", "layers.4.residual_group.blocks.2.attn.relative_position_index", "layers.4.residual_group.blocks.3.attn_mask", "layers.4.residual_group.blocks.3.attn.relative_position_index", "layers.4.residual_group.blocks.4.attn.relative_position_index", "layers.4.residual_group.blocks.5.attn_mask", "layers.4.residual_group.blocks.5.attn.relative_position_index", "layers.4.conv.0.weight", "layers.4.conv.0.bias", "layers.4.conv.2.weight", "layers.4.conv.2.bias", "layers.4.conv.4.weight", "layers.4.conv.4.bias", "layers.5.residual_group.blocks.0.attn.relative_position_index", "layers.5.residual_group.blocks.1.attn_mask", "layers.5.residual_group.blocks.1.attn.relative_position_index", "layers.5.residual_group.blocks.2.attn.relative_position_index", "layers.5.residual_group.blocks.3.attn_mask", "layers.5.residual_group.blocks.3.attn.relative_position_index", "layers.5.residual_group.blocks.4.attn.relative_position_index", "layers.5.residual_group.blocks.5.attn_mask", "layers.5.residual_group.blocks.5.attn.relative_position_index", "layers.5.conv.0.weight", "layers.5.conv.0.bias", "layers.5.conv.2.weight", "layers.5.conv.2.bias", "layers.5.conv.4.weight", "layers.5.conv.4.bias", "layers.6.residual_group.blocks.0.norm1.weight", "layers.6.residual_group.blocks.0.norm1.bias", "layers.6.residual_group.blocks.0.attn.relative_position_bias_table", "layers.6.residual_group.blocks.0.attn.relative_position_index", "layers.6.residual_group.blocks.0.attn.qkv.weight", "layers.6.residual_group.blocks.0.attn.qkv.bias", "layers.6.residual_group.blocks.0.attn.proj.weight", "layers.6.residual_group.blocks.0.attn.proj.bias", "layers.6.residual_group.blocks.0.norm2.weight", "layers.6.residual_group.blocks.0.norm2.bias", "layers.6.residual_group.blocks.0.mlp.fc1.weight", "layers.6.residual_group.blocks.0.mlp.fc1.bias", "layers.6.residual_group.blocks.0.mlp.fc2.weight", "layers.6.residual_group.blocks.0.mlp.fc2.bias", "layers.6.residual_group.blocks.1.attn_mask", "layers.6.residual_group.blocks.1.norm1.weight", "layers.6.residual_group.blocks.1.norm1.bias", "layers.6.residual_group.blocks.1.attn.relative_position_bias_table", "layers.6.residual_group.blocks.1.attn.relative_position_index", "layers.6.residual_group.blocks.1.attn.qkv.weight", "layers.6.residual_group.blocks.1.attn.qkv.bias", "layers.6.residual_group.blocks.1.attn.proj.weight", "layers.6.residual_group.blocks.1.attn.proj.bias", "layers.6.residual_group.blocks.1.norm2.weight", "layers.6.residual_group.blocks.1.norm2.bias", "layers.6.residual_group.blocks.1.mlp.fc1.weight", "layers.6.residual_group.blocks.1.mlp.fc1.bias", "layers.6.residual_group.blocks.1.mlp.fc2.weight", "layers.6.residual_group.blocks.1.mlp.fc2.bias", "layers.6.residual_group.blocks.2.norm1.weight", "layers.6.residual_group.blocks.2.norm1.bias", "layers.6.residual_group.blocks.2.attn.relative_position_bias_table", "layers.6.residual_group.blocks.2.attn.relative_position_index", "layers.6.residual_group.blocks.2.attn.qkv.weight", "layers.7.residual_group.blocks.2.norm2.weight", "layers.7.residual_group.blocks.2.norm2.bias", "layers.7.residual_group.blocks.2.mlp.fc1.weight", "layers.7.residual_group.blocks.2.mlp.fc1.bias", "layers.7.residual_group.blocks.2.mlp.fc2.weight", "layers.7.residual_group.blocks.2.mlp.fc2.bias", "layers.7.residual_group.blocks.3.attn_mask", "layers.7.residual_group.blocks.3.norm1.weight", "layers.7.residual_group.blocks.3.norm1.bias", "layers.7.residual_group.blocks.3.attn.relative_position_bias_table", "layers.7.residual_group.blocks.3.attn.relative_position_index", "layers.7.residual_group.blocks.3.attn.qkv.weight", "layers.7.residual_group.blocks.3.attn.qkv.bias", "layers.7.residual_group.blocks.3.attn.proj.weight", cks.3.conv_block.cab.3.attention.1.weight", "layers.3.residual_group.blocks.3.conv_block.cab.3.attention.1.bias", "layers.3.residual_group.blocks.3.conv_block.cab.3.attention.3.weight", ······
ize([961, 6]) from checkpoint, the shape in current model is torch.Size([225, 8]). size mismatch for layers.0.residual_group.blocks.3.attn.qkv.weight: copying a param with shape torch.Size([540, 180]) from checkpoint, the shape in current model is torch.Size([720, 240]). size mismatch for layers.0.residual_group.blocks.3.attn.qkv.bias: copying a param with shape torch.Size([540]) from checkpoint, the shape in current model is torch.Size([720]). size mismatch for ape in current model is torch.Size([240, 480]). size mismatch for layers.5.residual_group.blocks.5.mlp.fc2.bias: copying a param with shape torch.Size([180]) from checkpoint, the shape in current model is torch.Size([240]). size mismatch for norm.weight: copying a param with shape torch.Size([180]) from checkpoint, the shape in current model is torch.Size([240]). size mismatch for norm.bias: copying a param with shape torch.Size([180]) from checkpoint, the shape in current model is torch.Size([240]). size mismatch for conv_before_upsample.0.weight: copying a param with shape torch.Size([64, 180, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 240, 3, 3]).
Time taken: 0.86sTorch active/reserved: 53/68 MiB, Sys VRAM: 2157/6144 MiB (35.11%)

chxy95 commented

@Erwin11 可以去提一下需求,或者自行更改它们repo提供的codes添加HAT model。看上去直接参考SwinIR添加或修改对应的model和config为HAT就行。

@chxy95 感谢您百忙之中还及时回复! auto1111 那边我已经提了 issue 建议,还没得到回复。

嗯。。 webui 那边我也看了作者@C43H66N12O12S2 的技术实现。
( 内置SwinIR 和SwinIR v2 是同一人@C43H66N12O12S2 )

参考SwinIR v2 的添加

作者@C43H66N12O12S2 是改了swinir_model.py (调度模块?) 和 新增了 swinir_model_arch_v2.py (架构文件?)。

swinir_model.py 进行修改,新增了SwinIR v2 的模块的net参数,相关参数如下:
model = net(
upscale=scale,
in_chans=3,
img_size=64,
window_size=8,
img_range=1.0,
depths=[6, 6, 6, 6, 6, 6, 6, 6, 6],
embed_dim=240,
num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
mlp_ratio=2,
upsampler="nearest+conv",
resi_connection="3conv",

关于这个参数在 HAT 模型是多少我不清楚。

新增架构的文件swinir_model_arch_v2.py
是在原内置的 swinir_model_arch.py 基础上改的,HAT 是新架构, 可能要加个这个文件进去。 我是外行,只能理解到这里了。
期待他们的更新吧

chxy95 commented

@Erwin11 我没详细看,但是如果要自己尝试的话,感觉可能也不复杂。原理上应该把这个repo里的网络架构代码hat_arch.py添加或改写到那边的模型文件里,即swinir_model_arch_v2.py,然后相关参数配置参考你要inference的test option文件如HAT_SRx4_ImageNet-LR.yml 就行。