EntrywiseNorm does not work if axes are specified
schreon opened this issue · 1 comments
schreon commented
Specifying axes when creating an EntrywiseNorm computation leads to the following error:
ValueError: Incompatible types of the transformation parameter 'input' (Type(float32)) and the node 'output' (Type(float32, shape=(100, 100), strides=(400, 4)))
Here ist the py.test code I added to test/test_linalg/test_norm.py :
@pytest.mark.parametrize('dtype', [numpy.float32, numpy.complex64], ids=['float32', 'complex64'])
@pytest.mark.parametrize('order', [0.5, 1, 2])
@pytest.mark.parametrize('num_axes', [1, 2, 3])
def test_entrywise_norm(thr, dtype, order, num_axes):
shape = tuple([100]*num_axes)
# norm over last axis
axes = tuple([num_axes - 1])
res_shape = tuple(shape[i] for i in range(len(shape)) if i not in axes)
a_host = get_test_array(shape, dtype)
res_host = (a_host**order).sum(axis=axes)**(1. / order)
a_dev = thr.to_device(a_host)
res_dev = thr.array(res_shape, dtype=dtype)
norm = EntrywiseNorm(a_dev, order=order, axes=axes)
norm = norm.compile(thr)
norm(res_dev, a_dev)
assert diff_is_negligible(res_dev.get(), res_host)
assert diff_is_negligible(a_dev.get(), a_host)
I think the problem is in the contructor of the EntrywiseNorm class:
def __init__(self, arr_t, order=2, axes=None):
tr_elems = norm_const(arr_t, order)
out_dtype = tr_elems.output.dtype
res_t = Type(out_dtype) # this is wrong if axes are specified
tr_sum = norm_const(res_t, 1. / order)
rd = Reduce(Type(out_dtype, arr_t.shape), predicate_sum(out_dtype), axes=axes)
rd.parameter.input.connect(tr_elems, tr_elems.output, input_prime=tr_elems.input)
rd.parameter.output.connect(tr_sum, tr_sum.input, output_prime=tr_sum.output)
self._rd = rd
Computation.__init__(self, [
Parameter('output', Annotation(res_t, 'o')),
Parameter('input', Annotation(arr_t, 'i'))])
Here is my fixed version. It works, but the accuracy is too bad, so diff_is_negligible does not pass:
def __init__(self, arr_t, order=2, axes=None):
tr_elems = norm_const(arr_t, order)
out_dtype = tr_elems.output.dtype
if axes is None:
res_t = Type(out_dtype)
else:
res_shape = tuple(arr_t.shape[i] for i in range(len(arr_t.shape)) if i not in axes)
res_t = Type(out_dtype, shape=res_shape)
tr_sum = norm_const(res_t, 1. / order)
rd = Reduce(Type(out_dtype, arr_t.shape), predicate_sum(out_dtype), axes=axes)
rd.parameter.input.connect(tr_elems, tr_elems.output, input_prime=tr_elems.input)
rd.parameter.output.connect(tr_sum, tr_sum.input, output_prime=tr_sum.output)
self._rd = rd
Computation.__init__(self, [
Parameter('output', Annotation(res_t, 'o')),
Parameter('input', Annotation(arr_t, 'i'))])
fjarri commented
Thank you for reporting! I fixed it a bit differently, using the fact that we already have the correct expected shape in the output
parameter of rd
(which processes its axes
argument similarly to what you did).
As for the accuracy, it's a usual problem with single precision, which seem to work a bit differently on GPUs. Should work fine with the double precision.