/ SRCNN_TORCH /
SRCNN 的 Pytorch实现版本
Build the SRCNN with Pytorch
v 0.1.0
We refer to the architecture of ESRGAN to build SRCNN.
- dataloader : To load the data
- model : Describe the specific architecture of SRCNN
- train:The main entrance of the program
-
torch.utils.data.Dataset
The most basic class to load data
The main job is to implement the following two member functions
-
# ***************************************************************** # how to load data each time and how to get the length of data # ***************************************************************** def __getitem__(self, index): img_path, label = self.data[index].img_path, self.data[index].label img = Image.open(img_path) return img, label def __len__(self): return len(self.data)
-
-
torch.utils.data.DataLoader
This class encapsulates the Dataset class.
-
# defination of the DataLoader class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False)
As you can see, there are several main parameters:
- Dataset : The custom Dataset class above.
- Collate_fn: This function is used to package the batch, which is described in detail later.
- Num_worker: Very simple multi-threaded method, as long as it is set to >=1, you can multi-thread pre-read data.
And we should Implement a function named "_iter_",we will use the DataLoaderIter class in it.
def __iter__(self): return DataLoaderIter(self)
-
-
torch.utils.data.dataloader.DataLoaderIter
DataLoadIter is the Iterator of the DataLoader
from torch.utils.data import Dataset import os from PIL import Image import numpy as np import torchvision.transforms as transform class Imgdataset(Dataset): def __init__(self, path): super(Imgdataset, self).__init__() self.data = [] if os.path.exists(path): dir_list = os.listdir(path) LR_dir = dir_list[0] HR_dir = dir_list[1] LR_path = path + '/' + LR_dir HR_path = path + '/' + HR_dir if os.path.exists(LR_path) and os.path.exists(HR_path): LR_data = os.listdir(LR_path) HR_data = os.listdir(HR_path) self.data = [{'LR':LR_path +'/'+ LR_data[i], 'HR':HR_path +'/'+ HR_data[i]} for i in range(len(LR_data))] else: raise FileNotFoundError('path doesnt exist!') def __getitem__(self, index): LR_path, HR_path = self.data[index]["LR"], self.data[index]["HR"] tran = transform.ToTensor() LR_img = Image.open(LR_path) HR_img = Image.open(HR_path) # print(tran(img).shape) return tran(LR_img), tran(HR_img) def __len__(self): return len(self.data)
- conv1: conv2d(in_channels=3, out_channels=64, kernal_size=9, stride=1)
- conv2: conv2d(in_channels=64, out_channels=32, kernal_size=3, stride=1)
- conv3:conv2d(in_channels=32, out_channels=3, kernal_size=1, stride=1)
- ReLU: Activation function
import torch
import torch.nn as nn
class NetWork(nn.Module):
def __init__(self):
super(NetWork, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=9, stride=1)
self.relu1 = nn.ReLU()
self.conv2 = nn.Conv2d(64, 32, kernel_size=1, stride=1)
self.relu2 = nn.ReLU()
self.conv3 = nn.Conv2d(32, 3, kernel_size=5, stride=1)
def forward(self, data):
out = self.conv1(data)
out = self.relu1(out)
out = self.conv2(out)
out = self.relu2(out)
out = self.conv3(out)
return out
**NOTE:To avoid border effects during training, all the convolutional layers have no padding, and the result we get from the conv3 won't have the same size as the input image, it will be a little bit smaller than the origin picture,so we just calcuate the MSELoss by central overlapping area **
there are a little bit of points that we should note:
-
how to load our data?
Not to mention what we said above, we just need to instantiate a DataLoader class with our custom DataSet and then iterate over it.
AND WE SHOULD NOTE:
for batch_id ,batch in enumerate(train_data_loader): # # batch is a tensor included twenty LR_data and twenty HR_data (e.g.) #
-
how to build our net?
just instantiate it!
-
how to train our model?
we use Adam to optimize our Loss function:MSELoss.
optimizer = optim.Adam(net.parameters(), lr=0.1) loss = nn.MSELoss()
-
how to use cuda?
if not torch.cuda.is_available(): raise Exception('NO GPU!')