Doodleverse/segmentation_gym

Add 'transfer learning' Transformer-based model architecture

dbuscombe-usgs opened this issue ยท 29 comments

After discussions with (and initial exploration by) @ebgoldstein, it would be advantageous to add a model architecture for transfer learning tasks based on a non-CNN architecture. This might be especially useful on tasks not well suited to training from scratch with a UNet or Res-UNet

A good candidate is the Segformer

A basic implementation has been worked out by @ebgoldstein : https://github.com/ebgoldstein/Coastal_TF_templates/blob/main/src/Segformer.ipynb, which uses the smallest pre-trained model, mit-b0

An additional advantage is that it can accept any sized input image shape (within memory constraints)

To add as an optional model in Gym, the following steps are required. I'm currently implementing these on a new branch

  • update doodleverse-utils\model_imports.py to add a segformer model architecture, adding transformers\TFSegformerForSemanticSegmentation as a dependency
  • update doodleverse-utils\imports.py to update plot_seg_history_iou to deal with segformer model training curves (no mean iou or other metrics)
  • add transformers library to the conda yml file
  • add note about conda recipe in readme
  • in do_train, add a function read_seg_dataset_multiclass_segformer to parse npz files into tensors for model training
  • in do_train, save model weights using model.save_weights if model==segformer
  • in do_train, only use mixed precision if model is NOT segformer
  • in do_train, only use weighted loss if model is NOT segformer
  • in do_train, only plot model is NOT segformer
  • in do_train, update plotcomp_n_metrics and do_viz section to deal with differently shaped model inputs
  • update seg_images_in_folder to accept 'segformer' model
  • update doodleverse-utils\preduiction_imports\do_seg to accept 'segformer' model

One interesting thing about model outputs is that they are always (W/4, H/4), so the outputs require resizing to match TARGET_SIZE

I am going to deal with this by keeping the softmax scores floats until the last second, which means resizing the softmax tensors before the argmax operation, i.e.

est_label = model.predict(img).logits
nR, nC = label.shape
est_label = resize(est_label, (1, NCLASSES, nR,nC), preserve_range=True, clip=True)
imgPredict = tf.math.argmax(est_label, axis=1)[0]

Easily the most difficult part of this process is updating plotcomp_n_metrics ... the "channels first" and tiny model outputs break a lot of code

This is now implemented in https://github.com/Doodleverse/segmentation_gym/tree/dev_models

This requires an update to doodleverse-utils 0.0.20 https://pypi.org/project/doodleverse-utils/0.0.20/

@ebgoldstein (and @CameronBodine ) - would be great to give this a try! I pointed it at a simple 2-class model based on coast train and the outputs were decent, but not as good as a unet trained from scratch

ct_ortho_all_water_768_aug_nd_data_0000003003_overlay
ct_ortho_all_water_768_aug_nd_data_0000003004_overlay
ct_ortho_all_water_768_aug_nd_data_0000003029_overlay

ct_ortho_all_water_768_aug_nd_data_0000002999_overlay

  • Requires "MODEL": "segformer" in the config file
  • Works with existing npz files
  • Note this model is much larger in terms of parameters, so a smaller batch size is likely to be needed (using a RTX3080, a target size of 768x768, I was able to use a batch of 4 and 6, but 8 was too large)
  • Note the following config hyperparams are ignored when using segformer: "KERNEL", "STRIDE", "FILTERS", "DROPOUT", "DROPOUT_CHANGE_PER_LAYER", "DROPOUT_TYPE", "USE_DROPOUT_ON_UPSAMPLING", "LOSS"

To do:

  • update batch_do_train
  • update do_train and doodleverse-utils so model can compile and train using custom mean IoU and Dice metrics
  • update segmentation zoo\script funcs to accept 'segformer' model
  • update Gym 'utils' scripts where needed
  • update wiki and other docs
  • extensive testing!
    • 3 band
    • 1 band
    • >3 bands
    • nclasses=2
    • nclasses>2
    • how well does it work?

My first 'real' model using a segformer (a building detector) works quite well and trained very fast

xbd_building_RGB_768_v1_trainhist_7
xbd_building_RGB_768_v1_val_0
xbd_building_RGB_768_v1_val_1
xbd_building_RGB_768_v1_val_2
xbd_building_RGB_768_v1_val_3
xbd_building_RGB_768_v1_val_4
xbd_building_RGB_768_v1_val_5
xbd_building_RGB_768_v1_val_6
xbd_building_RGB_768_v1_val_7
xbd_building_RGB_768_v1_val_8

EDIT: Maybe it's caused by my 1-band imagery?

Ok, I gave it a shot, but it threw an error (see below). I again had issues with not finding libcuda library, similar to what I noted on #78 , so I went through the process of re-installing cuda and nvida on my device (see https://docs.nvidia.com/datacenter/tesla/tesla-installation-notes/index.html). When I run:

python -c "import tensorflow as tf; print(tf.config.list_physical_devices('GPU'))"

I receive confirmation that gpu's are visible:

2023-02-11 17:34:50.922715: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU'), PhysicalDevice(name='/physical_device:GPU:1', device_type='GPU')]

I also verified that I was able to train using a resunet, which was successful. I'm done for today, but will try and look further into the error.

Error running train_model.py with segformer

$ python train_model.py 
/mnt/md0/SynologyDrive/Modeling/03_TrainDatasets/SpdCor_Substrate_inclShadow
/mnt/md0/SynologyDrive/Modeling/99_ForTesting/SegFormer_SpdCor_Substrate_inclShadow/config/SegFormer_config_SpdCor_Substrate_inclShadow.json
Using GPU
Using single GPU device
2023-02-11 17:27:53.489527: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
Version:  2.11.0
Eager mode:  True
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'), PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
Making new directory for example model outputs: /mnt/md0/SynologyDrive/Modeling/99_ForTesting/SegFormer_SpdCor_Substrate_inclShadow/modelOut
MODE not specified in config file. Setting to "all" files
MODE "all": using all augmented and non-augmented files
2023-02-11 17:27:55.677431: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-02-11 17:27:56.438498: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1613] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 14606 MB memory:  -> device: 0, name: Quadro RTX 5000, pci bus id: 0000:65:00.0, compute capability: 7.5
162
97
.....................................
Creating and compiling model ...
2023-02-11 17:27:58.029189: I tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:428] Loaded cuDNN version 8401
2023-02-11 17:27:58.796337: I tensorflow/tsl/platform/default/subprocess.cc:304] Start cannot spawn child process: No such file or directory
2023-02-11 17:27:59.011073: I tensorflow/compiler/xla/service/service.cc:173] XLA service 0x56477d4309b0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2023-02-11 17:27:59.011106: I tensorflow/compiler/xla/service/service.cc:181]   StreamExecutor device (0): Quadro RTX 5000, Compute Capability 7.5
2023-02-11 17:27:59.024924: I tensorflow/tsl/platform/default/subprocess.cc:304] Start cannot spawn child process: No such file or directory
2023-02-11 17:27:59.357196: I tensorflow/compiler/jit/xla_compilation_cache.cc:477] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
WARNING:tensorflow:5 out of the last 5 calls to <function Conv._jit_compiled_convolution_op at 0x7f47ac7a09d0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
WARNING:tensorflow:6 out of the last 6 calls to <function Conv._jit_compiled_convolution_op at 0x7f47ac7a2290> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
Some layers from the model checkpoint at nvidia/mit-b0 were not used when initializing TFSegformerForSemanticSegmentation: ['classifier']
- This IS expected if you are initializing TFSegformerForSemanticSegmentation from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFSegformerForSemanticSegmentation from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some layers of TFSegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/mit-b0 and are newly initialized: ['decode_head']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
No loss specified in compile() - the model's internal loss computation will be used as the loss. Don't panic - this is a common way to train TensorFlow models in Transformers! To disable this behaviour please pass a loss argument, or explicitly pass `loss=None` if you do not want your model to compute a loss.
INITIAL_EPOCH not specified in the config file. Setting to default of 0 ...
.....................................
Training model ...

Epoch 1: LearningRateScheduler setting learning rate to 1e-07.
Epoch 1/10
Traceback (most recent call last):
  File "/mnt/md0/SynologyDrive/Modeling/segmentation_gym/train_model.py", line 904, in <module>
    history = model.fit(train_ds, steps_per_epoch=steps_per_epoch, epochs=MAX_EPOCHS,
  File "/home/cbodine/miniconda3/envs/gym/lib/python3.10/site-packages/keras/utils/traceback_utils.py", line 70, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/tmp/__autograph_generated_fileb2xrfpce.py", line 15, in tf__train_function
    retval_ = ag__.converted_call(ag__.ld(step_function), (ag__.ld(self), ag__.ld(iterator)), None, fscope)
  File "/home/cbodine/miniconda3/envs/gym/lib/python3.10/site-packages/transformers/modeling_tf_utils.py", line 1535, in train_step
    y_pred = self(x, training=True)
  File "/tmp/__autograph_generated_file15xswu7i.py", line 36, in tf__run_call_with_unpacked_inputs
    retval_ = ag__.converted_call(ag__.ld(func), (ag__.ld(self),), dict(**ag__.ld(unpacked_inputs)), fscope)
  File "/tmp/__autograph_generated_filepumbcoaq.py", line 40, in tf__call
    outputs = ag__.converted_call(ag__.ld(self).segformer, (ag__.ld(pixel_values),), dict(output_attentions=ag__.ld(output_attentions), output_hidden_states=True, return_dict=ag__.ld(return_dict)), fscope)
  File "/tmp/__autograph_generated_file15xswu7i.py", line 36, in tf__run_call_with_unpacked_inputs
    retval_ = ag__.converted_call(ag__.ld(func), (ag__.ld(self),), dict(**ag__.ld(unpacked_inputs)), fscope)
  File "/tmp/__autograph_generated_filespuksggp.py", line 14, in tf__call
    encoder_outputs = ag__.converted_call(ag__.ld(self).encoder, (ag__.ld(pixel_values),), dict(output_attentions=ag__.ld(output_attentions), output_hidden_states=ag__.ld(output_hidden_states), return_dict=ag__.ld(return_dict), training=ag__.ld(training)), fscope)
  File "/tmp/__autograph_generated_filefosql1x4.py", line 102, in tf__call
    ag__.for_stmt(ag__.converted_call(ag__.ld(enumerate), (ag__.converted_call(ag__.ld(zip), (ag__.ld(self).embeddings, ag__.ld(self).block, ag__.ld(self).layer_norms), None, fscope),), None, fscope), None, loop_body_1, get_state_4, set_state_4, ('all_hidden_states', 'all_self_attentions', 'hidden_states'), {'iterate_names': '(idx, x)'})
  File "/tmp/__autograph_generated_filefosql1x4.py", line 26, in loop_body_1
    (hidden_states, height, width) = ag__.converted_call(ag__.ld(embedding_layer), (ag__.ld(hidden_states),), None, fscope)
  File "/tmp/__autograph_generated_file3_8g14m8.py", line 10, in tf__call
    embeddings = ag__.converted_call(ag__.ld(self).proj, (ag__.converted_call(ag__.ld(self).padding, (ag__.ld(pixel_values),), None, fscope),), None, fscope)
ValueError: in user code:

    File "/home/cbodine/miniconda3/envs/gym/lib/python3.10/site-packages/keras/engine/training.py", line 1249, in train_function  *
        return step_function(self, iterator)
    File "/home/cbodine/miniconda3/envs/gym/lib/python3.10/site-packages/transformers/modeling_tf_utils.py", line 830, in run_call_with_unpacked_inputs  *
        return func(self, **unpacked_inputs)
    File "/home/cbodine/miniconda3/envs/gym/lib/python3.10/site-packages/transformers/models/segformer/modeling_tf_segformer.py", line 857, in call  *
        outputs = self.segformer(
    File "/home/cbodine/miniconda3/envs/gym/lib/python3.10/site-packages/transformers/modeling_tf_utils.py", line 830, in run_call_with_unpacked_inputs  *
        return func(self, **unpacked_inputs)
    File "/home/cbodine/miniconda3/envs/gym/lib/python3.10/site-packages/transformers/models/segformer/modeling_tf_segformer.py", line 481, in call  *
        encoder_outputs = self.encoder(
    File "/home/cbodine/miniconda3/envs/gym/lib/python3.10/site-packages/keras/utils/traceback_utils.py", line 70, in error_handler  **
        raise e.with_traceback(filtered_tb) from None
    File "/tmp/__autograph_generated_filefosql1x4.py", line 102, in tf__call
        ag__.for_stmt(ag__.converted_call(ag__.ld(enumerate), (ag__.converted_call(ag__.ld(zip), (ag__.ld(self).embeddings, ag__.ld(self).block, ag__.ld(self).layer_norms), None, fscope),), None, fscope), None, loop_body_1, get_state_4, set_state_4, ('all_hidden_states', 'all_self_attentions', 'hidden_states'), {'iterate_names': '(idx, x)'})
    File "/tmp/__autograph_generated_filefosql1x4.py", line 26, in loop_body_1
        (hidden_states, height, width) = ag__.converted_call(ag__.ld(embedding_layer), (ag__.ld(hidden_states),), None, fscope)
    File "/tmp/__autograph_generated_file3_8g14m8.py", line 10, in tf__call
        embeddings = ag__.converted_call(ag__.ld(self).proj, (ag__.converted_call(ag__.ld(self).padding, (ag__.ld(pixel_values),), None, fscope),), None, fscope)

    ValueError: Exception encountered when calling layer 'encoder' (type TFSegformerEncoder).
    
    in user code:
    
        File "/home/cbodine/miniconda3/envs/gym/lib/python3.10/site-packages/transformers/models/segformer/modeling_tf_segformer.py", line 416, in call  *
            hidden_states, height, width = embedding_layer(hidden_states)
        File "/home/cbodine/miniconda3/envs/gym/lib/python3.10/site-packages/keras/utils/traceback_utils.py", line 70, in error_handler  **
            raise e.with_traceback(filtered_tb) from None
        File "/tmp/__autograph_generated_file3_8g14m8.py", line 10, in tf__call
            embeddings = ag__.converted_call(ag__.ld(self).proj, (ag__.converted_call(ag__.ld(self).padding, (ag__.ld(pixel_values),), None, fscope),), None, fscope)
    
        ValueError: Exception encountered when calling layer 'patch_embeddings.0' (type TFSegformerOverlapPatchEmbeddings).
        
        in user code:
        
            File "/home/cbodine/miniconda3/envs/gym/lib/python3.10/site-packages/transformers/models/segformer/modeling_tf_segformer.py", line 89, in call  *
                embeddings = self.proj(self.padding(pixel_values))
            File "/home/cbodine/miniconda3/envs/gym/lib/python3.10/site-packages/keras/utils/traceback_utils.py", line 70, in error_handler  **
                raise e.with_traceback(filtered_tb) from None
            File "/home/cbodine/miniconda3/envs/gym/lib/python3.10/site-packages/keras/engine/input_spec.py", line 277, in assert_input_compatibility
                raise ValueError(
        
            ValueError: Input 0 of layer "proj" is incompatible with the layer: expected axis -1 of input shape to have value 3, but received input with shape (8, None, None, 1)
        
        
        Call arguments received by layer 'patch_embeddings.0' (type TFSegformerOverlapPatchEmbeddings):
          โ€ข pixel_values=tf.Tensor(shape=(8, None, None, 1), dtype=float32)
    
    
    Call arguments received by layer 'encoder' (type TFSegformerEncoder):
      โ€ข pixel_values=tf.Tensor(shape=(8, None, None, 1), dtype=float32)
      โ€ข output_attentions=False
      โ€ข output_hidden_states=True
      โ€ข return_dict=True
      โ€ข training=True

Conda env

# packages in environment at /home/cbodine/miniconda3/envs/gym:
#
# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                 conda_forge    conda-forge
_openmp_mutex             4.5                       2_gnu    conda-forge
absl-py                   1.4.0              pyhd8ed1ab_0    conda-forge
aiohttp                   3.8.3           py310h5764c6d_1    conda-forge
aiosignal                 1.3.1              pyhd8ed1ab_0    conda-forge
alsa-lib                  1.2.8                h166bdaf_0    conda-forge
anyio                     3.6.2              pyhd8ed1ab_0    conda-forge
aom                       3.5.0                h27087fc_0    conda-forge
appdirs                   1.4.4              pyh9f0ad1d_0    conda-forge
argon2-cffi               21.3.0             pyhd8ed1ab_0    conda-forge
argon2-cffi-bindings      21.2.0          py310h5764c6d_3    conda-forge
asttokens                 2.2.1              pyhd8ed1ab_0    conda-forge
astunparse                1.6.3              pyhd8ed1ab_0    conda-forge
async-timeout             4.0.2              pyhd8ed1ab_0    conda-forge
attr                      2.5.1                h166bdaf_1    conda-forge
attrs                     22.2.0             pyh71513ae_0    conda-forge
backcall                  0.2.0              pyh9f0ad1d_0    conda-forge
backports                 1.0                pyhd8ed1ab_3    conda-forge
backports.functools_lru_cache 1.6.4              pyhd8ed1ab_0    conda-forge
beautifulsoup4            4.11.2             pyha770c72_0    conda-forge
bleach                    6.0.0              pyhd8ed1ab_0    conda-forge
blinker                   1.5                pyhd8ed1ab_0    conda-forge
blosc                     1.21.3               hafa529b_0    conda-forge
brotli                    1.0.9                h166bdaf_8    conda-forge
brotli-bin                1.0.9                h166bdaf_8    conda-forge
brotlipy                  0.7.0           py310h5764c6d_1005    conda-forge
brunsli                   0.1                  h9c3ff4c_0    conda-forge
bzip2                     1.0.8                h7f98852_4    conda-forge
c-ares                    1.18.1               h7f98852_0    conda-forge
c-blosc2                  2.6.1                hf91038e_0    conda-forge
ca-certificates           2022.12.7            ha878542_0    conda-forge
cached-property           1.5.2                hd8ed1ab_1    conda-forge
cached_property           1.5.2              pyha770c72_1    conda-forge
cachetools                5.3.0              pyhd8ed1ab_0    conda-forge
cairo                     1.16.0            ha61ee94_1014    conda-forge
certifi                   2022.12.7          pyhd8ed1ab_0    conda-forge
cffi                      1.15.1          py310h255011f_3    conda-forge
cfitsio                   4.2.0                hd9d235c_0    conda-forge
charls                    2.4.1                hcb278e6_0    conda-forge
charset-normalizer        2.1.1              pyhd8ed1ab_0    conda-forge
click                     8.1.3           unix_pyhd8ed1ab_2    conda-forge
cloudpickle               2.2.1              pyhd8ed1ab_0    conda-forge
colorama                  0.4.6              pyhd8ed1ab_0    conda-forge
comm                      0.1.2              pyhd8ed1ab_0    conda-forge
contourpy                 1.0.7           py310hdf3cbec_0    conda-forge
cryptography              39.0.1          py310h34c0648_0    conda-forge
cudatoolkit               11.8.0              h37601d7_11    conda-forge
cudnn                     8.4.1.50             hed8a83a_0    conda-forge
cycler                    0.11.0             pyhd8ed1ab_0    conda-forge
cython                    0.29.33         py310heca2aa9_0    conda-forge
cytoolz                   0.12.0          py310h5764c6d_1    conda-forge
dask-core                 2023.2.0           pyhd8ed1ab_0    conda-forge
dav1d                     1.0.0                h166bdaf_1    conda-forge
dbus                      1.13.6               h5008d03_3    conda-forge
debugpy                   1.6.6           py310heca2aa9_0    conda-forge
decorator                 5.1.1              pyhd8ed1ab_0    conda-forge
defusedxml                0.7.1              pyhd8ed1ab_0    conda-forge
doodleverse-utils         0.0.22                   pypi_0    pypi
entrypoints               0.4                pyhd8ed1ab_0    conda-forge
executing                 1.2.0              pyhd8ed1ab_0    conda-forge
expat                     2.5.0                h27087fc_0    conda-forge
fftw                      3.3.10          nompi_hf0379b8_106    conda-forge
filelock                  3.9.0                    pypi_0    pypi
flatbuffers               22.12.06             hcb278e6_2    conda-forge
flit-core                 3.8.0              pyhd8ed1ab_0    conda-forge
font-ttf-dejavu-sans-mono 2.37                 hab24e00_0    conda-forge
font-ttf-inconsolata      3.000                h77eed37_0    conda-forge
font-ttf-source-code-pro  2.038                h77eed37_0    conda-forge
font-ttf-ubuntu           0.83                 hab24e00_0    conda-forge
fontconfig                2.14.2               h14ed4e7_0    conda-forge
fonts-conda-ecosystem     1                             0    conda-forge
fonts-conda-forge         1                             0    conda-forge
fonttools                 4.38.0          py310h5764c6d_1    conda-forge
freetype                  2.12.1               hca18f0e_1    conda-forge
frozenlist                1.3.3           py310h5764c6d_0    conda-forge
fsspec                    2023.1.0           pyhd8ed1ab_0    conda-forge
gast                      0.4.0              pyh9f0ad1d_0    conda-forge
gettext                   0.21.1               h27087fc_0    conda-forge
giflib                    5.2.1                h36c2ea0_2    conda-forge
glib                      2.74.1               h6239696_1    conda-forge
glib-tools                2.74.1               h6239696_1    conda-forge
google-auth               2.16.0             pyh1a96a4e_1    conda-forge
google-auth-oauthlib      0.4.6              pyhd8ed1ab_0    conda-forge
google-pasta              0.2.0              pyh8c360ce_0    conda-forge
graphite2                 1.3.13            h58526e2_1001    conda-forge
grpcio                    1.51.1          py310h4a5735c_1    conda-forge
gst-plugins-base          1.22.0               h4243ec0_0    conda-forge
gstreamer                 1.22.0               h25f0c4b_0    conda-forge
gstreamer-orc             0.4.33               h166bdaf_0    conda-forge
h5py                      3.8.0           nompi_py310ha66b2ad_101    conda-forge
harfbuzz                  6.0.0                h8e241bc_0    conda-forge
hdf5                      1.14.0          nompi_hb72d44e_102    conda-forge
huggingface-hub           0.12.0                   pypi_0    pypi
icu                       70.1                 h27087fc_0    conda-forge
idna                      3.4                pyhd8ed1ab_0    conda-forge
imagecodecs               2023.1.23       py310ha3ed6a1_0    conda-forge
imageio                   2.25.0             pyh24c5eb1_0    conda-forge
importlib-metadata        6.0.0              pyha770c72_0    conda-forge
importlib_metadata        6.0.0                hd8ed1ab_0    conda-forge
importlib_resources       5.10.2             pyhd8ed1ab_0    conda-forge
ipykernel                 6.21.1             pyh210e3f2_0    conda-forge
ipython                   8.10.0             pyh41d4057_0    conda-forge
ipython_genutils          0.2.0                      py_1    conda-forge
ipywidgets                8.0.4              pyhd8ed1ab_0    conda-forge
jack                      1.9.22               h11f4161_0    conda-forge
jedi                      0.18.2             pyhd8ed1ab_0    conda-forge
jinja2                    3.1.2              pyhd8ed1ab_1    conda-forge
joblib                    1.2.0              pyhd8ed1ab_0    conda-forge
jpeg                      9e                   h166bdaf_2    conda-forge
jsonschema                4.17.3             pyhd8ed1ab_0    conda-forge
jupyter                   1.0.0           py310hff52083_8    conda-forge
jupyter_client            8.0.2              pyhd8ed1ab_0    conda-forge
jupyter_console           6.4.4              pyhd8ed1ab_0    conda-forge
jupyter_core              5.2.0           py310hff52083_0    conda-forge
jupyter_events            0.6.3              pyhd8ed1ab_0    conda-forge
jupyter_server            2.2.1              pyhd8ed1ab_0    conda-forge
jupyter_server_terminals  0.4.4              pyhd8ed1ab_1    conda-forge
jupyterlab_pygments       0.2.2              pyhd8ed1ab_0    conda-forge
jupyterlab_widgets        3.0.5              pyhd8ed1ab_0    conda-forge
jxrlib                    1.1                  h7f98852_2    conda-forge
keras                     2.11.0             pyhd8ed1ab_0    conda-forge
keras-preprocessing       1.1.2              pyhd8ed1ab_0    conda-forge
keyutils                  1.6.1                h166bdaf_0    conda-forge
kiwisolver                1.4.4           py310hbf28c38_1    conda-forge
krb5                      1.20.1               h81ceb04_0    conda-forge
lame                      3.100             h166bdaf_1003    conda-forge
lcms2                     2.14                 hfd0df8a_1    conda-forge
ld_impl_linux-64          2.40                 h41732ed_0    conda-forge
lerc                      4.0.0                h27087fc_0    conda-forge
libabseil                 20220623.0      cxx17_h05df665_6    conda-forge
libaec                    1.0.6                hcb278e6_1    conda-forge
libavif                   0.11.1               h5cdd6b5_0    conda-forge
libblas                   3.9.0           16_linux64_openblas    conda-forge
libbrotlicommon           1.0.9                h166bdaf_8    conda-forge
libbrotlidec              1.0.9                h166bdaf_8    conda-forge
libbrotlienc              1.0.9                h166bdaf_8    conda-forge
libcap                    2.66                 ha37c62d_0    conda-forge
libcblas                  3.9.0           16_linux64_openblas    conda-forge
libclang                  15.0.7          default_had23c3d_1    conda-forge
libclang13                15.0.7          default_h3e3d535_1    conda-forge
libcups                   2.3.3                h36d4200_3    conda-forge
libcurl                   7.87.0               hdc1c0ab_0    conda-forge
libdb                     6.2.32               h9c3ff4c_0    conda-forge
libdeflate                1.17                 h0b41bf4_0    conda-forge
libedit                   3.1.20191231         he28a2e2_2    conda-forge
libev                     4.33                 h516909a_1    conda-forge
libevent                  2.1.10               h28343ad_4    conda-forge
libffi                    3.4.2                h7f98852_5    conda-forge
libflac                   1.4.2                h27087fc_0    conda-forge
libgcc-ng                 12.2.0              h65d4601_19    conda-forge
libgcrypt                 1.10.1               h166bdaf_0    conda-forge
libgfortran-ng            12.2.0              h69a702a_19    conda-forge
libgfortran5              12.2.0              h337968e_19    conda-forge
libglib                   2.74.1               h606061b_1    conda-forge
libgomp                   12.2.0              h65d4601_19    conda-forge
libgpg-error              1.46                 h620e276_0    conda-forge
libgrpc                   1.51.1               h4fad500_1    conda-forge
libiconv                  1.17                 h166bdaf_0    conda-forge
liblapack                 3.9.0           16_linux64_openblas    conda-forge
libllvm15                 15.0.7               hadd5161_0    conda-forge
libnghttp2                1.51.0               hff17c54_0    conda-forge
libnsl                    2.0.0                h7f98852_0    conda-forge
libogg                    1.3.4                h7f98852_1    conda-forge
libopenblas               0.3.21          pthreads_h78a6416_3    conda-forge
libopus                   1.3.1                h7f98852_1    conda-forge
libpng                    1.6.39               h753d276_0    conda-forge
libpq                     15.2                 hb675445_0    conda-forge
libprotobuf               3.21.12              h3eb15da_0    conda-forge
libsndfile                1.2.0                hb75c966_0    conda-forge
libsodium                 1.0.18               h36c2ea0_1    conda-forge
libsqlite                 3.40.0               h753d276_0    conda-forge
libssh2                   1.10.0               hf14f497_3    conda-forge
libstdcxx-ng              12.2.0              h46fd767_19    conda-forge
libsystemd0               252                  h2a991cd_0    conda-forge
libtiff                   4.5.0                h6adf6a1_2    conda-forge
libtool                   2.4.7                h27087fc_0    conda-forge
libudev1                  252                  h166bdaf_0    conda-forge
libuuid                   2.32.1            h7f98852_1000    conda-forge
libvorbis                 1.3.7                h9c3ff4c_0    conda-forge
libwebp-base              1.2.4                h166bdaf_0    conda-forge
libxcb                    1.13              h7f98852_1004    conda-forge
libxkbcommon              1.0.3                he3ba5ed_0    conda-forge
libxml2                   2.10.3               h7463322_0    conda-forge
libzlib                   1.2.13               h166bdaf_4    conda-forge
libzopfli                 1.0.3                h9c3ff4c_0    conda-forge
locket                    1.0.0              pyhd8ed1ab_0    conda-forge
lz4-c                     1.9.4                hcb278e6_0    conda-forge
markdown                  3.4.1              pyhd8ed1ab_0    conda-forge
markupsafe                2.1.2           py310h1fa729e_0    conda-forge
matplotlib                3.6.3           py310hff52083_0    conda-forge
matplotlib-base           3.6.3           py310he60537e_0    conda-forge
matplotlib-inline         0.1.6              pyhd8ed1ab_0    conda-forge
mistune                   2.0.5              pyhd8ed1ab_0    conda-forge
mpg123                    1.31.2               hcb278e6_0    conda-forge
multidict                 6.0.4           py310h1fa729e_0    conda-forge
munkres                   1.1.4              pyh9f0ad1d_0    conda-forge
mysql-common              8.0.32               ha901b37_0    conda-forge
mysql-libs                8.0.32               hd7da12d_0    conda-forge
natsort                   8.2.0              pyhd8ed1ab_0    conda-forge
nbclassic                 0.5.1              pyhd8ed1ab_0    conda-forge
nbclient                  0.7.2              pyhd8ed1ab_0    conda-forge
nbconvert                 7.2.9              pyhd8ed1ab_0    conda-forge
nbconvert-core            7.2.9              pyhd8ed1ab_0    conda-forge
nbconvert-pandoc          7.2.9              pyhd8ed1ab_0    conda-forge
nbformat                  5.7.3              pyhd8ed1ab_0    conda-forge
nccl                      2.14.3.1             h0800d71_0    conda-forge
ncurses                   6.3                  h27087fc_1    conda-forge
nest-asyncio              1.5.6              pyhd8ed1ab_0    conda-forge
networkx                  3.0                pyhd8ed1ab_0    conda-forge
notebook                  6.5.2              pyha770c72_1    conda-forge
notebook-shim             0.2.2              pyhd8ed1ab_0    conda-forge
nspr                      4.35                 h27087fc_0    conda-forge
nss                       3.88                 he45b914_0    conda-forge
numpy                     1.24.2          py310h8deb116_0    conda-forge
oauthlib                  3.2.2              pyhd8ed1ab_0    conda-forge
openjpeg                  2.5.0                hfec8fc6_2    conda-forge
openssl                   3.0.8                h0b41bf4_0    conda-forge
opt_einsum                3.3.0              pyhd8ed1ab_1    conda-forge
packaging                 23.0               pyhd8ed1ab_0    conda-forge
pandas                    1.5.3           py310h9b08913_0    conda-forge
pandoc                    2.19.2               h32600fe_1    conda-forge
pandocfilters             1.5.0              pyhd8ed1ab_0    conda-forge
parso                     0.8.3              pyhd8ed1ab_0    conda-forge
partd                     1.3.0              pyhd8ed1ab_0    conda-forge
pcre2                     10.40                hc3806b6_0    conda-forge
pexpect                   4.8.0              pyh1a96a4e_2    conda-forge
pickleshare               0.7.5                   py_1003    conda-forge
pillow                    9.4.0           py310h023d228_1    conda-forge
pip                       23.0               pyhd8ed1ab_0    conda-forge
pixman                    0.40.0               h36c2ea0_0    conda-forge
pkgutil-resolve-name      1.3.10             pyhd8ed1ab_0    conda-forge
platformdirs              3.0.0              pyhd8ed1ab_0    conda-forge
plotly                    5.13.0             pyhd8ed1ab_0    conda-forge
ply                       3.11                       py_1    conda-forge
pooch                     1.6.0              pyhd8ed1ab_0    conda-forge
prometheus_client         0.16.0             pyhd8ed1ab_0    conda-forge
prompt-toolkit            3.0.36             pyha770c72_0    conda-forge
prompt_toolkit            3.0.36               hd8ed1ab_0    conda-forge
protobuf                  4.21.12         py310heca2aa9_0    conda-forge
psutil                    5.9.4           py310h5764c6d_0    conda-forge
pthread-stubs             0.4               h36c2ea0_1001    conda-forge
ptyprocess                0.7.0              pyhd3deb0d_0    conda-forge
pulseaudio                16.1                 ha8d29e2_1    conda-forge
pure_eval                 0.2.2              pyhd8ed1ab_0    conda-forge
pyasn1                    0.4.8                      py_0    conda-forge
pyasn1-modules            0.2.7                      py_0    conda-forge
pycparser                 2.21               pyhd8ed1ab_0    conda-forge
pygments                  2.14.0             pyhd8ed1ab_0    conda-forge
pyjwt                     2.6.0              pyhd8ed1ab_0    conda-forge
pyopenssl                 23.0.0             pyhd8ed1ab_0    conda-forge
pyparsing                 3.0.9              pyhd8ed1ab_0    conda-forge
pyqt                      5.15.7          py310hab646b1_3    conda-forge
pyqt5-sip                 12.11.0         py310heca2aa9_3    conda-forge
pyrsistent                0.19.3          py310h1fa729e_0    conda-forge
pysocks                   1.7.1              pyha2e5f31_6    conda-forge
python                    3.10.9          he550d4f_0_cpython    conda-forge
python-dateutil           2.8.2              pyhd8ed1ab_0    conda-forge
python-fastjsonschema     2.16.2             pyhd8ed1ab_0    conda-forge
python-flatbuffers        23.1.21            pyhd8ed1ab_0    conda-forge
python-json-logger        2.0.4              pyhd8ed1ab_0    conda-forge
python_abi                3.10                    3_cp310    conda-forge
pytz                      2022.7.1           pyhd8ed1ab_0    conda-forge
pyu2f                     0.1.5              pyhd8ed1ab_0    conda-forge
pywavelets                1.4.1           py310h0a54255_0    conda-forge
pyyaml                    6.0             py310h5764c6d_5    conda-forge
pyzmq                     25.0.0          py310h059b190_0    conda-forge
qt-main                   5.15.8               h5d23da1_6    conda-forge
qtconsole                 5.4.0              pyhd8ed1ab_0    conda-forge
qtconsole-base            5.4.0              pyha770c72_0    conda-forge
qtpy                      2.3.0              pyhd8ed1ab_0    conda-forge
re2                       2023.02.01           hcb278e6_0    conda-forge
readline                  8.1.2                h0f457ee_0    conda-forge
regex                     2022.10.31               pypi_0    pypi
requests                  2.28.2             pyhd8ed1ab_0    conda-forge
requests-oauthlib         1.3.1              pyhd8ed1ab_0    conda-forge
rfc3339-validator         0.1.4              pyhd8ed1ab_0    conda-forge
rfc3986-validator         0.1.1              pyh9f0ad1d_0    conda-forge
rsa                       4.9                pyhd8ed1ab_0    conda-forge
scikit-image              0.19.3          py310h769672d_2    conda-forge
scipy                     1.10.0          py310h8deb116_2    conda-forge
send2trash                1.8.0              pyhd8ed1ab_0    conda-forge
setuptools                67.1.0             pyhd8ed1ab_0    conda-forge
sip                       6.7.7           py310heca2aa9_0    conda-forge
six                       1.16.0             pyh6c4a22f_0    conda-forge
snappy                    1.1.9                hbd366e4_2    conda-forge
sniffio                   1.3.0              pyhd8ed1ab_0    conda-forge
soupsieve                 2.3.2.post1        pyhd8ed1ab_0    conda-forge
stack_data                0.6.2              pyhd8ed1ab_0    conda-forge
tenacity                  8.2.1              pyhd8ed1ab_0    conda-forge
tensorboard               2.11.2             pyhd8ed1ab_0    conda-forge
tensorboard-data-server   0.6.1           py310h600f1e7_4    conda-forge
tensorboard-plugin-wit    1.8.1              pyhd8ed1ab_0    conda-forge
tensorflow                2.11.0          cuda112py310he87a039_0    conda-forge
tensorflow-base           2.11.0          cuda112py310h52da4a5_0    conda-forge
tensorflow-estimator      2.11.0          cuda112py310h37add04_0    conda-forge
tensorflow-gpu            2.11.0          cuda112py310h0bbbad9_0    conda-forge
termcolor                 2.2.0              pyhd8ed1ab_0    conda-forge
terminado                 0.17.1             pyh41d4057_0    conda-forge
tifffile                  2023.2.3           pyhd8ed1ab_0    conda-forge
tinycss2                  1.2.1              pyhd8ed1ab_0    conda-forge
tk                        8.6.12               h27826a3_0    conda-forge
tokenizers                0.13.2                   pypi_0    pypi
toml                      0.10.2             pyhd8ed1ab_0    conda-forge
toolz                     0.12.0             pyhd8ed1ab_0    conda-forge
tornado                   6.2             py310h5764c6d_1    conda-forge
tqdm                      4.64.1             pyhd8ed1ab_0    conda-forge
traitlets                 5.9.0              pyhd8ed1ab_0    conda-forge
transformers              4.26.1                   pypi_0    pypi
typing-extensions         4.4.0                hd8ed1ab_0    conda-forge
typing_extensions         4.4.0              pyha770c72_0    conda-forge
tzdata                    2022g                h191b570_0    conda-forge
unicodedata2              15.0.0          py310h5764c6d_0    conda-forge
urllib3                   1.26.14            pyhd8ed1ab_0    conda-forge
versioneer                0.28                     pypi_0    pypi
wcwidth                   0.2.6              pyhd8ed1ab_0    conda-forge
webencodings              0.5.1                      py_1    conda-forge
websocket-client          1.5.1              pyhd8ed1ab_0    conda-forge
werkzeug                  2.2.2              pyhd8ed1ab_0    conda-forge
wheel                     0.38.4             pyhd8ed1ab_0    conda-forge
widgetsnbextension        4.0.5              pyhd8ed1ab_0    conda-forge
wrapt                     1.14.1          py310h5764c6d_1    conda-forge
xcb-util                  0.4.0                h516909a_0    conda-forge
xcb-util-image            0.4.0                h166bdaf_0    conda-forge
xcb-util-keysyms          0.4.0                h516909a_0    conda-forge
xcb-util-renderutil       0.3.9                h166bdaf_0    conda-forge
xcb-util-wm               0.4.1                h516909a_0    conda-forge
xorg-kbproto              1.0.7             h7f98852_1002    conda-forge
xorg-libice               1.0.10               h7f98852_0    conda-forge
xorg-libsm                1.2.3             hd9c2040_1000    conda-forge
xorg-libx11               1.7.2                h7f98852_0    conda-forge
xorg-libxau               1.0.9                h7f98852_0    conda-forge
xorg-libxdmcp             1.1.3                h7f98852_0    conda-forge
xorg-libxext              1.3.4                h7f98852_1    conda-forge
xorg-libxrender           0.9.10            h7f98852_1003    conda-forge
xorg-renderproto          0.11.1            h7f98852_1002    conda-forge
xorg-xextproto            7.3.0             h7f98852_1002    conda-forge
xorg-xproto               7.0.31            h7f98852_1007    conda-forge
xz                        5.2.6                h166bdaf_0    conda-forge
yaml                      0.2.5                h7f98852_2    conda-forge
yarl                      1.8.2           py310h5764c6d_0    conda-forge
zeromq                    4.3.4                h9c3ff4c_1    conda-forge
zfp                       1.0.0                h27087fc_3    conda-forge
zipp                      3.13.0             pyhd8ed1ab_0    conda-forge
zlib                      1.2.13               h166bdaf_4    conda-forge
zlib-ng                   2.0.6                h166bdaf_0    conda-forge
zstd                      1.5.2                h3eb15da_6    conda-forge

Config file

{
    "TARGET_SIZE": [
        512,
        512
    ],
    "MODEL": "segformer",
    "NCLASSES": 7,
    "BATCH_SIZE": 8,
    "N_DATA_BANDS": 1,
    "DO_TRAIN": true,
    "PATIENCE": 10,
    "MAX_EPOCHS": 10,
    "VALIDATION_SPLIT": 0.6,
    "FILTERS": 6,
    "KERNEL": 7,
    "STRIDE": 2,
    "LOSS": "dice",
    "DROPOUT": 0.1,
    "DROPOUT_CHANGE_PER_LAYER": 0.0,
    "DROPOUT_TYPE": "standard",
    "USE_DROPOUT_ON_UPSAMPLING": false,
    "ROOT_STRING": "Substrate_inclShadow",
    "FILTER_VALUE": 3,
    "DOPLOT": true,
    "USEMASK": true,
    "RAMPUP_EPOCHS": 10,
    "SUSTAIN_EPOCHS": 0.0,
    "EXP_DECAY": 0.9,
    "START_LR": 1e-07,
    "MIN_LR": 1e-07,
    "MAX_LR": 0.0001,
    "AUG_ROT": 0,
    "AUG_ZOOM": 0.05,
    "AUG_WIDTHSHIFT": 0.05,
    "AUG_HEIGHTSHIFT": 0.05,
    "AUG_HFLIP": true,
    "AUG_VFLIP": false,
    "AUG_LOOPS": 3,
    "AUG_COPIES": 3,
    "TESTTIMEAUG": false,
    "SET_GPU": "0",
    "DO_CRF": false,
    "SET_PCI_BUS_ID": true,
    "WRITE_MODELMETADATA": true,
    "OTSU_THRESHOLD": true,
}

I had the same problem initially. I fixed it with "pip install transformers"

Oh I just saw your edit. Yes, your 1 band imagery is probably the cause. There's an easy fix. One liner. I'll add it tomorrow

I'm super impressed with segformers!

I have a problem dataset that I cannot get a good Res-UNet model for. It is a heavily class-imbalanced dataset (damaged buildings). Res-UNets tend to overfit the majority class

Segformer worked really well with a low LR

xbd_RGB_768_v3_val_44
xbd_RGB_768_v3_val_52
xbd_RGB_768_v3_val_60
xbd_RGB_768_v3_val_64
xbd_RGB_768_v3_val_82

https://zenodo.org/record/7613175#.Y-lR6xzMLRY

I just noticed more bugs in do_seg from doodleverse-utils. I need to add support for NCLASSES>2. Let me keep working on this implementation and I'll let you know when it'll be ready for testing again

Haven't yet tested, but I added the kludge for N_DATA_BANDS=1, to read_seg_dataset_multiclass_segformer in do_train

    if N_DATA_BANDS==1:
        image = np.dstack((image, image, image))

Ok, @CameronBodine I think I have fixed the issue with the 1-band image (see above) and I also updated doodleverse-utils so the segformer works in prediction mode when nclasses>2

Please use latest dev_models branch and update to https://pypi.org/project/doodleverse-utils/0.0.23/

Haven't yet tested, but I added the kludge for N_DATA_BANDS=1, to read_seg_dataset_multiclass_segformer in do_train

    if N_DATA_BANDS==1:
        image = np.dstack((image, image, image))

Ok, still having issues and can't figure out how to fix it. Running with the fix above, I receive the following error:

Traceback (most recent call last):
  File "/mnt/md0/SynologyDrive/Modeling/segmentation_gym/train_model_csb.py", line 574, in <module>
    train_ds = train_ds.map(read_seg_dataset_multiclass_segformer, num_parallel_calls=AUTO)
  File "/home/cbodine/miniconda3/envs/gym/lib/python3.8/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 2296, in map
    return ParallelMapDataset(
  File "/home/cbodine/miniconda3/envs/gym/lib/python3.8/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 5540, in __init__
    self._map_func = structured_function.StructuredFunctionWrapper(
  File "/home/cbodine/miniconda3/envs/gym/lib/python3.8/site-packages/tensorflow/python/data/ops/structured_function.py", line 263, in __init__
    self._function = fn_factory()
  File "/home/cbodine/miniconda3/envs/gym/lib/python3.8/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compiler.py", line 226, in get_concrete_function
    concrete_function = self._get_concrete_function_garbage_collected(
  File "/home/cbodine/miniconda3/envs/gym/lib/python3.8/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compiler.py", line 192, in _get_concrete_function_garbage_collected
    concrete_function, _ = self._maybe_define_concrete_function(args, kwargs)
  File "/home/cbodine/miniconda3/envs/gym/lib/python3.8/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compiler.py", line 157, in _maybe_define_concrete_function
    return self._maybe_define_function(args, kwargs)
  File "/home/cbodine/miniconda3/envs/gym/lib/python3.8/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compiler.py", line 360, in _maybe_define_function
    concrete_function = self._create_concrete_function(args, kwargs)
  File "/home/cbodine/miniconda3/envs/gym/lib/python3.8/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compiler.py", line 284, in _create_concrete_function
    func_graph_module.func_graph_from_py_func(
  File "/home/cbodine/miniconda3/envs/gym/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 1283, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/home/cbodine/miniconda3/envs/gym/lib/python3.8/site-packages/tensorflow/python/data/ops/structured_function.py", line 240, in wrapped_fn
    ret = wrapper_helper(*args)
  File "/home/cbodine/miniconda3/envs/gym/lib/python3.8/site-packages/tensorflow/python/data/ops/structured_function.py", line 171, in wrapper_helper
    ret = autograph.tf_convert(self._func, ag_ctx)(*nested_args)
  File "/home/cbodine/miniconda3/envs/gym/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py", line 642, in wrapper
    return func(*args, **kwargs)
  File "/mnt/md0/SynologyDrive/Modeling/segmentation_gym/train_model_csb.py", line 258, in read_seg_dataset_multiclass_segformer
    image = np.dstack((image, image, image))
  File "<__array_function__ internals>", line 180, in dstack
  File "/home/cbodine/miniconda3/envs/gym/lib/python3.8/site-packages/numpy/lib/shape_base.py", line 720, in dstack
    arrs = atleast_3d(*tup)
  File "<__array_function__ internals>", line 180, in atleast_3d
  File "/home/cbodine/miniconda3/envs/gym/lib/python3.8/site-packages/numpy/core/shape_base.py", line 191, in atleast_3d
    ary = asanyarray(ary)
  File "/home/cbodine/miniconda3/envs/gym/lib/python3.8/site-packages/tensorflow/python/framework/ops.py", line 922, in __array__
    raise NotImplementedError(
NotImplementedError: Cannot convert a symbolic tf.Tensor (EagerPyFunc:0) to a numpy array. This error may indicate that you're trying to pass a Tensor to a NumPy call, which is not supported.

I then tried:

if N_DATA_BANDS==1:
        image = np.dstack((image.numpy(), image.numpy(), image.numpy()))

and get:

Traceback (most recent call last):
  File "/mnt/md0/SynologyDrive/Modeling/segmentation_gym/train_model_csb.py", line 574, in <module>
    train_ds = train_ds.map(read_seg_dataset_multiclass_segformer, num_parallel_calls=AUTO)
  File "/home/cbodine/miniconda3/envs/gym/lib/python3.8/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 2296, in map
    return ParallelMapDataset(
  File "/home/cbodine/miniconda3/envs/gym/lib/python3.8/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 5540, in __init__
    self._map_func = structured_function.StructuredFunctionWrapper(
  File "/home/cbodine/miniconda3/envs/gym/lib/python3.8/site-packages/tensorflow/python/data/ops/structured_function.py", line 263, in __init__
    self._function = fn_factory()
  File "/home/cbodine/miniconda3/envs/gym/lib/python3.8/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compiler.py", line 226, in get_concrete_function
    concrete_function = self._get_concrete_function_garbage_collected(
  File "/home/cbodine/miniconda3/envs/gym/lib/python3.8/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compiler.py", line 192, in _get_concrete_function_garbage_collected
    concrete_function, _ = self._maybe_define_concrete_function(args, kwargs)
  File "/home/cbodine/miniconda3/envs/gym/lib/python3.8/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compiler.py", line 157, in _maybe_define_concrete_function
    return self._maybe_define_function(args, kwargs)
  File "/home/cbodine/miniconda3/envs/gym/lib/python3.8/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compiler.py", line 360, in _maybe_define_function
    concrete_function = self._create_concrete_function(args, kwargs)
  File "/home/cbodine/miniconda3/envs/gym/lib/python3.8/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compiler.py", line 284, in _create_concrete_function
    func_graph_module.func_graph_from_py_func(
  File "/home/cbodine/miniconda3/envs/gym/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 1283, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/home/cbodine/miniconda3/envs/gym/lib/python3.8/site-packages/tensorflow/python/data/ops/structured_function.py", line 240, in wrapped_fn
    ret = wrapper_helper(*args)
  File "/home/cbodine/miniconda3/envs/gym/lib/python3.8/site-packages/tensorflow/python/data/ops/structured_function.py", line 171, in wrapper_helper
    ret = autograph.tf_convert(self._func, ag_ctx)(*nested_args)
  File "/home/cbodine/miniconda3/envs/gym/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py", line 642, in wrapper
    return func(*args, **kwargs)
  File "/mnt/md0/SynologyDrive/Modeling/segmentation_gym/train_model_csb.py", line 258, in read_seg_dataset_multiclass_segformer
    image = np.dstack((image.numpy(), image.numpy(), image.numpy()))
  File "/home/cbodine/miniconda3/envs/gym/lib/python3.8/site-packages/tensorflow/python/framework/ops.py", line 444, in __getattr__
    self.__getattribute__(name)
AttributeError: 'Tensor' object has no attribute 'numpy'

I found this and tried:

if N_DATA_BANDS==1:
        image[:,:,tf.newaxis]
        tf.concat([image, image, image], axis=2)

I don't receive an error in def read_seg_dataset_multiclass_segformer(example) anymore, but do get the same error I received previously:

Traceback (most recent call last):
  File "/mnt/md0/SynologyDrive/Modeling/segmentation_gym/train_model_csb.py", line 915, in <module>
    history = model.fit(train_ds, steps_per_epoch=steps_per_epoch, epochs=MAX_EPOCHS,
  File "/home/cbodine/miniconda3/envs/gym/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 70, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/tmp/__autograph_generated_filesubbl8qp.py", line 15, in tf__train_function
    retval_ = ag__.converted_call(ag__.ld(step_function), (ag__.ld(self), ag__.ld(iterator)), None, fscope)
  File "/home/cbodine/miniconda3/envs/gym/lib/python3.8/site-packages/transformers/modeling_tf_utils.py", line 1535, in train_step
    y_pred = self(x, training=True)
  File "/tmp/__autograph_generated_fileslzg7p14.py", line 36, in tf__run_call_with_unpacked_inputs
    retval_ = ag__.converted_call(ag__.ld(func), (ag__.ld(self),), dict(**ag__.ld(unpacked_inputs)), fscope)
  File "/tmp/__autograph_generated_filez21ii97w.py", line 13, in tf__call
    outputs = ag__.converted_call(ag__.ld(self).segformer, (ag__.ld(pixel_values),), dict(output_attentions=ag__.ld(output_attentions), output_hidden_states=True, return_dict=ag__.ld(return_dict)), fscope)
  File "/tmp/__autograph_generated_fileslzg7p14.py", line 36, in tf__run_call_with_unpacked_inputs
    retval_ = ag__.converted_call(ag__.ld(func), (ag__.ld(self),), dict(**ag__.ld(unpacked_inputs)), fscope)
  File "/tmp/__autograph_generated_filedchqfcvm.py", line 14, in tf__call
    encoder_outputs = ag__.converted_call(ag__.ld(self).encoder, (ag__.ld(pixel_values),), dict(output_attentions=ag__.ld(output_attentions), output_hidden_states=ag__.ld(output_hidden_states), return_dict=ag__.ld(return_dict), training=ag__.ld(training)), fscope)
  File "/tmp/__autograph_generated_fileu9el03h2.py", line 102, in tf__call
    ag__.for_stmt(ag__.converted_call(ag__.ld(enumerate), (ag__.converted_call(ag__.ld(zip), (ag__.ld(self).embeddings, ag__.ld(self).block, ag__.ld(self).layer_norms), None, fscope),), None, fscope), None, loop_body_1, get_state_4, set_state_4, ('all_hidden_states', 'all_self_attentions', 'hidden_states'), {'iterate_names': '(idx, x)'})
  File "/tmp/__autograph_generated_fileu9el03h2.py", line 26, in loop_body_1
    (hidden_states, height, width) = ag__.converted_call(ag__.ld(embedding_layer), (ag__.ld(hidden_states),), None, fscope)
  File "/tmp/__autograph_generated_filezjucj2gz.py", line 10, in tf__call
    embeddings = ag__.converted_call(ag__.ld(self).proj, (ag__.converted_call(ag__.ld(self).padding, (ag__.ld(pixel_values),), None, fscope),), None, fscope)
ValueError: in user code:

    File "/home/cbodine/miniconda3/envs/gym/lib/python3.8/site-packages/keras/engine/training.py", line 1249, in train_function  *
        return step_function(self, iterator)
    File "/home/cbodine/miniconda3/envs/gym/lib/python3.8/site-packages/transformers/modeling_tf_utils.py", line 830, in run_call_with_unpacked_inputs  *
        return func(self, **unpacked_inputs)
    File "/home/cbodine/miniconda3/envs/gym/lib/python3.8/site-packages/transformers/models/segformer/modeling_tf_segformer.py", line 857, in call  *
        outputs = self.segformer(
    File "/home/cbodine/miniconda3/envs/gym/lib/python3.8/site-packages/transformers/modeling_tf_utils.py", line 830, in run_call_with_unpacked_inputs  *
        return func(self, **unpacked_inputs)
    File "/home/cbodine/miniconda3/envs/gym/lib/python3.8/site-packages/transformers/models/segformer/modeling_tf_segformer.py", line 481, in call  *
        encoder_outputs = self.encoder(
    File "/home/cbodine/miniconda3/envs/gym/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 70, in error_handler  **
        raise e.with_traceback(filtered_tb) from None
    File "/tmp/__autograph_generated_fileu9el03h2.py", line 102, in tf__call
        ag__.for_stmt(ag__.converted_call(ag__.ld(enumerate), (ag__.converted_call(ag__.ld(zip), (ag__.ld(self).embeddings, ag__.ld(self).block, ag__.ld(self).layer_norms), None, fscope),), None, fscope), None, loop_body_1, get_state_4, set_state_4, ('all_hidden_states', 'all_self_attentions', 'hidden_states'), {'iterate_names': '(idx, x)'})
    File "/tmp/__autograph_generated_fileu9el03h2.py", line 26, in loop_body_1
        (hidden_states, height, width) = ag__.converted_call(ag__.ld(embedding_layer), (ag__.ld(hidden_states),), None, fscope)
    File "/tmp/__autograph_generated_filezjucj2gz.py", line 10, in tf__call
        embeddings = ag__.converted_call(ag__.ld(self).proj, (ag__.converted_call(ag__.ld(self).padding, (ag__.ld(pixel_values),), None, fscope),), None, fscope)

    ValueError: Exception encountered when calling layer 'encoder' (type TFSegformerEncoder).
    
    in user code:
    
        File "/home/cbodine/miniconda3/envs/gym/lib/python3.8/site-packages/transformers/models/segformer/modeling_tf_segformer.py", line 416, in call  *
            hidden_states, height, width = embedding_layer(hidden_states)
        File "/home/cbodine/miniconda3/envs/gym/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 70, in error_handler  **
            raise e.with_traceback(filtered_tb) from None
        File "/tmp/__autograph_generated_filezjucj2gz.py", line 10, in tf__call
            embeddings = ag__.converted_call(ag__.ld(self).proj, (ag__.converted_call(ag__.ld(self).padding, (ag__.ld(pixel_values),), None, fscope),), None, fscope)
    
        ValueError: Exception encountered when calling layer 'patch_embeddings.0' (type TFSegformerOverlapPatchEmbeddings).
        
        in user code:
        
            File "/home/cbodine/miniconda3/envs/gym/lib/python3.8/site-packages/transformers/models/segformer/modeling_tf_segformer.py", line 89, in call  *
                embeddings = self.proj(self.padding(pixel_values))
            File "/home/cbodine/miniconda3/envs/gym/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 70, in error_handler  **
                raise e.with_traceback(filtered_tb) from None
            File "/home/cbodine/miniconda3/envs/gym/lib/python3.8/site-packages/keras/engine/input_spec.py", line 277, in assert_input_compatibility
                raise ValueError(
        
            ValueError: Input 0 of layer "proj" is incompatible with the layer: expected axis -1 of input shape to have value 3, but received input with shape (8, None, None, 1)
        
        
        Call arguments received by layer 'patch_embeddings.0' (type TFSegformerOverlapPatchEmbeddings):
          โ€ข pixel_values=tf.Tensor(shape=(8, None, None, 1), dtype=float32)
    
    
    Call arguments received by layer 'encoder' (type TFSegformerEncoder):
      โ€ข pixel_values=tf.Tensor(shape=(8, None, None, 1), dtype=float32)
      โ€ข output_attentions=False
      โ€ข output_hidden_states=True
      โ€ข return_dict=True
      โ€ข training=True

This makes me think I'm not actually reshaping anything. I tried printing some of the variable in def read_seg_dataset_multiclass_segformer(example) to debug with:

image, label = tf.py_function(func=load_npz, inp=[example], Tout=[tf.float32, tf.uint8])

imdim = image.shape[0]

print('example:', example)
print('image:', image)
print('label:', label)
print('image shape:', image.shape)
print('label shape:', label.shape)

and I get:

example: Tensor("args_0:0", shape=(), dtype=string)
image: Tensor("EagerPyFunc:0", dtype=float32, device=/job:localhost/replica:0/task:0)
label: Tensor("EagerPyFunc:1", dtype=uint8, device=/job:localhost/replica:0/task:0)
image shape: <unknown>
label shape: <unknown>

Any thoughts on what I can try?

As a sanity check, I am able to train a segformer on the Cape Hatteras sample dataset.

Hmmm ... thanks for reporting. In the above, I assume you meant something like

image = image[:,:,tf.newaxis]
image = tf.concat([image, image, image], axis=2)

?

If so, I don't know ... I would have to play with a 1-band dataset (which I'm happy to do)

GOT IT! First, need to stack the image:

if N_DATA_BANDS==1:
        image = tf.concat([image, image, image], axis=2)

        # Below also works:
        # image = tf.experimental.numpy.dstack((image, image, image))

AND need to add a check for number of bands when setting the tensor's shape:

if N_DATA_BANDS==1:
        image.set_shape([3, imdim, imdim])
else:
        image.set_shape([N_DATA_BANDS, imdim, imdim])

Slamma Jamma Ding Dong!

Yes, I was messing up the original concatenation also. Now I'm ready to train and will report back!

Slamma Jamma Ding Dong!

Great, I'll add it to the dev branch

Initial results are encouraging! Trained on a small subset n=300 datasets with following hyper-parameters:

"TARGET_SIZE": [512,512],
"MODEL": 'segformer',
"NCLASSES": 7,
"BATCH_SIZE": 20,
"N_DATA_BANDS": 1,
"PATIENCE": 10,
"MAX_EPOCHS": 100,
"VALIDATION_SPLIT": 0.6,
"EXP_DECAY": 0.9,
"START_LR":  1e-7,
"MIN_LR": 1e-7,
"MAX_LR": 1e-4,

The model trained in 25 epochs. Examples below. Moving to training on full dataset.

SegFormer_SpdCor_Substrate_inclShadow_trainhist_20

SegFormer_SpdCor_Substrate_inclShadow_val_3
SegFormer_SpdCor_Substrate_inclShadow_val_5
SegFormer_SpdCor_Substrate_inclShadow_val_10
SegFormer_SpdCor_Substrate_inclShadow_val_12

Cool!

One thing I noticed was that I had to use a very low LR range for best results on a NCLASSES=2 and NCLASSES=4 problem. I used 1e-8 to 1e-5 and that trains slower but seems to create a better model. with rampup = 10, sustain=5

Yes that's it!

I am getting excellent results after training on my entire dataset n=~12,000 with the same hyperparameters I used on the smaller train set. I will give the smaller learning rate a shot for comparison. Super pleased with the results though!

SegFormer_SpdCor_Substrate_inclShadow_trainhist_20
SegFormer_SpdCor_Substrate_inclShadow_val_261
SegFormer_SpdCor_Substrate_inclShadow_val_263
SegFormer_SpdCor_Substrate_inclShadow_val_272
SegFormer_SpdCor_Substrate_inclShadow_val_298

I've been very impressed with Segformers so far!

Leaving this open solely to add details to the docs

I finally tested the segformer model on Windows and got the CUDA error that @CameronBodine previously got. pip install transformers did not work for me....

Ok, this is a different issue now.

I have updated the wiki. I am closing this issue and opening a new issue about the CUDA error