DifferentiableUniverseInitiative/jaxDecomp

FFTs are not working properly

Closed this issue · 12 comments

Comparing the 3D FFT computed by jaxdecomp and manually in jax, I realized that the result of fft3d does not match with the non-distributed version.
This could be due to a transposition of the pfft3d result, which is something more or less conventional, to save 2 all-to-all communications in a forward-backward step, but depending on the partitioning scheme, I get a result that is in different orders.

I have modified the FFT test to actually detect this problem in the fix_fft branch in #12

@ASKabalan can you take a look?

If we don't provide any other information to the user regarding the order of dimensions in the FFT, the user expects the following to be true:

pdims = (2, 2)
mesh_shape = (4, 4, 4)

devices = mesh_utils.create_device_mesh(pdims)
mesh = Mesh(devices, axis_names=('z', 'y'))
sharding = jax.sharding.NamedSharding(mesh, P('z', 'y'))

local_mesh_shape = [mesh_shape[0]//pdims[0], mesh_shape[1]//pdims[1], mesh_shape[2]]

z = jax.make_array_from_single_device_arrays(shape=mesh_shape,
                                             sharding=sharding,
                                             arrays=[jax.random.normal(key, local_mesh_shape)])

with mesh:
    kfield_dist = jaxdecomp.fft.pfft3d(z)

kfield_dist = multihost_utils.process_allgather(kfield_dist, tiled=True)

kfield = np.fft.fftn(multihost_utils.process_allgather(z, tiled=True))

# This should be true to within numerical accuracy
assert_allclose(kfield_dist, kfield )

I did some more testing this this little script and depending on the pdims, the result of the FFT is in different order, but what's worse is that apparently, there is no order that works for a pencil decomposition:

>>> import itertools
>>> dims = [0,1,2]
>>> all_perms = list(itertools.permutations(dims))

# pdims = [4, 1]
>>> for p in all_perms:
>>>    print(p, abs((field - fftn(z).transpose(p))).max())
(0, 1, 2) 21.092330368764046
(0, 2, 1) 16.964527287841943
(1, 0, 2) 21.092330574773722
(1, 2, 0) 8.642673492431641e-07
(2, 0, 1) 21.092330574773722
(2, 1, 0) 14.955970957875252

# pdims = [1, 4]
>>> for p in all_perms:
>>>    print(p, abs((field - fftn(z).transpose(p))).max())
(0, 1, 2) 21.092331293889643
(0, 2, 1) 1.2834044098934218e-06
(1, 0, 2) 21.092331293889643
(1, 2, 0) 16.964527282918105
(2, 0, 1) 14.955971911549568
(2, 1, 0) 21.092330516874384

# pdims = [2, 2]
>>> for p in all_perms:
>>>    print(p, abs((field - fftn(z).transpose(p))).max())
(0, 1, 2) 25.08715713705188
(0, 2, 1) 21.621456843859477
(1, 0, 2) 22.202013279408632
(1, 2, 0) 22.202013279408632
(2, 0, 1) 18.170884130091107
(2, 1, 0) 16.552610787745994

I have decided not to test the pfft and Jax.fft.fftn for the fact that there was no easy way to compare, but I think it is doable.

Since pfft pifft and jax.fft.fftn jax.fft.ifftn give the same results, I didn't test the forward pass against JAX.

The reason they are not the same is the fact that cudecomp transpositions are in place .. this means that local slices keep the same shape but have a different leading axis.
So each slice gets some elements from another (uses communications)

Last time I tried, JAX transpositions transpose the slices (thus the local shape is changed) but maybe I should try transposing a gathered array.

For JaxPM this is not an issue since applying element wise ops on a transposed array changes nothing.

But I agree that this could be unexpected for the user.

I Will find a solution

It does actually matter, because when we want to perform filtering operations in Fourier space, we need to understand the order of the array, to know what Fourier frequency corresponds to what index of the Fourier array.

So, (up to potentially a transpose operation, and in that case it should be clearly documented to the user), I need to know how to relate the results of fftn(z) and pfft3d(z).

Ok for pencils it is probably because 'transposing' changes the distributed axis and does not exactly transpose the cube

Two things can be done

  • change the global shape abstract eval to a transposed one 2 0 1 for pencils and by extension the lowered result shape (I am pretty sure I tried with things like this and nothing worked)
  • use the cudecomp binded transpositions that I did not yet upgrade to Jax 0.4.x

I'll let you know.

Hummm no, I don't think that's the issue. Since the slices all have the same shape in this case, we should get the right result up to a transpose at the end. Printing explicitly the results of the pfft3d fft, the results are off in the pencil case, I see values that are nowhere to be seen in the correct fft. So it's not a problem of data order.

I notice that in the initialize pencils function, you didn't reproduce exactly the initialization of cufftMakePlan s that was in the original cuDecomp example:

I'm guessing that's maybe not completely without consequences.

I don't understand. The 1DPlans are the same.

Are you talking about config.transpose_axis_contiguous = false branch? And consequently the strided CufftManyPlan ?

Right, so, I'm suspecting something is wrong with the striding.

I added this debug line in the FourierExecutor<real_t>::Initialize() function:

  std::cout << "Transpose axis contiguous : " << config.transpose_axis_contiguous[0] << " "
            << config.transpose_axis_contiguous[1] << " "
            << config.transpose_axis_contiguous[2] << std::endl;

which returns:

Transpose axis contiguous : 0 1 1

which shouldn't be the case, we used to always initialize to 1,1,1, and I don't know why it would be different.

I added another debug print in build_fft_descriptor and there the transpose gets correctly initialized:

Initial axis contiguous : 1 1 1

ok ignore that comment, it's normal behaviour.... enforced at the time of creation of the grid descriptor.

oK, so, after a bunch of experimentation, I think we should start by getting transpose operations to work again. I'm running out of time for now, so just taking notes on things to try.

I found (by hacking) that just the transpose operations X->Y, Y->Z don't work correctly for a 2,2 mesh under the current way we run the scripts.

A first reason was that I needed to swap the partitioning axes like so:

mesh = Mesh(devices, axis_names=('x', 'y'))
global_array = multihost_utils.host_local_array_to_global_array(
    array, mesh, P('y', 'x'))

Because the order in which cudecomp is building its own 2d communicator is reversed. Note: These bugs cancel out if we only test that pifft3d(pfft3d(x)) =x.

With this fix, data is no longer mangled after a series of transpose (in my hacking experiments).

After this, I can actually find a transpose of the output array that matches the result of an fftn, but I don't really understand it (not sure how the axes match the definition of a Z pencil), and the transpose permutation is different if I use a different mesh like (1,4), so that's not good.

So this might kind of work as long as we are working with cubes, but it doesn't pass the FFT tests which have non cubic volumes, because then I can't just tranpose.

So, the order of things I would suggest investigating is the following:

  1. Have an x-pencil in input, do a transposeX->Y, retrieve the Y pencil as output, and getting this to work for arbitrary mesh definition et volume sizes (not necessarily cubic)
  2. Demonstrate that we can understand the shape of the pencil after 2 transposes X->Y, Y->Z
  3. Adjust the lowering of fft3d to return the correct shape in Fourier space, and always as a Z pencil

I will experiment with this.
But why are axis x and y .. the distributed axis are z and y in this order since the array is reversed.
so the array is ZYX.
The processor dimension (as documented here ) are Y/P_row Z/P_col ... so pdims are YZ
I already do this by reversing the pdims axis.

devices = mesh_utils.create_device_mesh(pdims[::-1])
mesh = Mesh(devices, axis_names=('z', 'y'))
global_array = multihost_utils.host_local_array_to_global_array(
      local_array, mesh, P('z', 'y'))

What is devices in your example?

Ok, I understand what you did

# Create computing mesgh
devices = mesh_utils.create_device_mesh(pdims[::-1])
mesh = Mesh(devices, axis_names=('x', 'y'))
sharding = jax.sharding.NamedSharding(mesh, P('y', 'x'))

### Create all initial distributed tensors ###
local_mesh_shape = [mesh_shape[0]//pdims[0], mesh_shape[1]//pdims[1], mesh_shape[2]]

This might have worked on accident, but I am almost sure that this

local_mesh_shape = [mesh_shape[0]//pdims[0], mesh_shape[1]//pdims[1], mesh_shape[2]]

is wrong, and it works only because your two pencils axis are the same (2x2) right?
Or you used a XY Slab, but in cuDecomp it is doing a YZ slab

it should be

local_mesh_shape = [mesh_shape[0]//pdims[1], mesh_shape[1]//pdims[0], mesh_shape[2]]

Check this https://nvidia.github.io/cuDecomp/overview.html#decomposition-layout

Right?

This trick also works with (4,1) (see the jaxpm example). I looked at the output of the pfft3 for different combinations of compute mesh shape. Note: the order of processes cudecomp uses to create its own compute mesh, is not necessarily the same as jax's, this can possibly lead to confusion.

My comment above is to try to understand this bettter

  1. Have an x-pencil in input, do a transposeX->Y, retrieve the Y pencil as output, and getting this to work for arbitrary mesh definition et volume sizes (not necessarily cubic)
  2. Demonstrate that we can understand the shape of the pencil after 2 transposes X->Y, Y->Z
  3. Adjust the lowering of fft3d to return the correct shape in Fourier space, and always as a Z pencil

We can define the device mesh however we want, but what matters is that we understand the order of the dimensions at the output of pfft3, and ideally the order of dimensions shouldn't depend on the device mesh we choose to use.

We need to be able to know what transposition to apply to jaxdecomp.fft.pfft3d(z) to recover np.fftn(z).