麻烦看一下这样是否正确
Opened this issue · 5 comments
import torch
from basicsr.archs.hit_srf_arch import HiT_SRF
import torchvision.transforms as transforms
from PIL import Image
import os
加载图像
image_path = "C:\Users\17525\Desktop\tupian\rectImage1.bmp"
image = Image.open(image_path).convert('RGB')
定义预处理转换
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.4488, 0.4371, 0.4040], std=[1.0, 1.0, 1.0])
])
应用预处理
input_tensor = transform(image).unsqueeze(0)
初始化模型
model = HiT_SRF(upscale=2)
加载本地模型权重
local_model_path = "D:\chaofenbian\HiT-SR-main\HiT-SR\HiT-SRF-2x.pth"
检查文件是否存在
if not os.path.exists(local_model_path):
print(f"File not found: {local_model_path}")
else:
try:
state_dict = torch.load(local_model_path, map_location='cpu', weights_only=True)
model.load_state_dict(state_dict, strict=False)
except Exception as e:
print(f"An error occurred: {e}")
设置模型为评估模式
model.eval()
确保使用与模型训练时相同的设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
input_tensor = input_tensor.to(device)
进行推理
with torch.no_grad():
output_tensor = model(input_tensor)
反标准化
output_tensor = output_tensor.squeeze(0).cpu()
output_tensor = output_tensor.clamp(0, 1)
转换为图像
output_image = transforms.ToPILImage()(output_tensor)
保存或显示输出图像
如果运行正确结HiT-SRF-2x的结果应该是这样,供您参考
目前我们只在超分任务上进行过实验,还没有测试过去噪性能。如果您感兴趣的话可以参考SwinIR训练一个去噪模型看看效果如何。
对于运行时间,我们目前训练的模型都带超分,所以只能提供超分的运行时间作为参考。在论文图1b中我们测试了从360✖640超分到720x1280的速度,在A100 GPU上大约为331ms (HiT-SRF)。这个速度会随着不同设备而变化,因此在你们目标设备上测试结果才更准确。
谢谢,您可以用超分去噪的数据集训练一个新的网络,可能效果会比我们预训练的网络更好(我们训练时没有考虑噪声)。
祝项目一切顺利!