git clone <repo>
pip install -e .
from pathlib import Path
from fastcore.all import *
import torch
from segmentation_test.tf_model_creation import *
2023-11-29 22:42:36.113027: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2023-11-29 22:42:36.160792: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-11-29 22:42:36.160818: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-11-29 22:42:36.162022: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-11-29 22:42:36.170082: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-11-29 22:42:37.092370: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
h,w,c = 1152, 1632,1
n_classes = 1
filter_list=[64, 128, 256, 512]
normal_unet_model = unet_model(
input_size=(h, w, c),
filter_list=filter_list,
n_classes=n_classes
)
#print(normal_unet_model.summary())
attn_unet_model = unet_model_attention_gates(
input_size=(h, w, c),
filter_list=filter_list,
n_classes=n_classes
)
#print(attn_unet_model.summary())
res_attn_unet_model = residual_attn_unet(
input_size=(h, w, c),
filter_list=filter_list,
n_classes=n_classes
)
#print(res_attn_unet_model.summary())
Model: "model_3"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_6 (InputLayer) [(None, 1152, 1632, 1)] 0 []
conv2d_63 (Conv2D) (None, 1152, 1632, 64) 640 ['input_6[0][0]']
batch_normalization_54 (Ba (None, 1152, 1632, 64) 256 ['conv2d_63[0][0]']
tchNormalization)
activation_62 (Activation) (None, 1152, 1632, 64) 0 ['batch_normalization_54[0][0]
']
dropout_54 (Dropout) (None, 1152, 1632, 64) 0 ['activation_62[0][0]']
conv2d_64 (Conv2D) (None, 1152, 1632, 64) 36928 ['dropout_54[0][0]']
conv2d_65 (Conv2D) (None, 1152, 1632, 64) 128 ['input_6[0][0]']
batch_normalization_55 (Ba (None, 1152, 1632, 64) 256 ['conv2d_64[0][0]']
tchNormalization)
batch_normalization_56 (Ba (None, 1152, 1632, 64) 256 ['conv2d_65[0][0]']
tchNormalization)
dropout_55 (Dropout) (None, 1152, 1632, 64) 0 ['batch_normalization_55[0][0]
']
add_4 (Add) (None, 1152, 1632, 64) 0 ['batch_normalization_56[0][0]
',
'dropout_55[0][0]']
activation_63 (Activation) (None, 1152, 1632, 64) 0 ['add_4[0][0]']
max_pooling2d_12 (MaxPooli (None, 576, 816, 64) 0 ['activation_63[0][0]']
ng2D)
average_pooling2d_12 (Aver (None, 576, 816, 64) 0 ['activation_63[0][0]']
agePooling2D)
concatenate_24 (Concatenat (None, 576, 816, 128) 0 ['max_pooling2d_12[0][0]',
e) 'average_pooling2d_12[0][0]']
conv2d_66 (Conv2D) (None, 576, 816, 128) 147584 ['concatenate_24[0][0]']
batch_normalization_57 (Ba (None, 576, 816, 128) 512 ['conv2d_66[0][0]']
tchNormalization)
activation_64 (Activation) (None, 576, 816, 128) 0 ['batch_normalization_57[0][0]
']
dropout_56 (Dropout) (None, 576, 816, 128) 0 ['activation_64[0][0]']
conv2d_67 (Conv2D) (None, 576, 816, 128) 147584 ['dropout_56[0][0]']
conv2d_68 (Conv2D) (None, 576, 816, 128) 16512 ['concatenate_24[0][0]']
batch_normalization_58 (Ba (None, 576, 816, 128) 512 ['conv2d_67[0][0]']
tchNormalization)
batch_normalization_59 (Ba (None, 576, 816, 128) 512 ['conv2d_68[0][0]']
tchNormalization)
dropout_57 (Dropout) (None, 576, 816, 128) 0 ['batch_normalization_58[0][0]
']
add_5 (Add) (None, 576, 816, 128) 0 ['batch_normalization_59[0][0]
',
'dropout_57[0][0]']
activation_65 (Activation) (None, 576, 816, 128) 0 ['add_5[0][0]']
max_pooling2d_13 (MaxPooli (None, 288, 408, 128) 0 ['activation_65[0][0]']
ng2D)
average_pooling2d_13 (Aver (None, 288, 408, 128) 0 ['activation_65[0][0]']
agePooling2D)
concatenate_25 (Concatenat (None, 288, 408, 256) 0 ['max_pooling2d_13[0][0]',
e) 'average_pooling2d_13[0][0]']
conv2d_69 (Conv2D) (None, 288, 408, 256) 590080 ['concatenate_25[0][0]']
batch_normalization_60 (Ba (None, 288, 408, 256) 1024 ['conv2d_69[0][0]']
tchNormalization)
activation_66 (Activation) (None, 288, 408, 256) 0 ['batch_normalization_60[0][0]
']
dropout_58 (Dropout) (None, 288, 408, 256) 0 ['activation_66[0][0]']
conv2d_70 (Conv2D) (None, 288, 408, 256) 590080 ['dropout_58[0][0]']
conv2d_71 (Conv2D) (None, 288, 408, 256) 65792 ['concatenate_25[0][0]']
batch_normalization_61 (Ba (None, 288, 408, 256) 1024 ['conv2d_70[0][0]']
tchNormalization)
batch_normalization_62 (Ba (None, 288, 408, 256) 1024 ['conv2d_71[0][0]']
tchNormalization)
dropout_59 (Dropout) (None, 288, 408, 256) 0 ['batch_normalization_61[0][0]
']
add_6 (Add) (None, 288, 408, 256) 0 ['batch_normalization_62[0][0]
',
'dropout_59[0][0]']
activation_67 (Activation) (None, 288, 408, 256) 0 ['add_6[0][0]']
max_pooling2d_14 (MaxPooli (None, 144, 204, 256) 0 ['activation_67[0][0]']
ng2D)
average_pooling2d_14 (Aver (None, 144, 204, 256) 0 ['activation_67[0][0]']
agePooling2D)
concatenate_26 (Concatenat (None, 144, 204, 512) 0 ['max_pooling2d_14[0][0]',
e) 'average_pooling2d_14[0][0]']
conv2d_72 (Conv2D) (None, 144, 204, 512) 2359808 ['concatenate_26[0][0]']
batch_normalization_63 (Ba (None, 144, 204, 512) 2048 ['conv2d_72[0][0]']
tchNormalization)
activation_68 (Activation) (None, 144, 204, 512) 0 ['batch_normalization_63[0][0]
']
dropout_60 (Dropout) (None, 144, 204, 512) 0 ['activation_68[0][0]']
conv2d_73 (Conv2D) (None, 144, 204, 512) 2359808 ['dropout_60[0][0]']
conv2d_74 (Conv2D) (None, 144, 204, 512) 262656 ['concatenate_26[0][0]']
batch_normalization_64 (Ba (None, 144, 204, 512) 2048 ['conv2d_73[0][0]']
tchNormalization)
batch_normalization_65 (Ba (None, 144, 204, 512) 2048 ['conv2d_74[0][0]']
tchNormalization)
dropout_61 (Dropout) (None, 144, 204, 512) 0 ['batch_normalization_64[0][0]
']
add_7 (Add) (None, 144, 204, 512) 0 ['batch_normalization_65[0][0]
',
'dropout_61[0][0]']
activation_69 (Activation) (None, 144, 204, 512) 0 ['add_7[0][0]']
max_pooling2d_15 (MaxPooli (None, 72, 102, 512) 0 ['activation_69[0][0]']
ng2D)
average_pooling2d_15 (Aver (None, 72, 102, 512) 0 ['activation_69[0][0]']
agePooling2D)
concatenate_27 (Concatenat (None, 72, 102, 1024) 0 ['max_pooling2d_15[0][0]',
e) 'average_pooling2d_15[0][0]']
conv2d_75 (Conv2D) (None, 72, 102, 1024) 9438208 ['concatenate_27[0][0]']
batch_normalization_66 (Ba (None, 72, 102, 1024) 4096 ['conv2d_75[0][0]']
tchNormalization)
activation_70 (Activation) (None, 72, 102, 1024) 0 ['batch_normalization_66[0][0]
']
dropout_62 (Dropout) (None, 72, 102, 1024) 0 ['activation_70[0][0]']
conv2d_76 (Conv2D) (None, 72, 102, 1024) 9438208 ['dropout_62[0][0]']
conv2d_77 (Conv2D) (None, 72, 102, 1024) 1049600 ['concatenate_27[0][0]']
batch_normalization_67 (Ba (None, 72, 102, 1024) 4096 ['conv2d_76[0][0]']
tchNormalization)
batch_normalization_68 (Ba (None, 72, 102, 1024) 4096 ['conv2d_77[0][0]']
tchNormalization)
dropout_63 (Dropout) (None, 72, 102, 1024) 0 ['batch_normalization_67[0][0]
']
add_8 (Add) (None, 72, 102, 1024) 0 ['batch_normalization_68[0][0]
',
'dropout_63[0][0]']
activation_71 (Activation) (None, 72, 102, 1024) 0 ['add_8[0][0]']
gating_signal0_conv (Conv2 (None, 72, 102, 512) 524800 ['activation_71[0][0]']
D)
gating_signal0_bn (BatchNo (None, 72, 102, 512) 2048 ['gating_signal0_conv[0][0]']
rmalization)
gating_signal0_act (Activa (None, 72, 102, 512) 0 ['gating_signal0_bn[0][0]']
tion)
conv2d_78 (Conv2D) (None, 72, 102, 512) 262656 ['gating_signal0_act[0][0]']
g_upattention_block0 (Conv (None, 72, 102, 512) 2359808 ['conv2d_78[0][0]']
2DTranspose)
xlattention_block0 (Conv2D (None, 72, 102, 512) 1049088 ['activation_69[0][0]']
)
add_9 (Add) (None, 72, 102, 512) 0 ['g_upattention_block0[0][0]',
'xlattention_block0[0][0]']
activation_72 (Activation) (None, 72, 102, 512) 0 ['add_9[0][0]']
psiattention_block0 (Conv2 (None, 72, 102, 1) 513 ['activation_72[0][0]']
D)
activation_73 (Activation) (None, 72, 102, 1) 0 ['psiattention_block0[0][0]']
up_sampling2d_4 (UpSamplin (None, 144, 204, 1) 0 ['activation_73[0][0]']
g2D)
psi_upattention_block0 (La (None, 144, 204, 512) 0 ['up_sampling2d_4[0][0]']
mbda)
q_attnattention_block0 (Mu (None, 144, 204, 512) 0 ['psi_upattention_block0[0][0]
ltiply) ',
'activation_69[0][0]']
q_attn_convattention_block (None, 144, 204, 512) 262656 ['q_attnattention_block0[0][0]
0 (Conv2D) ']
upsampling_0 (UpSampling2D (None, 144, 204, 1024) 0 ['activation_71[0][0]']
)
q_attn_bnattention_block0 (None, 144, 204, 512) 2048 ['q_attn_convattention_block0[
(BatchNormalization) 0][0]']
concatenate_28 (Concatenat (None, 144, 204, 1536) 0 ['upsampling_0[0][0]',
e) 'q_attn_bnattention_block0[0]
[0]']
conv2d_79 (Conv2D) (None, 144, 204, 512) 7078400 ['concatenate_28[0][0]']
batch_normalization_69 (Ba (None, 144, 204, 512) 2048 ['conv2d_79[0][0]']
tchNormalization)
activation_74 (Activation) (None, 144, 204, 512) 0 ['batch_normalization_69[0][0]
']
dropout_64 (Dropout) (None, 144, 204, 512) 0 ['activation_74[0][0]']
conv2d_80 (Conv2D) (None, 144, 204, 512) 2359808 ['dropout_64[0][0]']
conv2d_81 (Conv2D) (None, 144, 204, 512) 786944 ['concatenate_28[0][0]']
batch_normalization_70 (Ba (None, 144, 204, 512) 2048 ['conv2d_80[0][0]']
tchNormalization)
batch_normalization_71 (Ba (None, 144, 204, 512) 2048 ['conv2d_81[0][0]']
tchNormalization)
dropout_65 (Dropout) (None, 144, 204, 512) 0 ['batch_normalization_70[0][0]
']
add_10 (Add) (None, 144, 204, 512) 0 ['batch_normalization_71[0][0]
',
'dropout_65[0][0]']
activation_75 (Activation) (None, 144, 204, 512) 0 ['add_10[0][0]']
gating_signal1_conv (Conv2 (None, 144, 204, 256) 131328 ['activation_75[0][0]']
D)
gating_signal1_bn (BatchNo (None, 144, 204, 256) 1024 ['gating_signal1_conv[0][0]']
rmalization)
gating_signal1_act (Activa (None, 144, 204, 256) 0 ['gating_signal1_bn[0][0]']
tion)
conv2d_82 (Conv2D) (None, 144, 204, 256) 65792 ['gating_signal1_act[0][0]']
g_upattention_block1 (Conv (None, 144, 204, 256) 590080 ['conv2d_82[0][0]']
2DTranspose)
xlattention_block1 (Conv2D (None, 144, 204, 256) 262400 ['activation_67[0][0]']
)
add_11 (Add) (None, 144, 204, 256) 0 ['g_upattention_block1[0][0]',
'xlattention_block1[0][0]']
activation_76 (Activation) (None, 144, 204, 256) 0 ['add_11[0][0]']
psiattention_block1 (Conv2 (None, 144, 204, 1) 257 ['activation_76[0][0]']
D)
activation_77 (Activation) (None, 144, 204, 1) 0 ['psiattention_block1[0][0]']
up_sampling2d_5 (UpSamplin (None, 288, 408, 1) 0 ['activation_77[0][0]']
g2D)
psi_upattention_block1 (La (None, 288, 408, 256) 0 ['up_sampling2d_5[0][0]']
mbda)
q_attnattention_block1 (Mu (None, 288, 408, 256) 0 ['psi_upattention_block1[0][0]
ltiply) ',
'activation_67[0][0]']
q_attn_convattention_block (None, 288, 408, 256) 65792 ['q_attnattention_block1[0][0]
1 (Conv2D) ']
upsampling_1 (UpSampling2D (None, 288, 408, 512) 0 ['activation_75[0][0]']
)
q_attn_bnattention_block1 (None, 288, 408, 256) 1024 ['q_attn_convattention_block1[
(BatchNormalization) 0][0]']
concatenate_29 (Concatenat (None, 288, 408, 768) 0 ['upsampling_1[0][0]',
e) 'q_attn_bnattention_block1[0]
[0]']
conv2d_83 (Conv2D) (None, 288, 408, 256) 1769728 ['concatenate_29[0][0]']
batch_normalization_72 (Ba (None, 288, 408, 256) 1024 ['conv2d_83[0][0]']
tchNormalization)
activation_78 (Activation) (None, 288, 408, 256) 0 ['batch_normalization_72[0][0]
']
dropout_66 (Dropout) (None, 288, 408, 256) 0 ['activation_78[0][0]']
conv2d_84 (Conv2D) (None, 288, 408, 256) 590080 ['dropout_66[0][0]']
conv2d_85 (Conv2D) (None, 288, 408, 256) 196864 ['concatenate_29[0][0]']
batch_normalization_73 (Ba (None, 288, 408, 256) 1024 ['conv2d_84[0][0]']
tchNormalization)
batch_normalization_74 (Ba (None, 288, 408, 256) 1024 ['conv2d_85[0][0]']
tchNormalization)
dropout_67 (Dropout) (None, 288, 408, 256) 0 ['batch_normalization_73[0][0]
']
add_12 (Add) (None, 288, 408, 256) 0 ['batch_normalization_74[0][0]
',
'dropout_67[0][0]']
activation_79 (Activation) (None, 288, 408, 256) 0 ['add_12[0][0]']
gating_signal2_conv (Conv2 (None, 288, 408, 128) 32896 ['activation_79[0][0]']
D)
gating_signal2_bn (BatchNo (None, 288, 408, 128) 512 ['gating_signal2_conv[0][0]']
rmalization)
gating_signal2_act (Activa (None, 288, 408, 128) 0 ['gating_signal2_bn[0][0]']
tion)
conv2d_86 (Conv2D) (None, 288, 408, 128) 16512 ['gating_signal2_act[0][0]']
g_upattention_block2 (Conv (None, 288, 408, 128) 147584 ['conv2d_86[0][0]']
2DTranspose)
xlattention_block2 (Conv2D (None, 288, 408, 128) 65664 ['activation_65[0][0]']
)
add_13 (Add) (None, 288, 408, 128) 0 ['g_upattention_block2[0][0]',
'xlattention_block2[0][0]']
activation_80 (Activation) (None, 288, 408, 128) 0 ['add_13[0][0]']
psiattention_block2 (Conv2 (None, 288, 408, 1) 129 ['activation_80[0][0]']
D)
activation_81 (Activation) (None, 288, 408, 1) 0 ['psiattention_block2[0][0]']
up_sampling2d_6 (UpSamplin (None, 576, 816, 1) 0 ['activation_81[0][0]']
g2D)
psi_upattention_block2 (La (None, 576, 816, 128) 0 ['up_sampling2d_6[0][0]']
mbda)
q_attnattention_block2 (Mu (None, 576, 816, 128) 0 ['psi_upattention_block2[0][0]
ltiply) ',
'activation_65[0][0]']
q_attn_convattention_block (None, 576, 816, 128) 16512 ['q_attnattention_block2[0][0]
2 (Conv2D) ']
upsampling_2 (UpSampling2D (None, 576, 816, 256) 0 ['activation_79[0][0]']
)
q_attn_bnattention_block2 (None, 576, 816, 128) 512 ['q_attn_convattention_block2[
(BatchNormalization) 0][0]']
concatenate_30 (Concatenat (None, 576, 816, 384) 0 ['upsampling_2[0][0]',
e) 'q_attn_bnattention_block2[0]
[0]']
conv2d_87 (Conv2D) (None, 576, 816, 128) 442496 ['concatenate_30[0][0]']
batch_normalization_75 (Ba (None, 576, 816, 128) 512 ['conv2d_87[0][0]']
tchNormalization)
activation_82 (Activation) (None, 576, 816, 128) 0 ['batch_normalization_75[0][0]
']
dropout_68 (Dropout) (None, 576, 816, 128) 0 ['activation_82[0][0]']
conv2d_88 (Conv2D) (None, 576, 816, 128) 147584 ['dropout_68[0][0]']
conv2d_89 (Conv2D) (None, 576, 816, 128) 49280 ['concatenate_30[0][0]']
batch_normalization_76 (Ba (None, 576, 816, 128) 512 ['conv2d_88[0][0]']
tchNormalization)
batch_normalization_77 (Ba (None, 576, 816, 128) 512 ['conv2d_89[0][0]']
tchNormalization)
dropout_69 (Dropout) (None, 576, 816, 128) 0 ['batch_normalization_76[0][0]
']
add_14 (Add) (None, 576, 816, 128) 0 ['batch_normalization_77[0][0]
',
'dropout_69[0][0]']
activation_83 (Activation) (None, 576, 816, 128) 0 ['add_14[0][0]']
gating_signal3_conv (Conv2 (None, 576, 816, 64) 8256 ['activation_83[0][0]']
D)
gating_signal3_bn (BatchNo (None, 576, 816, 64) 256 ['gating_signal3_conv[0][0]']
rmalization)
gating_signal3_act (Activa (None, 576, 816, 64) 0 ['gating_signal3_bn[0][0]']
tion)
conv2d_90 (Conv2D) (None, 576, 816, 64) 4160 ['gating_signal3_act[0][0]']
g_upattention_block3 (Conv (None, 576, 816, 64) 36928 ['conv2d_90[0][0]']
2DTranspose)
xlattention_block3 (Conv2D (None, 576, 816, 64) 16448 ['activation_63[0][0]']
)
add_15 (Add) (None, 576, 816, 64) 0 ['g_upattention_block3[0][0]',
'xlattention_block3[0][0]']
activation_84 (Activation) (None, 576, 816, 64) 0 ['add_15[0][0]']
psiattention_block3 (Conv2 (None, 576, 816, 1) 65 ['activation_84[0][0]']
D)
activation_85 (Activation) (None, 576, 816, 1) 0 ['psiattention_block3[0][0]']
up_sampling2d_7 (UpSamplin (None, 1152, 1632, 1) 0 ['activation_85[0][0]']
g2D)
psi_upattention_block3 (La (None, 1152, 1632, 64) 0 ['up_sampling2d_7[0][0]']
mbda)
q_attnattention_block3 (Mu (None, 1152, 1632, 64) 0 ['psi_upattention_block3[0][0]
ltiply) ',
'activation_63[0][0]']
q_attn_convattention_block (None, 1152, 1632, 64) 4160 ['q_attnattention_block3[0][0]
3 (Conv2D) ']
upsampling_3 (UpSampling2D (None, 1152, 1632, 128) 0 ['activation_83[0][0]']
)
q_attn_bnattention_block3 (None, 1152, 1632, 64) 256 ['q_attn_convattention_block3[
(BatchNormalization) 0][0]']
concatenate_31 (Concatenat (None, 1152, 1632, 192) 0 ['upsampling_3[0][0]',
e) 'q_attn_bnattention_block3[0]
[0]']
conv2d_91 (Conv2D) (None, 1152, 1632, 64) 110656 ['concatenate_31[0][0]']
batch_normalization_78 (Ba (None, 1152, 1632, 64) 256 ['conv2d_91[0][0]']
tchNormalization)
activation_86 (Activation) (None, 1152, 1632, 64) 0 ['batch_normalization_78[0][0]
']
dropout_70 (Dropout) (None, 1152, 1632, 64) 0 ['activation_86[0][0]']
conv2d_92 (Conv2D) (None, 1152, 1632, 64) 36928 ['dropout_70[0][0]']
conv2d_93 (Conv2D) (None, 1152, 1632, 64) 12352 ['concatenate_31[0][0]']
batch_normalization_79 (Ba (None, 1152, 1632, 64) 256 ['conv2d_92[0][0]']
tchNormalization)
batch_normalization_80 (Ba (None, 1152, 1632, 64) 256 ['conv2d_93[0][0]']
tchNormalization)
dropout_71 (Dropout) (None, 1152, 1632, 64) 0 ['batch_normalization_79[0][0]
']
add_16 (Add) (None, 1152, 1632, 64) 0 ['batch_normalization_80[0][0]
',
'dropout_71[0][0]']
activation_87 (Activation) (None, 1152, 1632, 64) 0 ['add_16[0][0]']
conv2d_94 (Conv2D) (None, 1152, 1632, 1) 65 ['activation_87[0][0]']
==================================================================================================
Total params: 46052293 (175.68 MB)
Trainable params: 46030789 (175.59 MB)
Non-trainable params: 21504 (84.00 KB)
__________________________________________________________________________________________________
None
from segmentation_test.pytorch_model_development import *
from segmentation_test.pytorch_model_development import *
from segmentation_test.dataloader_creation import *
from segmentation_test.pytorch_training_and_loss import *
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
model = UNet(
in_channels=1,
out_channels=1)
im_path = Path(r'/home/hasan/workspace/data/microscopy_data/patch_images')
msk_path = Path(r'/home/hasan/workspace/data/microscopy_data/patch_masks')
train_loader, val_loader = create_pytorch_dataloader(
split_type='random',
split_per=0.1,
batch_size=2,
image_path=im_path,
mask_path=msk_path,
transforms=None,
num_workers=0
)
Number of images found = 1642
training dataset length = 164 and validation dataset length= 1478
torch.Size([1, 256, 256]) torch.Size([1, 256, 256])
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
criterion = FocalLoss()
#| export
total_epochs = 10
# Define LR scheduler
scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=2, eta_min=0)
warmup_epochs = 0.1 * total_epochs
initial_lr = 0.002
loss_fn= FocalLoss()
model_save_path = Path.cwd()
warmup_epochs=2
total_epochs=10
initial_lr=0.002
train_segmentation_model(
train_loader,
val_loader,
model,
optimizer,
scheduler,
loss_fn,
model_save_path,
warmup_epochs,
total_epochs,
initial_lr,
dtype=torch.float,
device='cuda')
Epoch 1/10: 0%| | 0/82 [00:00<?, ?batch/s]
Epoch 1/10 - Train Loss: 0.0130, IoU: 0.0254, FP: 0.0, FN: 0.0
Validation Loss: 0.0097, IoU: 0.3979, FP: 0.0, FN: 0.0
Epoch 2/10: 0%| | 0/82 [00:00<?, ?batch/s]