Multi time step prediction
Opened this issue · 1 comments
Im struggling to implement a multi time step model where i use my prediction outputs as an input for the next iteration. Can someone please help me on this.
Load model parameters and configuration
with open('/data4/home/rohitsuresh/graphcast_/model/params/params-GraphCast_operational-ERA5-HRES_1979-2021-resolution_0.25-pressure_levels_13-mesh_2to6-precipitation_output_only.npz', 'rb') as model:
ckpt = checkpoint.load(model, graphcast.CheckPoint)
params = ckpt.params
model_config = ckpt.model_config
task_config = ckpt.task_config
Load statistics
with open('/data4/home/rohitsuresh/graphcast_/model/stats/stats-diffs_stddev_by_level.nc', 'rb') as f:
diffs_stddev_by_level = xarray.load_dataset(f).compute()
with open('/data4/home/rohitsuresh/graphcast_/model/stats/stats-mean_by_level.nc', 'rb') as f:
mean_by_level = xarray.load_dataset(f).compute()
with open('/data4/home/rohitsuresh/graphcast_/model/stats/stats-stddev_by_level.nc', 'rb') as f:
stddev_by_level = xarray.load_dataset(f).compute()
def construct_graphcast(model_config: graphcast.ModelConfig, task_config: graphcast.TaskConfig):
predictor = graphcast.GraphCast(model_config, task_config)
predictor = casting.Bfloat16Cast(predictor)
predictor = normalization.InputsAndResiduals(
predictor,
diffs_stddev_by_level=diffs_stddev_by_level,
mean_by_level=mean_by_level,
stddev_by_level=stddev_by_level
)
return predictor
@hk.transform_with_state
def run_forward(model_config, task_config, inputs, targets_template, forcings):
predictor = construct_graphcast(model_config, task_config)
return predictor(inputs, targets_template=targets_template, forcings=forcings)
def with_configs(fn):
return functools.partial(fn, model_config=model_config, task_config=task_config)
def with_params(fn):
return functools.partial(fn, params=params, state={})
def drop_state(fn):
return lambda **kw: fn(**kw)[0]
run_forward_jitted = drop_state(with_params(jax.jit(with_configs(run_forward.apply))))
class Predictor:
@classmethod
def predict(cls, inputs, targets, forcings) -> xarray.Dataset:
predictions = rollout.chunked_prediction(
predictor_fn=run_forward_jitted,
rng=jax.random.PRNGKey(0),
inputs=inputs,
targets_template=targets,
forcings=forcings,
num_steps_per_chunk=2 # Adjust this value based on your needs
)
return predictions
Assuming inputs, targets, and forcings are already prepared
predictions = Predictor.predict(inputs, targets, forcings)
predictions.to_dataframe().to_csv('predictions2024_0.25_step3.csv', sep=',')
ValueError: 'grid2mesh_gnn/_networks_builder/encoder_nodes_grid_nodes_mlp//linear_0/w' with retrieved shape (184, 512) does not match shape=[189, 512] dtype=dtype(bfloat16)