/segmentation_test

Library for segmentation model

Primary LanguageJupyter NotebookApache License 2.0Apache-2.0

Segmentation related functions will be developed here

Install

git clone <repo>
pip install -e .

How to use

from pathlib import Path
from fastcore.all import *
import torch
from segmentation_test.tf_model_creation import *
2023-11-29 22:42:36.113027: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2023-11-29 22:42:36.160792: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-11-29 22:42:36.160818: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-11-29 22:42:36.162022: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-11-29 22:42:36.170082: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-11-29 22:42:37.092370: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
h,w,c = 1152, 1632,1
n_classes = 1
filter_list=[64, 128, 256, 512]

Create a Normal UNet

normal_unet_model = unet_model(
    input_size=(h, w, c),
    filter_list=filter_list,
    n_classes=n_classes
)
#print(normal_unet_model.summary())

Create a U-Net with Attention

attn_unet_model = unet_model_attention_gates(
    input_size=(h, w, c),
    filter_list=filter_list,
    n_classes=n_classes
)
#print(attn_unet_model.summary())

Create a U-Net model with attention gates and residual connections

res_attn_unet_model = residual_attn_unet(
    input_size=(h, w, c),
    filter_list=filter_list,
    n_classes=n_classes
)
#print(res_attn_unet_model.summary())
Model: "model_3"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
==================================================================================================
 input_6 (InputLayer)        [(None, 1152, 1632, 1)]      0         []                            
                                                                                                  
 conv2d_63 (Conv2D)          (None, 1152, 1632, 64)       640       ['input_6[0][0]']             
                                                                                                  
 batch_normalization_54 (Ba  (None, 1152, 1632, 64)       256       ['conv2d_63[0][0]']           
 tchNormalization)                                                                                
                                                                                                  
 activation_62 (Activation)  (None, 1152, 1632, 64)       0         ['batch_normalization_54[0][0]
                                                                    ']                            
                                                                                                  
 dropout_54 (Dropout)        (None, 1152, 1632, 64)       0         ['activation_62[0][0]']       
                                                                                                  
 conv2d_64 (Conv2D)          (None, 1152, 1632, 64)       36928     ['dropout_54[0][0]']          
                                                                                                  
 conv2d_65 (Conv2D)          (None, 1152, 1632, 64)       128       ['input_6[0][0]']             
                                                                                                  
 batch_normalization_55 (Ba  (None, 1152, 1632, 64)       256       ['conv2d_64[0][0]']           
 tchNormalization)                                                                                
                                                                                                  
 batch_normalization_56 (Ba  (None, 1152, 1632, 64)       256       ['conv2d_65[0][0]']           
 tchNormalization)                                                                                
                                                                                                  
 dropout_55 (Dropout)        (None, 1152, 1632, 64)       0         ['batch_normalization_55[0][0]
                                                                    ']                            
                                                                                                  
 add_4 (Add)                 (None, 1152, 1632, 64)       0         ['batch_normalization_56[0][0]
                                                                    ',                            
                                                                     'dropout_55[0][0]']          
                                                                                                  
 activation_63 (Activation)  (None, 1152, 1632, 64)       0         ['add_4[0][0]']               
                                                                                                  
 max_pooling2d_12 (MaxPooli  (None, 576, 816, 64)         0         ['activation_63[0][0]']       
 ng2D)                                                                                            
                                                                                                  
 average_pooling2d_12 (Aver  (None, 576, 816, 64)         0         ['activation_63[0][0]']       
 agePooling2D)                                                                                    
                                                                                                  
 concatenate_24 (Concatenat  (None, 576, 816, 128)        0         ['max_pooling2d_12[0][0]',    
 e)                                                                  'average_pooling2d_12[0][0]']
                                                                                                  
 conv2d_66 (Conv2D)          (None, 576, 816, 128)        147584    ['concatenate_24[0][0]']      
                                                                                                  
 batch_normalization_57 (Ba  (None, 576, 816, 128)        512       ['conv2d_66[0][0]']           
 tchNormalization)                                                                                
                                                                                                  
 activation_64 (Activation)  (None, 576, 816, 128)        0         ['batch_normalization_57[0][0]
                                                                    ']                            
                                                                                                  
 dropout_56 (Dropout)        (None, 576, 816, 128)        0         ['activation_64[0][0]']       
                                                                                                  
 conv2d_67 (Conv2D)          (None, 576, 816, 128)        147584    ['dropout_56[0][0]']          
                                                                                                  
 conv2d_68 (Conv2D)          (None, 576, 816, 128)        16512     ['concatenate_24[0][0]']      
                                                                                                  
 batch_normalization_58 (Ba  (None, 576, 816, 128)        512       ['conv2d_67[0][0]']           
 tchNormalization)                                                                                
                                                                                                  
 batch_normalization_59 (Ba  (None, 576, 816, 128)        512       ['conv2d_68[0][0]']           
 tchNormalization)                                                                                
                                                                                                  
 dropout_57 (Dropout)        (None, 576, 816, 128)        0         ['batch_normalization_58[0][0]
                                                                    ']                            
                                                                                                  
 add_5 (Add)                 (None, 576, 816, 128)        0         ['batch_normalization_59[0][0]
                                                                    ',                            
                                                                     'dropout_57[0][0]']          
                                                                                                  
 activation_65 (Activation)  (None, 576, 816, 128)        0         ['add_5[0][0]']               
                                                                                                  
 max_pooling2d_13 (MaxPooli  (None, 288, 408, 128)        0         ['activation_65[0][0]']       
 ng2D)                                                                                            
                                                                                                  
 average_pooling2d_13 (Aver  (None, 288, 408, 128)        0         ['activation_65[0][0]']       
 agePooling2D)                                                                                    
                                                                                                  
 concatenate_25 (Concatenat  (None, 288, 408, 256)        0         ['max_pooling2d_13[0][0]',    
 e)                                                                  'average_pooling2d_13[0][0]']
                                                                                                  
 conv2d_69 (Conv2D)          (None, 288, 408, 256)        590080    ['concatenate_25[0][0]']      
                                                                                                  
 batch_normalization_60 (Ba  (None, 288, 408, 256)        1024      ['conv2d_69[0][0]']           
 tchNormalization)                                                                                
                                                                                                  
 activation_66 (Activation)  (None, 288, 408, 256)        0         ['batch_normalization_60[0][0]
                                                                    ']                            
                                                                                                  
 dropout_58 (Dropout)        (None, 288, 408, 256)        0         ['activation_66[0][0]']       
                                                                                                  
 conv2d_70 (Conv2D)          (None, 288, 408, 256)        590080    ['dropout_58[0][0]']          
                                                                                                  
 conv2d_71 (Conv2D)          (None, 288, 408, 256)        65792     ['concatenate_25[0][0]']      
                                                                                                  
 batch_normalization_61 (Ba  (None, 288, 408, 256)        1024      ['conv2d_70[0][0]']           
 tchNormalization)                                                                                
                                                                                                  
 batch_normalization_62 (Ba  (None, 288, 408, 256)        1024      ['conv2d_71[0][0]']           
 tchNormalization)                                                                                
                                                                                                  
 dropout_59 (Dropout)        (None, 288, 408, 256)        0         ['batch_normalization_61[0][0]
                                                                    ']                            
                                                                                                  
 add_6 (Add)                 (None, 288, 408, 256)        0         ['batch_normalization_62[0][0]
                                                                    ',                            
                                                                     'dropout_59[0][0]']          
                                                                                                  
 activation_67 (Activation)  (None, 288, 408, 256)        0         ['add_6[0][0]']               
                                                                                                  
 max_pooling2d_14 (MaxPooli  (None, 144, 204, 256)        0         ['activation_67[0][0]']       
 ng2D)                                                                                            
                                                                                                  
 average_pooling2d_14 (Aver  (None, 144, 204, 256)        0         ['activation_67[0][0]']       
 agePooling2D)                                                                                    
                                                                                                  
 concatenate_26 (Concatenat  (None, 144, 204, 512)        0         ['max_pooling2d_14[0][0]',    
 e)                                                                  'average_pooling2d_14[0][0]']
                                                                                                  
 conv2d_72 (Conv2D)          (None, 144, 204, 512)        2359808   ['concatenate_26[0][0]']      
                                                                                                  
 batch_normalization_63 (Ba  (None, 144, 204, 512)        2048      ['conv2d_72[0][0]']           
 tchNormalization)                                                                                
                                                                                                  
 activation_68 (Activation)  (None, 144, 204, 512)        0         ['batch_normalization_63[0][0]
                                                                    ']                            
                                                                                                  
 dropout_60 (Dropout)        (None, 144, 204, 512)        0         ['activation_68[0][0]']       
                                                                                                  
 conv2d_73 (Conv2D)          (None, 144, 204, 512)        2359808   ['dropout_60[0][0]']          
                                                                                                  
 conv2d_74 (Conv2D)          (None, 144, 204, 512)        262656    ['concatenate_26[0][0]']      
                                                                                                  
 batch_normalization_64 (Ba  (None, 144, 204, 512)        2048      ['conv2d_73[0][0]']           
 tchNormalization)                                                                                
                                                                                                  
 batch_normalization_65 (Ba  (None, 144, 204, 512)        2048      ['conv2d_74[0][0]']           
 tchNormalization)                                                                                
                                                                                                  
 dropout_61 (Dropout)        (None, 144, 204, 512)        0         ['batch_normalization_64[0][0]
                                                                    ']                            
                                                                                                  
 add_7 (Add)                 (None, 144, 204, 512)        0         ['batch_normalization_65[0][0]
                                                                    ',                            
                                                                     'dropout_61[0][0]']          
                                                                                                  
 activation_69 (Activation)  (None, 144, 204, 512)        0         ['add_7[0][0]']               
                                                                                                  
 max_pooling2d_15 (MaxPooli  (None, 72, 102, 512)         0         ['activation_69[0][0]']       
 ng2D)                                                                                            
                                                                                                  
 average_pooling2d_15 (Aver  (None, 72, 102, 512)         0         ['activation_69[0][0]']       
 agePooling2D)                                                                                    
                                                                                                  
 concatenate_27 (Concatenat  (None, 72, 102, 1024)        0         ['max_pooling2d_15[0][0]',    
 e)                                                                  'average_pooling2d_15[0][0]']
                                                                                                  
 conv2d_75 (Conv2D)          (None, 72, 102, 1024)        9438208   ['concatenate_27[0][0]']      
                                                                                                  
 batch_normalization_66 (Ba  (None, 72, 102, 1024)        4096      ['conv2d_75[0][0]']           
 tchNormalization)                                                                                
                                                                                                  
 activation_70 (Activation)  (None, 72, 102, 1024)        0         ['batch_normalization_66[0][0]
                                                                    ']                            
                                                                                                  
 dropout_62 (Dropout)        (None, 72, 102, 1024)        0         ['activation_70[0][0]']       
                                                                                                  
 conv2d_76 (Conv2D)          (None, 72, 102, 1024)        9438208   ['dropout_62[0][0]']          
                                                                                                  
 conv2d_77 (Conv2D)          (None, 72, 102, 1024)        1049600   ['concatenate_27[0][0]']      
                                                                                                  
 batch_normalization_67 (Ba  (None, 72, 102, 1024)        4096      ['conv2d_76[0][0]']           
 tchNormalization)                                                                                
                                                                                                  
 batch_normalization_68 (Ba  (None, 72, 102, 1024)        4096      ['conv2d_77[0][0]']           
 tchNormalization)                                                                                
                                                                                                  
 dropout_63 (Dropout)        (None, 72, 102, 1024)        0         ['batch_normalization_67[0][0]
                                                                    ']                            
                                                                                                  
 add_8 (Add)                 (None, 72, 102, 1024)        0         ['batch_normalization_68[0][0]
                                                                    ',                            
                                                                     'dropout_63[0][0]']          
                                                                                                  
 activation_71 (Activation)  (None, 72, 102, 1024)        0         ['add_8[0][0]']               
                                                                                                  
 gating_signal0_conv (Conv2  (None, 72, 102, 512)         524800    ['activation_71[0][0]']       
 D)                                                                                               
                                                                                                  
 gating_signal0_bn (BatchNo  (None, 72, 102, 512)         2048      ['gating_signal0_conv[0][0]'] 
 rmalization)                                                                                     
                                                                                                  
 gating_signal0_act (Activa  (None, 72, 102, 512)         0         ['gating_signal0_bn[0][0]']   
 tion)                                                                                            
                                                                                                  
 conv2d_78 (Conv2D)          (None, 72, 102, 512)         262656    ['gating_signal0_act[0][0]']  
                                                                                                  
 g_upattention_block0 (Conv  (None, 72, 102, 512)         2359808   ['conv2d_78[0][0]']           
 2DTranspose)                                                                                     
                                                                                                  
 xlattention_block0 (Conv2D  (None, 72, 102, 512)         1049088   ['activation_69[0][0]']       
 )                                                                                                
                                                                                                  
 add_9 (Add)                 (None, 72, 102, 512)         0         ['g_upattention_block0[0][0]',
                                                                     'xlattention_block0[0][0]']  
                                                                                                  
 activation_72 (Activation)  (None, 72, 102, 512)         0         ['add_9[0][0]']               
                                                                                                  
 psiattention_block0 (Conv2  (None, 72, 102, 1)           513       ['activation_72[0][0]']       
 D)                                                                                               
                                                                                                  
 activation_73 (Activation)  (None, 72, 102, 1)           0         ['psiattention_block0[0][0]'] 
                                                                                                  
 up_sampling2d_4 (UpSamplin  (None, 144, 204, 1)          0         ['activation_73[0][0]']       
 g2D)                                                                                             
                                                                                                  
 psi_upattention_block0 (La  (None, 144, 204, 512)        0         ['up_sampling2d_4[0][0]']     
 mbda)                                                                                            
                                                                                                  
 q_attnattention_block0 (Mu  (None, 144, 204, 512)        0         ['psi_upattention_block0[0][0]
 ltiply)                                                            ',                            
                                                                     'activation_69[0][0]']       
                                                                                                  
 q_attn_convattention_block  (None, 144, 204, 512)        262656    ['q_attnattention_block0[0][0]
 0 (Conv2D)                                                         ']                            
                                                                                                  
 upsampling_0 (UpSampling2D  (None, 144, 204, 1024)       0         ['activation_71[0][0]']       
 )                                                                                                
                                                                                                  
 q_attn_bnattention_block0   (None, 144, 204, 512)        2048      ['q_attn_convattention_block0[
 (BatchNormalization)                                               0][0]']                       
                                                                                                  
 concatenate_28 (Concatenat  (None, 144, 204, 1536)       0         ['upsampling_0[0][0]',        
 e)                                                                  'q_attn_bnattention_block0[0]
                                                                    [0]']                         
                                                                                                  
 conv2d_79 (Conv2D)          (None, 144, 204, 512)        7078400   ['concatenate_28[0][0]']      
                                                                                                  
 batch_normalization_69 (Ba  (None, 144, 204, 512)        2048      ['conv2d_79[0][0]']           
 tchNormalization)                                                                                
                                                                                                  
 activation_74 (Activation)  (None, 144, 204, 512)        0         ['batch_normalization_69[0][0]
                                                                    ']                            
                                                                                                  
 dropout_64 (Dropout)        (None, 144, 204, 512)        0         ['activation_74[0][0]']       
                                                                                                  
 conv2d_80 (Conv2D)          (None, 144, 204, 512)        2359808   ['dropout_64[0][0]']          
                                                                                                  
 conv2d_81 (Conv2D)          (None, 144, 204, 512)        786944    ['concatenate_28[0][0]']      
                                                                                                  
 batch_normalization_70 (Ba  (None, 144, 204, 512)        2048      ['conv2d_80[0][0]']           
 tchNormalization)                                                                                
                                                                                                  
 batch_normalization_71 (Ba  (None, 144, 204, 512)        2048      ['conv2d_81[0][0]']           
 tchNormalization)                                                                                
                                                                                                  
 dropout_65 (Dropout)        (None, 144, 204, 512)        0         ['batch_normalization_70[0][0]
                                                                    ']                            
                                                                                                  
 add_10 (Add)                (None, 144, 204, 512)        0         ['batch_normalization_71[0][0]
                                                                    ',                            
                                                                     'dropout_65[0][0]']          
                                                                                                  
 activation_75 (Activation)  (None, 144, 204, 512)        0         ['add_10[0][0]']              
                                                                                                  
 gating_signal1_conv (Conv2  (None, 144, 204, 256)        131328    ['activation_75[0][0]']       
 D)                                                                                               
                                                                                                  
 gating_signal1_bn (BatchNo  (None, 144, 204, 256)        1024      ['gating_signal1_conv[0][0]'] 
 rmalization)                                                                                     
                                                                                                  
 gating_signal1_act (Activa  (None, 144, 204, 256)        0         ['gating_signal1_bn[0][0]']   
 tion)                                                                                            
                                                                                                  
 conv2d_82 (Conv2D)          (None, 144, 204, 256)        65792     ['gating_signal1_act[0][0]']  
                                                                                                  
 g_upattention_block1 (Conv  (None, 144, 204, 256)        590080    ['conv2d_82[0][0]']           
 2DTranspose)                                                                                     
                                                                                                  
 xlattention_block1 (Conv2D  (None, 144, 204, 256)        262400    ['activation_67[0][0]']       
 )                                                                                                
                                                                                                  
 add_11 (Add)                (None, 144, 204, 256)        0         ['g_upattention_block1[0][0]',
                                                                     'xlattention_block1[0][0]']  
                                                                                                  
 activation_76 (Activation)  (None, 144, 204, 256)        0         ['add_11[0][0]']              
                                                                                                  
 psiattention_block1 (Conv2  (None, 144, 204, 1)          257       ['activation_76[0][0]']       
 D)                                                                                               
                                                                                                  
 activation_77 (Activation)  (None, 144, 204, 1)          0         ['psiattention_block1[0][0]'] 
                                                                                                  
 up_sampling2d_5 (UpSamplin  (None, 288, 408, 1)          0         ['activation_77[0][0]']       
 g2D)                                                                                             
                                                                                                  
 psi_upattention_block1 (La  (None, 288, 408, 256)        0         ['up_sampling2d_5[0][0]']     
 mbda)                                                                                            
                                                                                                  
 q_attnattention_block1 (Mu  (None, 288, 408, 256)        0         ['psi_upattention_block1[0][0]
 ltiply)                                                            ',                            
                                                                     'activation_67[0][0]']       
                                                                                                  
 q_attn_convattention_block  (None, 288, 408, 256)        65792     ['q_attnattention_block1[0][0]
 1 (Conv2D)                                                         ']                            
                                                                                                  
 upsampling_1 (UpSampling2D  (None, 288, 408, 512)        0         ['activation_75[0][0]']       
 )                                                                                                
                                                                                                  
 q_attn_bnattention_block1   (None, 288, 408, 256)        1024      ['q_attn_convattention_block1[
 (BatchNormalization)                                               0][0]']                       
                                                                                                  
 concatenate_29 (Concatenat  (None, 288, 408, 768)        0         ['upsampling_1[0][0]',        
 e)                                                                  'q_attn_bnattention_block1[0]
                                                                    [0]']                         
                                                                                                  
 conv2d_83 (Conv2D)          (None, 288, 408, 256)        1769728   ['concatenate_29[0][0]']      
                                                                                                  
 batch_normalization_72 (Ba  (None, 288, 408, 256)        1024      ['conv2d_83[0][0]']           
 tchNormalization)                                                                                
                                                                                                  
 activation_78 (Activation)  (None, 288, 408, 256)        0         ['batch_normalization_72[0][0]
                                                                    ']                            
                                                                                                  
 dropout_66 (Dropout)        (None, 288, 408, 256)        0         ['activation_78[0][0]']       
                                                                                                  
 conv2d_84 (Conv2D)          (None, 288, 408, 256)        590080    ['dropout_66[0][0]']          
                                                                                                  
 conv2d_85 (Conv2D)          (None, 288, 408, 256)        196864    ['concatenate_29[0][0]']      
                                                                                                  
 batch_normalization_73 (Ba  (None, 288, 408, 256)        1024      ['conv2d_84[0][0]']           
 tchNormalization)                                                                                
                                                                                                  
 batch_normalization_74 (Ba  (None, 288, 408, 256)        1024      ['conv2d_85[0][0]']           
 tchNormalization)                                                                                
                                                                                                  
 dropout_67 (Dropout)        (None, 288, 408, 256)        0         ['batch_normalization_73[0][0]
                                                                    ']                            
                                                                                                  
 add_12 (Add)                (None, 288, 408, 256)        0         ['batch_normalization_74[0][0]
                                                                    ',                            
                                                                     'dropout_67[0][0]']          
                                                                                                  
 activation_79 (Activation)  (None, 288, 408, 256)        0         ['add_12[0][0]']              
                                                                                                  
 gating_signal2_conv (Conv2  (None, 288, 408, 128)        32896     ['activation_79[0][0]']       
 D)                                                                                               
                                                                                                  
 gating_signal2_bn (BatchNo  (None, 288, 408, 128)        512       ['gating_signal2_conv[0][0]'] 
 rmalization)                                                                                     
                                                                                                  
 gating_signal2_act (Activa  (None, 288, 408, 128)        0         ['gating_signal2_bn[0][0]']   
 tion)                                                                                            
                                                                                                  
 conv2d_86 (Conv2D)          (None, 288, 408, 128)        16512     ['gating_signal2_act[0][0]']  
                                                                                                  
 g_upattention_block2 (Conv  (None, 288, 408, 128)        147584    ['conv2d_86[0][0]']           
 2DTranspose)                                                                                     
                                                                                                  
 xlattention_block2 (Conv2D  (None, 288, 408, 128)        65664     ['activation_65[0][0]']       
 )                                                                                                
                                                                                                  
 add_13 (Add)                (None, 288, 408, 128)        0         ['g_upattention_block2[0][0]',
                                                                     'xlattention_block2[0][0]']  
                                                                                                  
 activation_80 (Activation)  (None, 288, 408, 128)        0         ['add_13[0][0]']              
                                                                                                  
 psiattention_block2 (Conv2  (None, 288, 408, 1)          129       ['activation_80[0][0]']       
 D)                                                                                               
                                                                                                  
 activation_81 (Activation)  (None, 288, 408, 1)          0         ['psiattention_block2[0][0]'] 
                                                                                                  
 up_sampling2d_6 (UpSamplin  (None, 576, 816, 1)          0         ['activation_81[0][0]']       
 g2D)                                                                                             
                                                                                                  
 psi_upattention_block2 (La  (None, 576, 816, 128)        0         ['up_sampling2d_6[0][0]']     
 mbda)                                                                                            
                                                                                                  
 q_attnattention_block2 (Mu  (None, 576, 816, 128)        0         ['psi_upattention_block2[0][0]
 ltiply)                                                            ',                            
                                                                     'activation_65[0][0]']       
                                                                                                  
 q_attn_convattention_block  (None, 576, 816, 128)        16512     ['q_attnattention_block2[0][0]
 2 (Conv2D)                                                         ']                            
                                                                                                  
 upsampling_2 (UpSampling2D  (None, 576, 816, 256)        0         ['activation_79[0][0]']       
 )                                                                                                
                                                                                                  
 q_attn_bnattention_block2   (None, 576, 816, 128)        512       ['q_attn_convattention_block2[
 (BatchNormalization)                                               0][0]']                       
                                                                                                  
 concatenate_30 (Concatenat  (None, 576, 816, 384)        0         ['upsampling_2[0][0]',        
 e)                                                                  'q_attn_bnattention_block2[0]
                                                                    [0]']                         
                                                                                                  
 conv2d_87 (Conv2D)          (None, 576, 816, 128)        442496    ['concatenate_30[0][0]']      
                                                                                                  
 batch_normalization_75 (Ba  (None, 576, 816, 128)        512       ['conv2d_87[0][0]']           
 tchNormalization)                                                                                
                                                                                                  
 activation_82 (Activation)  (None, 576, 816, 128)        0         ['batch_normalization_75[0][0]
                                                                    ']                            
                                                                                                  
 dropout_68 (Dropout)        (None, 576, 816, 128)        0         ['activation_82[0][0]']       
                                                                                                  
 conv2d_88 (Conv2D)          (None, 576, 816, 128)        147584    ['dropout_68[0][0]']          
                                                                                                  
 conv2d_89 (Conv2D)          (None, 576, 816, 128)        49280     ['concatenate_30[0][0]']      
                                                                                                  
 batch_normalization_76 (Ba  (None, 576, 816, 128)        512       ['conv2d_88[0][0]']           
 tchNormalization)                                                                                
                                                                                                  
 batch_normalization_77 (Ba  (None, 576, 816, 128)        512       ['conv2d_89[0][0]']           
 tchNormalization)                                                                                
                                                                                                  
 dropout_69 (Dropout)        (None, 576, 816, 128)        0         ['batch_normalization_76[0][0]
                                                                    ']                            
                                                                                                  
 add_14 (Add)                (None, 576, 816, 128)        0         ['batch_normalization_77[0][0]
                                                                    ',                            
                                                                     'dropout_69[0][0]']          
                                                                                                  
 activation_83 (Activation)  (None, 576, 816, 128)        0         ['add_14[0][0]']              
                                                                                                  
 gating_signal3_conv (Conv2  (None, 576, 816, 64)         8256      ['activation_83[0][0]']       
 D)                                                                                               
                                                                                                  
 gating_signal3_bn (BatchNo  (None, 576, 816, 64)         256       ['gating_signal3_conv[0][0]'] 
 rmalization)                                                                                     
                                                                                                  
 gating_signal3_act (Activa  (None, 576, 816, 64)         0         ['gating_signal3_bn[0][0]']   
 tion)                                                                                            
                                                                                                  
 conv2d_90 (Conv2D)          (None, 576, 816, 64)         4160      ['gating_signal3_act[0][0]']  
                                                                                                  
 g_upattention_block3 (Conv  (None, 576, 816, 64)         36928     ['conv2d_90[0][0]']           
 2DTranspose)                                                                                     
                                                                                                  
 xlattention_block3 (Conv2D  (None, 576, 816, 64)         16448     ['activation_63[0][0]']       
 )                                                                                                
                                                                                                  
 add_15 (Add)                (None, 576, 816, 64)         0         ['g_upattention_block3[0][0]',
                                                                     'xlattention_block3[0][0]']  
                                                                                                  
 activation_84 (Activation)  (None, 576, 816, 64)         0         ['add_15[0][0]']              
                                                                                                  
 psiattention_block3 (Conv2  (None, 576, 816, 1)          65        ['activation_84[0][0]']       
 D)                                                                                               
                                                                                                  
 activation_85 (Activation)  (None, 576, 816, 1)          0         ['psiattention_block3[0][0]'] 
                                                                                                  
 up_sampling2d_7 (UpSamplin  (None, 1152, 1632, 1)        0         ['activation_85[0][0]']       
 g2D)                                                                                             
                                                                                                  
 psi_upattention_block3 (La  (None, 1152, 1632, 64)       0         ['up_sampling2d_7[0][0]']     
 mbda)                                                                                            
                                                                                                  
 q_attnattention_block3 (Mu  (None, 1152, 1632, 64)       0         ['psi_upattention_block3[0][0]
 ltiply)                                                            ',                            
                                                                     'activation_63[0][0]']       
                                                                                                  
 q_attn_convattention_block  (None, 1152, 1632, 64)       4160      ['q_attnattention_block3[0][0]
 3 (Conv2D)                                                         ']                            
                                                                                                  
 upsampling_3 (UpSampling2D  (None, 1152, 1632, 128)      0         ['activation_83[0][0]']       
 )                                                                                                
                                                                                                  
 q_attn_bnattention_block3   (None, 1152, 1632, 64)       256       ['q_attn_convattention_block3[
 (BatchNormalization)                                               0][0]']                       
                                                                                                  
 concatenate_31 (Concatenat  (None, 1152, 1632, 192)      0         ['upsampling_3[0][0]',        
 e)                                                                  'q_attn_bnattention_block3[0]
                                                                    [0]']                         
                                                                                                  
 conv2d_91 (Conv2D)          (None, 1152, 1632, 64)       110656    ['concatenate_31[0][0]']      
                                                                                                  
 batch_normalization_78 (Ba  (None, 1152, 1632, 64)       256       ['conv2d_91[0][0]']           
 tchNormalization)                                                                                
                                                                                                  
 activation_86 (Activation)  (None, 1152, 1632, 64)       0         ['batch_normalization_78[0][0]
                                                                    ']                            
                                                                                                  
 dropout_70 (Dropout)        (None, 1152, 1632, 64)       0         ['activation_86[0][0]']       
                                                                                                  
 conv2d_92 (Conv2D)          (None, 1152, 1632, 64)       36928     ['dropout_70[0][0]']          
                                                                                                  
 conv2d_93 (Conv2D)          (None, 1152, 1632, 64)       12352     ['concatenate_31[0][0]']      
                                                                                                  
 batch_normalization_79 (Ba  (None, 1152, 1632, 64)       256       ['conv2d_92[0][0]']           
 tchNormalization)                                                                                
                                                                                                  
 batch_normalization_80 (Ba  (None, 1152, 1632, 64)       256       ['conv2d_93[0][0]']           
 tchNormalization)                                                                                
                                                                                                  
 dropout_71 (Dropout)        (None, 1152, 1632, 64)       0         ['batch_normalization_79[0][0]
                                                                    ']                            
                                                                                                  
 add_16 (Add)                (None, 1152, 1632, 64)       0         ['batch_normalization_80[0][0]
                                                                    ',                            
                                                                     'dropout_71[0][0]']          
                                                                                                  
 activation_87 (Activation)  (None, 1152, 1632, 64)       0         ['add_16[0][0]']              
                                                                                                  
 conv2d_94 (Conv2D)          (None, 1152, 1632, 1)        65        ['activation_87[0][0]']       
                                                                                                  
==================================================================================================
Total params: 46052293 (175.68 MB)
Trainable params: 46030789 (175.59 MB)
Non-trainable params: 21504 (84.00 KB)
__________________________________________________________________________________________________
None

In case of pytorch environment

from segmentation_test.pytorch_model_development import *
from segmentation_test.pytorch_model_development import *
from segmentation_test.dataloader_creation import *
from segmentation_test.pytorch_training_and_loss import *
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
model = UNet(
    in_channels=1, 
    out_channels=1)
im_path = Path(r'/home/hasan/workspace/data/microscopy_data/patch_images')
msk_path = Path(r'/home/hasan/workspace/data/microscopy_data/patch_masks')

train_loader, val_loader = create_pytorch_dataloader(
    split_type='random',
    split_per=0.1,
    batch_size=2,
    image_path=im_path,
    mask_path=msk_path,
    transforms=None,
    num_workers=0

)
 Number of images found = 1642
 training dataset length = 164 and validation dataset length=  1478
torch.Size([1, 256, 256]) torch.Size([1, 256, 256])
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
criterion = FocalLoss()
#| export
total_epochs = 10
# Define LR scheduler
scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=2, eta_min=0)
warmup_epochs = 0.1 * total_epochs
initial_lr = 0.002
loss_fn= FocalLoss()
model_save_path = Path.cwd()
warmup_epochs=2
total_epochs=10
initial_lr=0.002
train_segmentation_model(
                            train_loader,
                            val_loader,
                            model,
                            optimizer,
                            scheduler, 
                            loss_fn, 
                            model_save_path, 
                            warmup_epochs, 
                            total_epochs, 
                            initial_lr, 
                            dtype=torch.float,
                            device='cuda')
Epoch 1/10:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 1/10 - Train Loss: 0.0130, IoU: 0.0254, FP: 0.0, FN: 0.0
Validation Loss: 0.0097, IoU: 0.3979, FP: 0.0, FN: 0.0

Epoch 2/10:   0%|          | 0/82 [00:00<?, ?batch/s]