wrong output for gathering slices
Closed this issue · 1 comments
yke0305 commented
My problem is similar to the completed issue: #7 .
I have tried to use the z.get to gather slices, but when the matrix is decomposed to 4 or more along one axis, the gathered result goes wrong. which means: for the following code, if I run in terminal with mpirun -np 4 python script.py it returns the correct figure, but if I use 8 (or more) or using 4 processors but with slab decomposition, the slices are not gathered in correct order.
import numpy as np
import matplotlib.pyplot as plt
from mpi4py_fft import PFFT, newDistArray
# Parameters
N = [128, 128, 128] # Number of grid points in each dimension
L = 10.0 # Physical length of the domain
k = 5.0 # Wavenumber for Helmholtz equation
sigma = 0.5 # Standard deviation for the Gaussian source
comm = MPI.COMM_WORLD # MPI communicator
# 3D Grid setup (distributed across processes)
grid_shape = tuple(N)
dx = L / N[0] # Grid spacing assuming cubic domain
# Set up the FFT object and distribution
fft = PFFT(comm, grid_shape, axes=(0, 1, 2), dtype=np.complex64)
f = newDistArray(fft, False) # Source term array (distributed)
# Get the local slice for each process
local_slice_x = fft.local_slice(False)[0]
local_slice_y = fft.local_slice(False)[1]
local_slice_z = fft.local_slice(False)[2]
# Create local grids using the slices
x = np.linspace(-L/2, L/2, N[0], endpoint=False)[local_slice_x] # x-dimension
y = np.linspace(-L/2, L/2, N[1], endpoint=False)[local_slice_y] # y-dimension
z = np.linspace(-L/2, L/2, N[2], endpoint=False)[local_slice_z] # z-dimension
# Create a meshgrid (local to each process)
X_local, Y_local, Z_local = np.meshgrid(x, y, z, indexing='ij')
# Define the Gaussian point source (local data per process)
f[:] = np.exp(-((X_local**2 + Y_local**2 + Z_local**2) / (2 * sigma**2)))
f /= np.sum(f) # Normalize
result = newDistArray(fft, False)
result[:] = f
# Gather the data from all processes (for plotting only)
f0 = result.get((slice(None), slice(None), slice(None)))
# Visualization (only the root process will plot)
if comm.rank == 0:
# Plot slices through the center of the domain
slice_index = N[0] // 2
plt.figure(figsize=(15, 5))
# X-Y slice
plt.subplot(1, 3, 1)
plt.imshow(np.abs(f0[:, :, slice_index]), extent=(-L/2, L/2, -L/2, L/2), origin='lower', cmap='inferno')
plt.colorbar(label='u(x, y)')
plt.title('Slice through X-Y plane')
plt.xlabel('x')
plt.ylabel('y')
# X-Z slice
plt.subplot(1, 3, 2)
plt.imshow(np.abs(f0[:, slice_index, :]), extent=(-L/2, L/2, -L/2, L/2), origin='lower', cmap='inferno')
plt.colorbar(label='u(x, z)')
plt.title('Slice through X-Z plane')
plt.xlabel('x')
plt.ylabel('z')
# Y-Z slice
plt.subplot(1, 3, 3)
plt.imshow(np.abs(f0[slice_index, :, :]), extent=(-L/2, L/2, -L/2, L/2), origin='lower', cmap='inferno')
plt.colorbar(label='u(y, z)')
plt.title('Slice through Y-Z plane')
plt.xlabel('y')
plt.ylabel('z')
plt.tight_layout()
plt.show()
yke0305 commented
Now I figure it out: I should not do the normalization: f /= np.sum(f) , since this is calculated on each slices, not globally. And now I know the z.get is very useful :)