mpi4py/mpi4py-fft

wrong output for gathering slices

Closed this issue · 1 comments

My problem is similar to the completed issue: #7 .

I have tried to use the z.get to gather slices, but when the matrix is decomposed to 4 or more along one axis, the gathered result goes wrong. which means: for the following code, if I run in terminal with mpirun -np 4 python script.py it returns the correct figure, but if I use 8 (or more) or using 4 processors but with slab decomposition, the slices are not gathered in correct order.

import numpy as np
import matplotlib.pyplot as plt
from mpi4py_fft import PFFT, newDistArray

# Parameters
N = [128, 128, 128]  # Number of grid points in each dimension
L = 10.0             # Physical length of the domain
k = 5.0              # Wavenumber for Helmholtz equation
sigma = 0.5          # Standard deviation for the Gaussian source
comm = MPI.COMM_WORLD  # MPI communicator

# 3D Grid setup (distributed across processes)
grid_shape = tuple(N)
dx = L / N[0]  # Grid spacing assuming cubic domain

# Set up the FFT object and distribution
fft = PFFT(comm, grid_shape, axes=(0, 1, 2), dtype=np.complex64)
f = newDistArray(fft, False)  # Source term array (distributed)

# Get the local slice for each process
local_slice_x = fft.local_slice(False)[0]
local_slice_y = fft.local_slice(False)[1]
local_slice_z = fft.local_slice(False)[2]

# Create local grids using the slices
x = np.linspace(-L/2, L/2, N[0], endpoint=False)[local_slice_x]  # x-dimension
y = np.linspace(-L/2, L/2, N[1], endpoint=False)[local_slice_y]  # y-dimension
z = np.linspace(-L/2, L/2, N[2], endpoint=False)[local_slice_z]  # z-dimension
# Create a meshgrid (local to each process)
X_local, Y_local, Z_local = np.meshgrid(x, y, z, indexing='ij')

# Define the Gaussian point source (local data per process)
f[:] = np.exp(-((X_local**2 + Y_local**2 + Z_local**2) / (2 * sigma**2)))
f /= np.sum(f)  # Normalize

result = newDistArray(fft, False)
result[:] = f 

# Gather the data from all processes (for plotting only)
f0 = result.get((slice(None), slice(None), slice(None)))

# Visualization (only the root process will plot)
if comm.rank == 0:

    # Plot slices through the center of the domain
    slice_index = N[0] // 2

    plt.figure(figsize=(15, 5))

    # X-Y slice
    plt.subplot(1, 3, 1)
    plt.imshow(np.abs(f0[:, :, slice_index]), extent=(-L/2, L/2, -L/2, L/2), origin='lower', cmap='inferno')
    plt.colorbar(label='u(x, y)')
    plt.title('Slice through X-Y plane')
    plt.xlabel('x')
    plt.ylabel('y')

    # X-Z slice
    plt.subplot(1, 3, 2)
    plt.imshow(np.abs(f0[:, slice_index, :]), extent=(-L/2, L/2, -L/2, L/2), origin='lower', cmap='inferno')
    plt.colorbar(label='u(x, z)')
    plt.title('Slice through X-Z plane')
    plt.xlabel('x')
    plt.ylabel('z')

    # Y-Z slice
    plt.subplot(1, 3, 3)
    plt.imshow(np.abs(f0[slice_index, :, :]), extent=(-L/2, L/2, -L/2, L/2), origin='lower', cmap='inferno')
    plt.colorbar(label='u(y, z)')
    plt.title('Slice through Y-Z plane')
    plt.xlabel('y')
    plt.ylabel('z')

    plt.tight_layout()
    plt.show()

Now I figure it out: I should not do the normalization: f /= np.sum(f) , since this is calculated on each slices, not globally. And now I know the z.get is very useful :)