NVlabs/earth2grid

Visualization utilties

nbren12 opened this issue · 0 comments

earth2grid can be used to accelerate plotting workflows. It might be worth adding some visualization utilties.

from earth2grid import healpix
import generate
import torch
import matplotlib.pyplot as plt
import cartopy.crs
import numpy as np


def create_regular_grid_in_projection(projection, nx, ny):
    """
    Create a regular grid of lat-lon coordinates in a given Cartopy projection.

    Parameters:
    projection (cartopy.crs.Projection): The desired Cartopy projection
    resolution (float): The grid resolution in projection units

    Returns:
    tuple: Two 2D arrays, one for latitudes and one for longitudes
    """
    # Get the projection's limits
    x_min, x_max, y_min, y_max = projection.x_limits + projection.y_limits

    # Create a regular grid in the projection coordinates
    x = np.linspace(x_min, x_max, nx)
    y = np.linspace(y_min, y_max, ny)
    xx, yy = np.meshgrid(x, y)

    # Transform the gridded coordinates back to lat-lon
    geodetic = cartopy.crs.Geodetic()
    transformed = geodetic.transform_points(projection, xx, yy)

    lons = transformed[..., 0]
    lats = transformed[..., 1]

    # Filter out invalid points (those outside the projection's valid domain)
    valid = np.logical_and(np.isfinite(lons), np.isfinite(lats))
    lons[~valid] = np.nan
    lats[~valid] = np.nan

    return lats, lons, xx, yy


def visualize(x):
    hpx = healpix.Grid(healpix.npix2level(x.shape[-1]))
    crs = cartopy.crs.Robinson()
    lat, lon, xx, yy = create_regular_grid_in_projection(crs, 256, 512)
    mask = ~np.isnan(lat)
    latm = lat[mask]
    lonm = lon[mask]
    regrid = hpx.get_bilinear_regridder_to(latm, lonm)
    regrid.to(x)
    out = torch.zeros_like(torch.tensor(lat)).to(x)
    out[mask] = regrid(x)
    out[~mask] = torch.nan
    ax = plt.subplot(projection=crs)
    im = ax.pcolormesh(xx, yy, out.cpu(), transform=crs)
    ax.coastlines()
    plt.colorbar(im, orientation="horizontal")