BatchNorm2D Error
Closed this issue · 2 comments
kwojcicki commented
When using pytorch 1.5 on a Colab TPU I get an error on the first batchnorm2d in a basic block (specifically this one https://github.com/owruby/shake-drop_pytorch/blob/master/models/shake_pyramidnet.py#L33)
2020-03-09 00:05:22.038767: E tensorflow/compiler/xla/xla_client/tf_logging.cc:11] Check failed: status.status() == ::tensorflow::Status::OK() (Invalid argument: Expected array argument for operand of batch norm training, but got (f32[128,16,32,96], f32[16], f32[16], f32[16]). vs. OK)
Wondering if you have seen this before?
owruby commented
Thank you for reporting. Sorry, i have not used this codes on TPU.
kwojcicki commented
This has been fixed as part of pytorch/xla#1736 and pytorch/xla#1735. So it should be fine on TPUs again 😄