ValueError: axes don't match array in prediction of S1 unit at 7-synthseg+.py
Closed this issue · 3 comments
Thank you for sharing synthseg.
To predict by S1 unit in synthseg+, I'm working on 7-synthseg+.py for training S1 and on 4-predict.py for prediction,
however, it returns "ValueError: axes don't match array". Here is the way the error is reproduced.
Environment: requirements_python3.6.txt
- Prepare image/lanel dataset by running scripts/tutorials/1-generation_visualisation.py
cd Synthseg/scripts/tutorials/
python 1-generation_visualisation.py
- Train the S1 unit by the part of 7-synthseg+.py as below. This ran and saved the model weight as "scripts/tutorials/outputs_tutorial_7/training_s1/dice_001.h5".
from SynthSeg.training import training as training_s1
# ------------------ segmenter S1
labels_dir_s1 = '../../data/training_label_maps'
path_generation_labels = '../../data/labels_classes_priors/generation_labels.npy'
path_generation_classes = '../../data/labels_classes_priors/generation_classes.npy'
path_segmentation_labels_s1 = '../../data/tutorial_7/segmentation_labels_s1.npy'
model_dir_s1 = './outputs_tutorial_7/training_s1'
training_s1(labels_dir=labels_dir_s1,
model_dir=model_dir_s1,
generation_labels=path_generation_labels,
segmentation_labels=None,
n_neutral_labels=18,
generation_classes=path_generation_classes,
target_res=1,
output_shape=160,
prior_distributions='uniform',
prior_means=[0, 255],
prior_stds=[0, 50],
randomise_res=True,
dice_epochs=1,
steps_per_epoch=1000,)
- Predict with trained S1 model by following code based on 4-predict.py
from SynthSeg.predict import predict
path_images = './outputs_tutorial_1/image.nii.gz'
path_segm = './outputs_tutorial_7/predicted_segmentations-S1'
path_posteriors = './outputs_tutorial_7/predicted_segmentations-S1'
path_vol = './outputs_tutorial_7/predicted_information/volumes.csv'
path_model = './outputs_tutorial_7/training_s1/dice_001.h5'
path_segmentation_labels = '../../data/tutorial_7/segmentation_labels_s1.npy'
path_segmentation_names = '../../data/labels_classes_priors/synthseg_segmentation_names.npy'
cropping = 192
target_res = 1.
path_resampled = './outputs_tutorial_7/predicted_information'
flip = True
n_neutral_labels = 18
sigma_smoothing = 0.5
topology_classes = '../../data/labels_classes_priors/synthseg_topological_classes.npy'
keep_biggest_component = True
n_levels = 5
nb_conv_per_level = 2
conv_size = 3
unet_feat_count = 24
activation = 'elu'
feat_multiplier = 2
gt_folder = None
compute_distances = True
predict(path_images,
path_segm,
path_model,
path_segmentation_labels,
n_neutral_labels=n_neutral_labels,
path_posteriors=path_posteriors,
path_resampled=path_resampled,
path_volumes=path_vol,
names_segmentation=path_segmentation_names,
cropping=cropping,
target_res=target_res,
flip=flip,
topology_classes=topology_classes,
sigma_smoothing=sigma_smoothing,
keep_biggest_component=keep_biggest_component,
n_levels=n_levels,
nb_conv_per_level=nb_conv_per_level,
conv_size=conv_size,
unet_feat_count=unet_feat_count,
feat_multiplier=feat_multiplier,
activation=activation,
gt_folder=gt_folder,
compute_distances=compute_distances)
- This eventually returns the following "ValueError: axes don't match array" error.
Traceback (most recent call last):
File "4-prediction-S1.py", line 137, in <module>
predict(path_images,
File "/home/ubuntu/SynthSeg/SynthSeg/predict.py", line 157, in predict
net = build_model(path_model=path_model,
File "/home/ubuntu/SynthSeg/SynthSeg/predict.py", line 467, in build_model
net.load_weights(path_model, by_name=True)
File "/home/ubuntu/SynthSeg/synthseg_env/lib/python3.8/site-packages/keras/engine/saving.py", line 492, in load_wrapper
return load_function(*args, **kwargs)
File "/home/ubuntu/SynthSeg/synthseg_env/lib/python3.8/site-packages/keras/engine/network.py", line 1225, in load_weights
saving.load_weights_from_hdf5_group_by_name(
File "/home/ubuntu/SynthSeg/synthseg_env/lib/python3.8/site-packages/keras/engine/saving.py", line 1289, in load_weights_from_hdf5_group_by_name
weight_values = preprocess_weights_for_loading(
File "/home/ubuntu/SynthSeg/synthseg_env/lib/python3.8/site-packages/keras/engine/saving.py", line 980, in preprocess_weights_for_loading
weights[0] = np.transpose(weights[0], (3, 2, 0, 1))
File "<__array_function__ internals>", line 5, in transpose
File "/home/ubuntu/SynthSeg/synthseg_env/lib/python3.8/site-packages/numpy/core/fromnumeric.py", line 651, in transpose
return _wrapfunc(a, 'transpose', axes)
File "/home/ubuntu/SynthSeg/synthseg_env/lib/python3.8/site-packages/numpy/core/fromnumeric.py", line 61, in _wrapfunc
return bound(*args, **kwds)
ValueError: axes don't match array
Hi,
so if you want to use the predict function on the trained S1, then you would need to adapt it a bit, because th default parameters are set to work for the distributed version of SynthSeg, which differs from S1 mainly by the set of predicted labels.
Everything is explained in the docstring of predict.py, but here's how you should modify your script.
First, you should change n_neutral_labels from 18 to None, because here we only have 5 labels, and all of them correspond to non-sided regions (ie, they are not restricted to either left or right hemisphere).
Then, you should get rid of path_segmentation_names, and topology_classes (ie set them to None), because 1) you won;t need these, and 2) these specific numpy arrays were designed for a network that segments 30 or so labels, not 5.
let me know if this works :)
Thanks! It worked!
Because I used the not-fully-trained S1 (by dice_epochs=1 and steps_per_epoch=10) to check,
the prediction returns the label "Left-Cerebral-Exterior" in whole pixels :)
For now it seems working and I'll try the prediction in the well-trained S1 later.
Glad to know it works!
More generally, please note that tutorial 7 is just here to explain how synthseg+ was trained. The 20 training maps given under /data/training_label_maps/ are not enough to train robust S1 and S2, and the three examples under /data/tutorial_7 are far from enough to train a good denoiser D.
To obtain good networks, you would have to train on your own data.