Project-MONAI/tutorials

ValueError: y_pred and y should have same shapes.

CharlieFengCN opened this issue · 2 comments

Hello! I ran into the same problem as someone before, but I didn't get the solution from the previous answers. I am using unetr in monai for my binary classification segmentation of stroke lesions. It is going well during training, but when validating, an error will appear. This is my error code and my code. Please take a look at what the problem is. The reason is that I also want to ask if unetr can support two-class segmentation. I hope to get your help as soon as possible. Thank you very much.
`---> 75 global_step, dice_val_best, global_step_best = train(global_step, train_loader, dice_val_best, global_step_best)
Cell In[70], line 39, in train(global_step, train_loader, dice_val_best, global_step_best)
37 if (global_step % eval_num == 0 and global_step != 0) or global_step == max_iterations:
38 epoch_iterator_val = tqdm(val_loader, desc="Validate (X / X Steps) (dice=X.X)", dynamic_ncols=True)
---> 39 dice_val = validation(epoch_iterator_val)
40 epoch_loss /= step
41 epoch_loss_values.append(epoch_loss)

Cell In[70], line 14, in validation(epoch_iterator_val)
11 val_outputs_list = decollate_batch(val_outputs)
12 val_output_convert = [post_pred(val_pred_tensor) for val_pred_tensor in val_outputs_list]
---> 14 dice_metric(y_pred=val_output_convert, y=val_labels_convert)
15 epoch_iterator_val.set_description("Validate (%d / %d Steps)" % (global_step, 10.0))
16 mean_dice_val = dice_metric.aggregate().item()

File E:\mambaforge\envs\BCP\lib\site-packages\monai\metrics\metric.py:216, in CumulativeIterationMetric.call(self, y_pred, y)
202 def call(self, y_pred: TensorOrList, y: Optional[TensorOrList] = None): # type: ignore
203 """
204 Execute basic computation for model prediction and ground truth.
205 It can support both list of channel-first Tensor and batch-first Tensor.
(...)
214
215 """
--> 216 ret = super().call(y_pred=y_pred, y=y)
217 if isinstance(ret, (tuple, list)):
218 self.add(*ret)

File E:\mambaforge\envs\BCP\lib\site-packages\monai\metrics\metric.py:63, in IterationMetric.call(self, y_pred, y)
60 ret: TensorOrList
61 if isinstance(y_pred, (list, tuple)) or isinstance(y, (list, tuple)):
62 # if y_pred or y is a list of channel-first data, add batch dim and compute metric
---> 63 ret = self.compute_list(y_pred, y)
64 elif isinstance(y_pred, torch.Tensor):
65 y
= y.detach() if y is not None and isinstance(y, torch.Tensor) else None

File E:\mambaforge\envs\BCP\lib\site-packages\monai\metrics\metric.py:82, in IterationMetric.compute_list(self, y_pred, y)
80 ret: TensorOrList
81 if y is not None:
---> 82 ret = [self.compute_tensor(p.detach().unsqueeze(0), y.detach().unsqueeze(0)) for p, y
in zip(y_pred, y)]
83 else:
84 ret = [self.compute_tensor(p.detach().unsqueeze(0), None) for p_ in y_pred]

File E:\mambaforge\envs\BCP\lib\site-packages\monai\metrics\metric.py:82, in (.0)
80 ret: TensorOrList
81 if y is not None:
---> 82 ret = [self.compute_tensor(p.detach().unsqueeze(0), y.detach().unsqueeze(0)) for p, y_ in zip(y_pred, y)]
83 else:
84 ret = [self.compute_tensor(p.detach().unsqueeze(0), None) for p_ in y_pred]

File E:\mambaforge\envs\BCP\lib\site-packages\monai\metrics\meandice.py:80, in DiceMetric._compute_tensor(self, y_pred, y)
78 raise ValueError("y_pred should have at least three dimensions.")
79 # compute dice (BxC) for each channel for each batch
---> 80 return compute_meandice(
81 y_pred=y_pred,
82 y=y,
83 include_background=self.include_background,
84 )

File E:\mambaforge\envs\BCP\lib\site-packages\monai\metrics\meandice.py:134, in compute_meandice(y_pred, y, include_background)
131 y_pred = y_pred.float()
133 if y.shape != y_pred.shape:
--> 134 raise ValueError("y_pred and y should have same shapes.")
136 # reducing only spatial dimensions (not batch nor channels)
137 n_len = len(y_pred.shape)

ValueError: y_pred and y should have same shapes.`

my code
`train_transforms = Compose(
[
LoadImaged(keys=["image", "label"]),
EnsureChannelFirstd(keys=["image", "label"]),
Orientationd(keys=["image", "label"], axcodes="RAS"),
ScaleIntensityRanged(
keys=["image"],
a_min=-50,
a_max=100,
b_min=0.0,
b_max=1.0,
clip=True,
),
# Spacingd(keys=["image", "label"], pixdim=(2.0, 2.0, 0.1), mode=("bilinear", "nearest")),
# CropForegroundd(keys=["image", "label"], source_key="image"),
# CenterSpatialCropd(keys=['image', 'label'], roi_size=(512,512,12)),
# Resize(spatial_size=(400, 400, 12)),
ResizeWithPadOrCropd(keys=["image", "label"], spatial_size=(512, 512, 16)),
RandCropByPosNegLabeld(
keys=["image", "label"],
label_key="label",
spatial_size=(96, 96, 16),
pos=1,
neg=1,
num_samples=8,
image_key="image",
image_threshold=0,
),
# RandAffined(
# keys=["image", "label"],
# mode=("bilinear", "nearest"),
# prob=1.0,
# spatial_size=(512, 512, 16),
# translate_range=(40, 40, 2),
# rotate_range=(np.pi / 36, np.pi / 36, np.pi / 4),
# scale_range=(0.15, 0.15, 0.15),
# padding_mode="border",
# ),
# RandGaussianNoised(keys=["image"], prob=0.10, std=0.1),
RandFlipd(
keys=["image", "label"],
spatial_axis=[0],
prob=0.10,
),
RandFlipd(
keys=["image", "label"],
spatial_axis=[1],
prob=0.10,
),
RandFlipd(
keys=["image", "label"],
spatial_axis=[2],
prob=0.10,
),
RandRotate90d(
keys=["image", "label"],
prob=0.10,
max_k=3,
),
RandShiftIntensityd(
keys=["image"],
offsets=0.10,
prob=0.50,
),
# CenterSpatialCropd(keys=['image', 'label'], roi_size=(352,352,16))
]
)
val_transforms = Compose(
[
LoadImaged(keys=["image", "label"]),
EnsureChannelFirstd(keys=["image", "label"]),
Orientationd(keys=["image", "label"], axcodes="RAS"),
# Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 0.1), mode=("bilinear", "nearest")),
ScaleIntensityRanged(
keys=["image"],
a_min=-50,
a_max=100,
b_min=0.0,
b_max=1.0,
clip=True,
),
# CropForegroundd(keys=["image", "label"], source_key="image"),
# Resize(spatial_size=(400, 400, 12)),
ResizeWithPadOrCropd(keys=["image", "label"], spatial_size=(512, 512, 16)),

    # Rotate90d(keys=["image", "label"], k=1),
    # CenterSpatialCropd(keys=['image', 'label'], roi_size=(352,352,16)),
    # Flipd(keys=["image", "label"], spatial_axis=[0]),
]

)

data_dir = r"F:/Datasets/ATLAS_2/data"
split_json = "/dataset.json"

datasets = data_dir + split_json
datalist = load_decathlon_datalist(datasets, True, "training")
val_files = load_decathlon_datalist(datasets, True, "validation")
train_ds = CacheDataset(
data=datalist,
transform=train_transforms,
cache_num=8,
cache_rate=1.0,
num_workers=0,
)
device_ids = [i for i in range(torch.cuda.device_count())]
train_loader = DataLoader(train_ds, batch_size=1 * len(device_ids), shuffle=True, num_workers=0, pin_memory=True,drop_last=True, collate_fn=pad_list_data_collate)
val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_num=6, cache_rate=1.0, num_workers=0)
val_loader = DataLoader(val_ds, batch_size=1* len(device_ids), shuffle=False, num_workers=0, pin_memory=True, collate_fn=pad_list_data_collate)

import torch.nn as nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = UNETR(
in_channels=1,
out_channels=2,
img_size=(96, 96, 16),
feature_size=16,
hidden_size=768,
mlp_dim=3072,
num_heads=12,
pos_embed="perceptron",
norm_name="instance",
res_block=True,
dropout_rate=0.0,
).to(device)

root_dir = "./run"
loss_function = DiceCELoss(to_onehot_y=True, softmax=True)
torch.backends.cudnn.benchmark = True
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)

def validation(epoch_iterator_val):
model.eval()
with torch.no_grad():
for batch in epoch_iterator_val:
val_inputs, val_labels = (batch["image"].cuda(device_ids[0]), batch["label"].cuda(device_ids[0]))
val_outputs = sliding_window_inference(val_inputs, (96, 96, 16), 8, model)
print(val_labels.shape)
print(val_outputs.shape)
val_labels_list = decollate_batch(val_labels)
val_labels_convert = [post_label(val_label_tensor) for val_label_tensor in val_labels_list]
val_outputs_list = decollate_batch(val_outputs)
val_output_convert = [post_pred(val_pred_tensor) for val_pred_tensor in val_outputs_list]

        dice_metric(y_pred=val_output_convert, y=val_labels_convert)
        epoch_iterator_val.set_description("Validate (%d / %d Steps)" % (global_step, 10.0))
    mean_dice_val = dice_metric.aggregate().item()
    dice_metric.reset()
return mean_dice_val

def train(global_step, train_loader, dice_val_best, global_step_best):
model.train()

epoch_loss = 0
step = 0
epoch_iterator = tqdm(train_loader, desc="Training (X / X Steps) (loss=X.X)", dynamic_ncols=True)
for step, batch in enumerate(epoch_iterator):
    step += 1
    x, y = (batch["image"].cuda(device_ids[0]), batch["label"].cuda(device_ids[0]))
    logit_map = model(x)
    loss = loss_function(logit_map, y)
    loss.backward()
    epoch_loss += loss.item()
    optimizer.step()
    optimizer.zero_grad()
    epoch_iterator.set_description("Training (%d / %d Steps) (loss=%2.5f)" % (global_step, max_iterations, loss))
    if (global_step % eval_num == 0 and global_step != 0) or global_step == max_iterations:
        epoch_iterator_val = tqdm(val_loader, desc="Validate (X / X Steps) (dice=X.X)", dynamic_ncols=True)
        dice_val = validation(epoch_iterator_val)
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        metric_values.append(dice_val)
        if dice_val > dice_val_best:
            dice_val_best = dice_val
            global_step_best = global_step
            torch.save(model.state_dict(), os.path.join(root_dir, "best_metric_model.pth"))
            print(
                "Model Was Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}".format(dice_val_best, dice_val)
            )
        else:
            print(
                "Model Was Not Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}".format(
                    dice_val_best, dice_val
                )
            )
    global_step += 1
return global_step, dice_val_best, global_step_best

max_iterations = 30000 # 2000
eval_num = 5 # 50
num_classes = 2
post_label = AsDiscrete(to_onehot=True, num_classes=num_classes)

post_label = AsDiscrete(to_onehot=2)

post_pred = AsDiscrete(argmax=True, num_classes=num_classes)
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
global_step = 0
dice_val_best = 0.0
global_step_best = 0
epoch_loss_values = []
metric_values = []

pretrained_dict = torch.load('./UNETR_model_best_acc.pth')

model.load_state_dict(pretrained_dict)

while global_step < max_iterations:
global_step, dice_val_best, global_step_best = train(global_step, train_loader, dice_val_best, global_step_best)
model.load_state_dict(torch.load(os.path.join(root_dir, "best_metric_model.pth")))`

please Reply to me as soon as possible. It's really urgent.

I don't know why the format is a bit messy after submission. I'm very sorry. Please forgive me.

Hi @CharlieFengCN, from the error message seems that the shape of your pred and label is mismatch, could you please check the shape?
Thanks!