STMicroelectronics/stm32ai-modelzoo

Issue with layers.Input for a UNet model

kirilllzaitsev opened this issue · 2 comments

Hi, do you have any examples of how to fit architectures such as UNet, Autoencoder, etc. onto an STM32 device?
Trying to do it with a UNet I define below, I receive the error: NOT IMPLEMENTED: Order of dimensions of input cannot be interpreted

The issue must be in the way I define inputs: layers.Input(shape=(*img_size, in_channels), name="input"), but I see lots of similar cases that work. Can it be that the skip-connection architecture impacts tflite conversion, causing the issue?

My model is:

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_1 (InputLayer)           [(None, 28, 28, 1)]  0           []                               
                                                                                                  
 conv2d (Conv2D)                (None, 14, 14, 16)   32          ['input_1[0][0]']                
                                                                                                  
 batch_normalization (BatchNorm  (None, 14, 14, 16)  64          ['conv2d[0][0]']                 
 alization)                                                                                       
                                                                                                  
 activation (Activation)        (None, 14, 14, 16)   0           ['batch_normalization[0][0]']    
                                                                                                  
 activation_1 (Activation)      (None, 14, 14, 16)   0           ['activation[0][0]']             
                                                                                                  
 separable_conv2d (SeparableCon  (None, 14, 14, 32)  688         ['activation_1[0][0]']           
 v2D)                                                                                             
                                                                                                  
 batch_normalization_1 (BatchNo  (None, 14, 14, 32)  128         ['separable_conv2d[0][0]']       
 rmalization)                                                                                     
                                                                                                  
 activation_2 (Activation)      (None, 14, 14, 32)   0           ['batch_normalization_1[0][0]']  
                                                                                                  
 separable_conv2d_1 (SeparableC  (None, 14, 14, 32)  1344        ['activation_2[0][0]']           
 onv2D)                                                                                           
                                                                                                  
 batch_normalization_2 (BatchNo  (None, 14, 14, 32)  128         ['separable_conv2d_1[0][0]']     
 rmalization)                                                                                     
                                                                                                  
 max_pooling2d (MaxPooling2D)   (None, 7, 7, 32)     0           ['batch_normalization_2[0][0]']  
                                                                                                  
 conv2d_1 (Conv2D)              (None, 7, 7, 32)     544         ['activation[0][0]']             
                                                                                                  
 add (Add)                      (None, 7, 7, 32)     0           ['max_pooling2d[0][0]',          
                                                                  'conv2d_1[0][0]']               
                                                                                                  
 activation_3 (Activation)      (None, 7, 7, 32)     0           ['add[0][0]']                    
                                                                                                  
 conv2d_transpose (Conv2DTransp  (None, 7, 7, 32)    9248        ['activation_3[0][0]']           
 ose)                                                                                             
                                                                                                  
 batch_normalization_3 (BatchNo  (None, 7, 7, 32)    128         ['conv2d_transpose[0][0]']       
 rmalization)                                                                                     
                                                                                                  
 activation_4 (Activation)      (None, 7, 7, 32)     0           ['batch_normalization_3[0][0]']  
                                                                                                  
 conv2d_transpose_1 (Conv2DTran  (None, 7, 7, 32)    9248        ['activation_4[0][0]']           
 spose)                                                                                           
                                                                                                  
 batch_normalization_4 (BatchNo  (None, 7, 7, 32)    128         ['conv2d_transpose_1[0][0]']     
 rmalization)                                                                                     
                                                                                                  
 up_sampling2d_1 (UpSampling2D)  (None, 14, 14, 32)  0           ['add[0][0]']                    
                                                                                                  
 up_sampling2d (UpSampling2D)   (None, 14, 14, 32)   0           ['batch_normalization_4[0][0]']  
                                                                                                  
 conv2d_2 (Conv2D)              (None, 14, 14, 32)   1056        ['up_sampling2d_1[0][0]']        
                                                                                                  
 add_1 (Add)                    (None, 14, 14, 32)   0           ['up_sampling2d[0][0]',          
                                                                  'conv2d_2[0][0]']               
                                                                                                  
 activation_5 (Activation)      (None, 14, 14, 32)   0           ['add_1[0][0]']                  
                                                                                                  
 conv2d_transpose_2 (Conv2DTran  (None, 14, 14, 16)  4624        ['activation_5[0][0]']           
 spose)                                                                                           
                                                                                                  
 batch_normalization_5 (BatchNo  (None, 14, 14, 16)  64          ['conv2d_transpose_2[0][0]']     
 rmalization)                                                                                     
                                                                                                  
 activation_6 (Activation)      (None, 14, 14, 16)   0           ['batch_normalization_5[0][0]']  
                                                                                                  
 conv2d_transpose_3 (Conv2DTran  (None, 14, 14, 16)  2320        ['activation_6[0][0]']           
 spose)                                                                                           
                                                                                                  
 batch_normalization_6 (BatchNo  (None, 14, 14, 16)  64          ['conv2d_transpose_3[0][0]']     
 rmalization)                                                                                     
                                                                                                  
 up_sampling2d_3 (UpSampling2D)  (None, 28, 28, 32)  0           ['add_1[0][0]']                  
                                                                                                  
 up_sampling2d_2 (UpSampling2D)  (None, 28, 28, 16)  0           ['batch_normalization_6[0][0]']  
                                                                                                  
 conv2d_3 (Conv2D)              (None, 28, 28, 16)   528         ['up_sampling2d_3[0][0]']        
                                                                                                  
 add_2 (Add)                    (None, 28, 28, 16)   0           ['up_sampling2d_2[0][0]',        
                                                                  'conv2d_3[0][0]']               
                                                                                                  
 conv2d_4 (Conv2D)              (None, 28, 28, 1)    17          ['add_2[0][0]']                  
                                                                                                  
==================================================================================================
Total params: 30,353
Trainable params: 30,001
Non-trainable params: 352
__________________________________________________________________________________________________

LFOSTM commented

Hello,
Actually this should not happen...
Could you share the model with us (.h5, .tflite or .onnx), please?
Thanks

Closing the issue, as we did not manage to reproduce the issue on our side and did not get any additional information from the reporter.