fjarri/reikna

EntrywiseNorm does not work if axes are specified

schreon opened this issue · 1 comments

Specifying axes when creating an EntrywiseNorm computation leads to the following error:

ValueError: Incompatible types of the transformation parameter 'input' (Type(float32)) and the node 'output' (Type(float32, shape=(100, 100), strides=(400, 4)))

Here ist the py.test code I added to test/test_linalg/test_norm.py :

@pytest.mark.parametrize('dtype', [numpy.float32, numpy.complex64], ids=['float32', 'complex64'])
@pytest.mark.parametrize('order', [0.5, 1, 2])
@pytest.mark.parametrize('num_axes', [1, 2, 3])
def test_entrywise_norm(thr, dtype, order, num_axes):

    shape = tuple([100]*num_axes)

    # norm over last axis
    axes = tuple([num_axes - 1])
    res_shape = tuple(shape[i] for i in range(len(shape)) if i not in axes)

    a_host = get_test_array(shape, dtype)
    res_host = (a_host**order).sum(axis=axes)**(1. / order)

    a_dev = thr.to_device(a_host)
    res_dev = thr.array(res_shape, dtype=dtype)

    norm = EntrywiseNorm(a_dev, order=order, axes=axes)
    norm = norm.compile(thr)

    norm(res_dev, a_dev)

    assert diff_is_negligible(res_dev.get(), res_host)
    assert diff_is_negligible(a_dev.get(), a_host)

I think the problem is in the contructor of the EntrywiseNorm class:

    def __init__(self, arr_t, order=2, axes=None):
        tr_elems = norm_const(arr_t, order)
        out_dtype = tr_elems.output.dtype
        res_t = Type(out_dtype) # this is wrong if axes are specified
        tr_sum = norm_const(res_t, 1. / order)

        rd = Reduce(Type(out_dtype, arr_t.shape), predicate_sum(out_dtype), axes=axes)
        rd.parameter.input.connect(tr_elems, tr_elems.output, input_prime=tr_elems.input)
        rd.parameter.output.connect(tr_sum, tr_sum.input, output_prime=tr_sum.output)

        self._rd = rd

        Computation.__init__(self, [
            Parameter('output', Annotation(res_t, 'o')),
            Parameter('input', Annotation(arr_t, 'i'))])

Here is my fixed version. It works, but the accuracy is too bad, so diff_is_negligible does not pass:

    def __init__(self, arr_t, order=2, axes=None):
        tr_elems = norm_const(arr_t, order)
        out_dtype = tr_elems.output.dtype
        if axes is None:
            res_t = Type(out_dtype)
        else:
            res_shape = tuple(arr_t.shape[i] for i in range(len(arr_t.shape)) if i not in axes)
            res_t = Type(out_dtype, shape=res_shape)
        tr_sum = norm_const(res_t, 1. / order)

        rd = Reduce(Type(out_dtype, arr_t.shape), predicate_sum(out_dtype), axes=axes)
        rd.parameter.input.connect(tr_elems, tr_elems.output, input_prime=tr_elems.input)
        rd.parameter.output.connect(tr_sum, tr_sum.input, output_prime=tr_sum.output)

        self._rd = rd

        Computation.__init__(self, [
            Parameter('output', Annotation(res_t, 'o')),
            Parameter('input', Annotation(arr_t, 'i'))])

Thank you for reporting! I fixed it a bit differently, using the fact that we already have the correct expected shape in the output parameter of rd (which processes its axes argument similarly to what you did).

As for the accuracy, it's a usual problem with single precision, which seem to work a bit differently on GPUs. Should work fine with the double precision.