ValueError: operands could not be broadcast together with shapes (256,256,4) (3,) (256,256,4)
nvnvashisth opened this issue · 10 comments
I am trying to use this package, and it is throwing as below. I am using the same pipeline from cassava lead detection problem but on different set where image size is (256, 256)
Could you please help here.
Downloading: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth" to /root/.cache/torch/hub/checkpoints/efficientnet-b4-6ed6700e.pth
100%
74.4M/74.4M [00:00<00:00, 107MB/s]Loaded pretrained weights for efficientnet-b4
0%| | 0/51 [00:00<?, ?it/s]ValueError Traceback (most recent call last)
in ()
11 epochs=10,
12 callbacks=[es],
---> 13 fp16=True,
14 )
15 model.save("model.bin")6 frames
/usr/local/lib/python3.6/dist-packages/tez/model/model.py in fit(self, train_dataset, valid_dataset, train_sampler, valid_sampler, device, epochs, train_bs, valid_bs, n_jobs, callbacks, fp16)
295 self.train_state = enums.TrainingState.EPOCH_START
296 self.train_state = enums.TrainingState.TRAIN_EPOCH_START
--> 297 train_loss = self.train_one_epoch(self.train_loader, device)
298 self.train_state = enums.TrainingState.TRAIN_EPOCH_END
299 if self.valid_loader:/usr/local/lib/python3.6/dist-packages/tez/model/model.py in train_one_epoch(self, data_loader, device)
176 losses = AverageMeter()
177 tk0 = tqdm(data_loader, total=len(data_loader))
--> 178 for b_idx, data in enumerate(tk0):
179 self.train_state = enums.TrainingState.TRAIN_STEP_START
180 loss, metrics = self.train_one_step(data, device)/usr/local/lib/python3.6/dist-packages/tqdm/std.py in iter(self)
1102 fp_write=getattr(self.fp, 'write', sys.stderr.write))
1103
-> 1104 for obj in iterable:
1105 yield obj
1106 # Update and possibly print the progressbar./usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in next(self)
433 if self._sampler_iter is None:
434 self._reset()
--> 435 data = self._next_data()
436 self._num_yielded += 1
437 if self._dataset_kind == _DatasetKind.Iterable and \/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in _next_data(self)
1083 else:
1084 del self._task_info[idx]
-> 1085 return self._process_data(data)
1086
1087 def _try_put_index(self):/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in _process_data(self, data)
1109 self._try_put_index()
1110 if isinstance(data, ExceptionWrapper):
-> 1111 data.reraise()
1112 return data
1113/usr/local/lib/python3.6/dist-packages/torch/_utils.py in reraise(self)
426 # have message field
427 raise self.exc_type(message=msg)
--> 428 raise self.exc_type(msg)
429
430ValueError: Caught ValueError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/worker.py", line 198, in _worker_loop
data = fetcher.fetch(index)
File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py", line 44, in
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/usr/local/lib/python3.6/dist-packages/tez/datasets/image_classification.py", line 48, in getitem
augmented = self.augmentations(image=image)
File "/usr/local/lib/python3.6/dist-packages/albumentations/core/composition.py", line 171, in call
data = t(**data)
File "/usr/local/lib/python3.6/dist-packages/albumentations/core/transforms_interface.py", line 38, in call
res[key] = target_function(arg, **dict(params, **target_dependencies))
File "/usr/local/lib/python3.6/dist-packages/albumentations/augmentations/transforms.py", line 808, in apply
return F.normalize(image, self.mean, self.std, self.max_pixel_value)
File "/usr/local/lib/python3.6/dist-packages/albumentations/augmentations/functional.py", line 93, in normalize
img -= mean
ValueError: operands could not be broadcast together with shapes (256,256,4) (3,) (256,256,4)
It seems like your image is RGBA. Can you convert it to RGB? or do you have to use RGBA? If latter, then try writing your own dataloader.
Can you provide data and code to reproduce the error?
Ok I converted everything to RGB. I have label from [0-9] with image size 256x256. But I come across this CUDA error. Another thing, I am trying to execute in Colab.
Regarding the code, it is exactly taken from here https://www.kaggle.com/abhishek/tez-faster-and-easier-training-for-leaf-detection ;)
Loaded pretrained weights for efficientnet-b4
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-29-de0853739e51> in <module>()
11 epochs=10,
12 callbacks=[es],
---> 13 fp16=True
14 )
15 model.save("model.bin")
6 frames
/usr/local/lib/python3.6/dist-packages/tez/model/model.py in fit(self, train_dataset, valid_dataset, train_sampler, valid_sampler, device, epochs, train_bs, valid_bs, n_jobs, callbacks, fp16)
289 n_jobs=n_jobs,
290 callbacks=callbacks,
--> 291 fp16=fp16,
292 )
293
/usr/local/lib/python3.6/dist-packages/tez/model/model.py in _init_model(self, device, train_dataset, valid_dataset, train_sampler, valid_sampler, train_bs, valid_bs, n_jobs, callbacks, fp16)
81
82 if next(self.parameters()).device != device:
---> 83 self.to(device)
84
85 if self.train_loader is None:
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in to(self, *args, **kwargs)
610 return t.to(device, dtype if t.is_floating_point() else None, non_blocking)
611
--> 612 return self._apply(convert)
613
614 def register_backward_hook(
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in _apply(self, fn)
357 def _apply(self, fn):
358 for module in self.children():
--> 359 module._apply(fn)
360
361 def compute_should_use_set_data(tensor, tensor_applied):
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in _apply(self, fn)
357 def _apply(self, fn):
358 for module in self.children():
--> 359 module._apply(fn)
360
361 def compute_should_use_set_data(tensor, tensor_applied):
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in _apply(self, fn)
379 # `with torch.no_grad():`
380 with torch.no_grad():
--> 381 param_applied = fn(param)
382 should_use_set_data = compute_should_use_set_data(param, param_applied)
383 if should_use_set_data:
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in convert(t)
608 if convert_to_format is not None and t.dim() == 4:
609 return t.to(device, dtype if t.is_floating_point() else None, non_blocking, memory_format=convert_to_format)
--> 610 return t.to(device, dtype if t.is_floating_point() else None, non_blocking)
611
612 return self._apply(convert)
RuntimeError: CUDA error: device-side assert triggered
Code provided in examples works quite well. This seems like some problem with the model. I cant say without having data and full code to reproduce the error :)
@nvnvashisth I just added a multi-class classification example (flower classification with 104 classes). It might be useful for you: https://github.com/abhishekkrthakur/tez/blob/main/examples/image_classification/flower_classification.py
Let me know if it still doesnt work.
Code provided in examples works quite well. This seems like some problem with the model. I cant say without having data and full code to reproduce the error :)
I have the code privately on your twitter (DM). That's the only way I could figure to reach you privately.
@nvnvashisth I just added a multi-class classification example (flower classification with 104 classes). It might be useful for you: https://github.com/abhishekkrthakur/tez/blob/main/examples/image_classification/flower_classification.py
Let me know if it still doesnt work.
I'll give it a try. Thanks
@abhishekkrthakur it is so weird. I didn't really change anything and it started working. No more cuda error. Thanks for quick support.
wow. maybe you updated torch?
Not really, I was running in colab, was using the default one.
It seems like your image is RGBA. Can you convert it to RGB? or do you have to use RGBA? If latter, then try writing your own dataloader.
@abhishekkrthakur
Thank you! This helped me to solve the above error.