Quansight-Labs/numpy.net

apply_along_axis returns differently compared to numpy

Happypig375 opened this issue · 3 comments

System.Console.WriteLine(np.apply_along_axis((ndarray all, ndarray view) => np.sum(view), 0, np.arange(9).reshape(3,3)))
INT32 
9
import numpy as np
print(np.apply_along_axis(np.sum, 0, np.arange(9).reshape(3,3)))
[ 9 12 15]

Seems that it just automatically gets the first element.

System.Console.WriteLine(np.apply_along_axis((ndarray all, ndarray view) => np.diag(view), -1, np.arange(9).reshape(3,3)))
INT32 
{ { 0, 0, 0 },
  { 0, 1, 0 },
  { 0, 0, 2 } }
import numpy as np
print(np.apply_along_axis(np.diag, -1, np.arange(9).reshape(3,3)))
[[[0 0 0]
  [0 1 0]
  [0 0 2]]

 [[3 0 0]
  [0 4 0]
  [0 0 5]]

 [[6 0 0]
  [0 7 0]
  [0 0 8]]]

Your anonymous function is not getting the expected version of apply_along_axis called. I have three different signatures for the same API. (I implemented it a long time ago and don't remember why I did it that way).

You can set a breakpoint in the various functions to see which is being called. To get the results you are expecting, you will need to adjust your anonymous function or do it the way I demonstrated below.

 public delegate ndarray apply_along_axis_view(ndarray a, ndarray view);
    public delegate ndarray apply_along_axis_indices(ndarray a, IList<npy_intp> indices);
    public delegate object apply_along_axis_fn(ndarray a, params object[] args);

    [TestMethod]
    public void test_HadrianTang_16()
    {
        var A = np.arange(9).reshape(3, 3);

        // your sample not calling the expected API
        var X = apply_along_axis((ndarray all, ndarray view) => np.sum(view), 0, A);
        print(X);

        // sample that works as expected.
        object my_func(ndarray a, params object[] args)
        {
            return np.sum(a);
        }

        var Y = apply_along_axis(my_func, 0, A);
        print(Y);

        return;
    }

    public static ndarray apply_along_axis(apply_along_axis_indices fn, int axis, ndarray arr)
    {
        return arr;
    }

    public static ndarray apply_along_axis(apply_along_axis_view fn, int axis, ndarray arr)
    {
        return arr;
    }

    public static ndarray apply_along_axis(apply_along_axis_fn func1d, int axis, ndarray arr, params object[] args)
    {
        return arr;
    }

Hmm. Then how is the apply_along_axis_view overload supposed to work then?