Improve error when calling replicated function without Parallel
carlini opened this issue · 1 comments
carlini commented
Currently if you call replicate() and then don't use Parallel() to call a function, you get a bad error message. Try this code:
import objax
import numpy as np
mod = objax.nn.Conv2D(2, 4, 3)
with mod.vars().replicate():
print(mod(np.ones((8,2,10,10))))
the error says
Traceback (most recent call last):
File "b.py", line 9, in <module>
print(mod(np.ones((8,2,10,10))))
File "/opt/conda/lib/python3.7/site-packages/objax/nn/layers.py", line 185, in __call__
dimension_numbers=('NCHW', 'HWIO', 'NCHW'))
File "/opt/conda/lib/python3.7/site-packages/jax/_src/lax/lax.py", line 555, in conv_general_dilated
dnums = conv_dimension_numbers(lhs.shape, rhs.shape, dimension_numbers)
File "/opt/conda/lib/python3.7/site-packages/jax/_src/lax/lax.py", line 5955, in conv_dimension_numbers
raise TypeError(msg.format(len(lhs_shape), len(rhs_shape)))
TypeError: convolution requires lhs and rhs ndim to be equal, got 4 and 5.
which obviously means that you accidentally are evaluating a function that was replicated without wrapping in Parallel.
AlexeyKurakin commented
One possible solution - if function is called with SharderDeviceArray
then it must be replicated.
In the example above - mod.__call__
checks whether input is SharderDeviceArray
or not. If it is SharderDeviceArray
then exception is thrown saying that user have to call objax.Parallel