google/objax

Improve error when calling replicated function without Parallel

carlini opened this issue · 1 comments

Currently if you call replicate() and then don't use Parallel() to call a function, you get a bad error message. Try this code:

import objax
import numpy as np

mod = objax.nn.Conv2D(2, 4, 3)

with mod.vars().replicate():
    print(mod(np.ones((8,2,10,10))))

the error says

Traceback (most recent call last):
  File "b.py", line 9, in <module>
    print(mod(np.ones((8,2,10,10))))
  File "/opt/conda/lib/python3.7/site-packages/objax/nn/layers.py", line 185, in __call__
    dimension_numbers=('NCHW', 'HWIO', 'NCHW'))
  File "/opt/conda/lib/python3.7/site-packages/jax/_src/lax/lax.py", line 555, in conv_general_dilated
    dnums = conv_dimension_numbers(lhs.shape, rhs.shape, dimension_numbers)
  File "/opt/conda/lib/python3.7/site-packages/jax/_src/lax/lax.py", line 5955, in conv_dimension_numbers
    raise TypeError(msg.format(len(lhs_shape), len(rhs_shape)))
TypeError: convolution requires lhs and rhs ndim to be equal, got 4 and 5.

which obviously means that you accidentally are evaluating a function that was replicated without wrapping in Parallel.

One possible solution - if function is called with SharderDeviceArray then it must be replicated.
In the example above - mod.__call__ checks whether input is SharderDeviceArray or not. If it is SharderDeviceArray then exception is thrown saying that user have to call objax.Parallel