Visualization utilties
nbren12 opened this issue · 0 comments
nbren12 commented
earth2grid can be used to accelerate plotting workflows. It might be worth adding some visualization utilties.
from earth2grid import healpix
import generate
import torch
import matplotlib.pyplot as plt
import cartopy.crs
import numpy as np
def create_regular_grid_in_projection(projection, nx, ny):
"""
Create a regular grid of lat-lon coordinates in a given Cartopy projection.
Parameters:
projection (cartopy.crs.Projection): The desired Cartopy projection
resolution (float): The grid resolution in projection units
Returns:
tuple: Two 2D arrays, one for latitudes and one for longitudes
"""
# Get the projection's limits
x_min, x_max, y_min, y_max = projection.x_limits + projection.y_limits
# Create a regular grid in the projection coordinates
x = np.linspace(x_min, x_max, nx)
y = np.linspace(y_min, y_max, ny)
xx, yy = np.meshgrid(x, y)
# Transform the gridded coordinates back to lat-lon
geodetic = cartopy.crs.Geodetic()
transformed = geodetic.transform_points(projection, xx, yy)
lons = transformed[..., 0]
lats = transformed[..., 1]
# Filter out invalid points (those outside the projection's valid domain)
valid = np.logical_and(np.isfinite(lons), np.isfinite(lats))
lons[~valid] = np.nan
lats[~valid] = np.nan
return lats, lons, xx, yy
def visualize(x):
hpx = healpix.Grid(healpix.npix2level(x.shape[-1]))
crs = cartopy.crs.Robinson()
lat, lon, xx, yy = create_regular_grid_in_projection(crs, 256, 512)
mask = ~np.isnan(lat)
latm = lat[mask]
lonm = lon[mask]
regrid = hpx.get_bilinear_regridder_to(latm, lonm)
regrid.to(x)
out = torch.zeros_like(torch.tensor(lat)).to(x)
out[mask] = regrid(x)
out[~mask] = torch.nan
ax = plt.subplot(projection=crs)
im = ax.pcolormesh(xx, yy, out.cpu(), transform=crs)
ax.coastlines()
plt.colorbar(im, orientation="horizontal")