google-deepmind/ferminet

Question About load Checkpoint

Closed this issue · 1 comments

Hello, I have a question about loading checkpoint function.
To the best of my knowledge, we can save the model by /ferminet/train.py in :

      if time.time() - time_of_last_ckpt > cfg.log.save_frequency * 60:
        checkpoint.save(ckpt_save_path, t, data, params, opt_state, mcmc_width)
        time_of_last_ckpt = time.time()
        sys.exit(0)

and this function is implemented by np.savez.
However, when I attempt to load this chekpoint, it will not pass the check logic in checkpoint.resotre, specifically:

  with open(restore_filename, 'rb') as f:
    ckpt_data = np.load(f, allow_pickle=True)
    # Retrieve data from npz file. Non-array variables need to be converted back
    # to natives types using .tolist().
    t = ckpt_data['t'].tolist() + 1  # Return the iterations completed.
    data = ckpt_data['data']
    params = ckpt_data['params'].tolist()
    opt_state = ckpt_data['opt_state'].tolist()
    mcmc_width = jnp.array(ckpt_data['mcmc_width'].tolist())
    if data.shape[0] != jax.device_count():
      raise ValueError(
          f'Incorrect number of devices found. Expected {data.shape[0]}, found '
          f'{jax.device_count()}.')

I attempt to alleviate this issue, and I found that for checkpoint.save function, the data is FermiNetData class, which contains four array named position, spins, atoms, charges, respectively. However, when I load this numpy checkpoint, the data is an array merely consists of strings [position, spins, atoms, charges]. It seems that this part may have some questions?
I'm wondering wheter this part should be correct? With great appreciate for your time and efforts in reading my issue.

Sorry, this was broken by mistake with some refactoring. This is now fixed.