The temperature derivative of the transmission spectra model becomes NaN
sh-tada opened this issue · 1 comments
sh-tada commented
I found that the derivative of the transmission spectra model (ArtTransPure) with respect to temperature results in NaN. Similarly, derivatives with respect to gravity_btm, radius_btm, and mean molecular weight also yield NaN.
Sample code
import jax
from jax.config import config
import pandas as pd
import numpy as np
import jax.numpy as jnp
from exojax.utils.grids import wavenumber_grid
from exojax.spec.opacalc import OpaPremodit
from exojax.spec.atmrt import ArtTransPure
from exojax.utils.constants import RJ, Rs
from exojax.spec.api import MdbHitran
from exojax.utils.astrofunc import gravity_jupiter
from exojax.spec.unitconvert import wav2nu
from exojax.spec.specop import SopRotation
from exojax.spec.specop import SopInstProfile
from exojax.utils.instfunc import resolution_to_gaussian_std
config.update("jax_enable_x64", True)
def read_data(filename):
dat = pd.read_csv(filename, delimiter=" ")
wav = dat["Wavelength[um]"]
mask = (wav > 2.25) & (wav < 2.6)
return wav[mask], dat["Rp/Rs"][mask]
# Read data
filename = "/home/kawahara/exojax/tests/integration/comparison/transmission/spectrum/CO100percent_500K.dat"
wav, rprs = read_data(filename)
inst_nus = wav2nu(np.array(wav), "um")
# Model
Nx = 300000
nu_grid, wav, res = wavenumber_grid(22900.0, 26000.0, Nx, unit="AA", xsmode="premodit")
art = ArtTransPure(pressure_top=1.0e-15, pressure_btm=1.0e1, nlayer=100)
art.change_temperature_range(490.0, 510.0)
mdb = MdbHitran("CO", nu_grid, gpu_transfer=True, isotope=1)
opa = OpaPremodit(
mdb=mdb,
nu_grid=nu_grid,
auto_trange=[490, 510],
dit_grid_resolution=1,
)
sop_inst = SopInstProfile(nu_grid, res, vrmax=100.0)
def model(params):
mmr_CO, mu_fid, T_fid, gravity_btm, radius_btm, RV = params
Tarr = T_fid * np.ones_like(art.pressure)
mmr_arr = art.constant_mmr_profile(mmr_CO)
mmw = mu_fid * np.ones_like(art.pressure)
gravity = art.gravity_profile(Tarr, mmw, radius_btm, gravity_btm)
xsmatrix = opa.xsmatrix(Tarr, art.pressure)
dtau = art.opacity_profile_xs(xsmatrix, mmr_arr, opa.mdb.molmass, gravity)
Rp2 = art.run(dtau, Tarr, mmw, radius_btm, gravity_btm)
Rp2_sample = sop_inst.sampling(Rp2, RV, inst_nus)
return jnp.sqrt(Rp2_sample)
def objective(params):
return jnp.sum((np.array(rprs[::-1]) - model(params)) ** 2)
# Gradient
grad = jax.grad(objective)
params = np.array([1, 28.00863, 500, gravity_jupiter(1.0, 1.0), RJ, 0])
print()
print("Parameters: mmr_CO, mu_fid, T_fid, gravity_btm, radius_btm, RV")
print("Gradient", grad(params))
Result
Parameters: mmr_CO, mu_fid, T_fid, gravity_btm, radius_btm, RV
Gradient [4.57663499e+00 nan nan nan nan 3.38162608e-03]
HajimeKawahara commented
Thanks, It looks atmprof.nomralized_layer_height
is not differentiable.