Program randomly hangs and doesn't complete data generation with trained Psuedo ISP.
MrBled opened this issue · 1 comments
I do not know why but I could not run Stage 2 train, specifically 'Generate_Synthesis_Dataset()'.
I solved this by adding a 'train' flag to the three models and disabling weight initialization when the train flag was false. Without this the program randomly crashes and needs to be killed in htop/nvtop.
Simply as follows:
class RGB2PACK(nn.Module):
def __init__(self, channels=3, filters_num = 128, filters_pack = 4, train=True):
super(RGB2PACK, self).__init__()
# RGB2RAW Network
self.RGB2RAW = nn.Sequential(
nn.Conv2d(channels, filters_num, 3, 1, 1),
nn.ReLU(True),
nn.Conv2d(filters_num, filters_num, 3, 1, 1),
nn.ReLU(True),
nn.Conv2d(filters_num, filters_num, 3, 1, 1),
nn.ReLU(True),
nn.Conv2d(filters_num, filters_num, 3, 1, 1),
nn.ReLU(True),
nn.Conv2d(filters_num, filters_num, 3, 1, 1),
nn.ReLU(True),
nn.Conv2d(filters_num, channels, 3, 1, 1),
nn.ReLU(True))
# Mosaic
self.mosaic = Mosaic_Operation()
if train:
self._initialize_weights()
Repeat for the other two networks.
I do not know why but I could not run Stage 2 train, specifically 'Generate_Synthesis_Dataset()'.
I solved this by adding a 'train' flag to the three models and disabling weight initialization when the train flag was false. Without this the program randomly crashes and needs to be killed in htop/nvtop.
Simply as follows:
class RGB2PACK(nn.Module): def __init__(self, channels=3, filters_num = 128, filters_pack = 4, train=True): super(RGB2PACK, self).__init__() # RGB2RAW Network self.RGB2RAW = nn.Sequential( nn.Conv2d(channels, filters_num, 3, 1, 1), nn.ReLU(True), nn.Conv2d(filters_num, filters_num, 3, 1, 1), nn.ReLU(True), nn.Conv2d(filters_num, filters_num, 3, 1, 1), nn.ReLU(True), nn.Conv2d(filters_num, filters_num, 3, 1, 1), nn.ReLU(True), nn.Conv2d(filters_num, filters_num, 3, 1, 1), nn.ReLU(True), nn.Conv2d(filters_num, channels, 3, 1, 1), nn.ReLU(True)) # Mosaic self.mosaic = Mosaic_Operation() if train: self._initialize_weights()
Repeat for the other two networks.
Thank you. I will check it.