simoninithomas/Deep_reinforcement_learning_Course

question on checkpoint restart

rnunziata opened this issue · 0 comments

Was not sure why code is not restoring mean and rms values I made mods as follows so on restart it can pickup where it left off. Is there a reason why this is not done?

>>>>>I added two new paths

   # Define the model path names
    model_path = 'models/{}.model'.format(env_id)
    predictor_path = 'models/{}.pred'.format(env_id)
    target_path = 'models/{}.target'.format(env_id)
    mean_path = 'models/{}_mean.pt'.format(env_id)
    reward_rms_path = 'models/{}_rms.pt'.format(env_id)   

>>>> changed startup code :

    # Loads models
    if is_load_model:
        obs_rms    = torch.load(mean_path)
        reward_rms = torch.load(reward_rms_path)          
        if use_cuda:
            print("Loading PPO Saved Model using GPU")
            agent.model.load_state_dict(torch.load(model_path))
            agent.rnd.predictor.load_state_dict(torch.load(predictor_path))
            agent.rnd.target.load_state_dict(torch.load(target_path))
        else:
            print("Loading PPO Saved Model using CPU")
            agent.model.load_state_dict(torch.load(model_path, map_location='cpu'))
            agent.rnd.predictor.load_state_dict(torch.load(predictor_path, map_location='cpu'))
            agent.rnd.target.load_state_dict(torch.load(target_path, map_location='cpu'))            
    else:
        # normalize obs
       print(" first time intialization")
       next_obs = []
   
       for step in range(num_step * pre_obs_norm_step):
         actions = np.random.randint(0, output_size, size=(num_worker,))

       for parent_conn, action in zip(parent_conns, actions):
             parent_conn.send(action)

       for parent_conn in parent_conns:
             s, r, d, rd, lr = parent_conn.recv()
             next_obs.append(s[3, :, :].reshape([1, 84, 84]))

       if len(next_obs) % (num_step * num_worker) == 0:
             next_obs = np.stack(next_obs)
             obs_rms.update(next_obs)
             next_obs = []      

>>>>>and in check pointing 

            torch.save(agent.model.state_dict(), model_path)
            torch.save(agent.rnd.predictor.state_dict(), predictor_path)
            torch.save(agent.rnd.target.state_dict(), target_path)
            torch.save(obs_rms, mean_path) 
            torch.save(reward_rms, reward_rms_path)