YifeiZhou02/ArCHer

Issues with loading in `lm_optimizer_state_dict`

starship006 opened this issue · 2 comments

I am working on a modified version of this repository with slight changes, so I am trying to see if this is an error on my side or not. My setup uses a distributed GPU setup using Accelerate. I am having some issues loading in the lm_optimizer. Here is my current saving and loading code inside of trainer.py:

 def save(self, path):
        torch.save({'model_state_dict': self.accelerator.unwrap_model(self.agent.model).state_dict(),
                    'critic_state_dict': self.accelerator.unwrap_model(self.agent.critic).state_dict(),
                    'target_critic_state_dict': self.accelerator.unwrap_model(self.agent.target_critic).state_dict(),
                    
                    }, path)
        # do it at the same path, but with a different name
        torch.save({'critic_optimizer_state_dict': self.critic_optimizer.state_dict()}, path.replace('.pt', '_critic_optim.pt'))
        torch.save({'lm_optimizer_state_dict': self.lm_optimizer.state_dict()}, path.replace('.pt', '_lm_optim.pt'))

def load(self, path):
        # We've modified the below to load in via the CPU. This fixes a memory issue. The agent will/should be prepared down the line, and the critic/lm optimizer is re-prepared here.
        checkpoint = torch.load(path, map_location=torch.device('cpu'))
        self.agent.model.load_state_dict(checkpoint['model_state_dict'])
        self.agent.critic.load_state_dict(checkpoint['critic_state_dict'])
        self.agent.target_critic.load_state_dict(checkpoint['target_critic_state_dict'])
        
        
        critic_optim_checkpoint = torch.load(path.replace('.pt', '_critic_optim.pt'))
        self.critic_optimizer.load_state_dict(critic_optim_checkpoint['critic_optimizer_state_dict'])    
        
        # The following crashes
        #trainer_checkpoint = torch.load(path.replace('.pt', '_lm_optim.pt'))
        #self.lm_optimizer.load_state_dict(trainer_checkpoint['lm_optimizer_state_dict'])

        self.critic_optimizer, self.lm_optimizer = self.accelerator.prepare(self.critic_optimizer, self.lm_optimizer)        
        return self.agent

The code above works fine, but isn't loading in lm_optimizer. However, when uncommenting those lines of code, everything works until self.lm_optimizer tries to perform lm_optimizer.step(). The code errors with:

RuntimeError: Tensors of the same index must be on the same device and the same dtype except `step` tensors that can be CPU and float32/64 notwithstanding

I'm currently pretty lost as to what the bug might be. I don't think I've changed any code which would be relevant to lm_optimizer. If this is something that you recognize/notice, I would very much appreciate it!

I'm sorry I don't think I have run into any issue like this before. Do you think it might be because somewhere during loading and checkpointing the device is messed up? Is it that only the lm_optimizer is crashing but critic_optimizer works fine? It sounds weird to me too.

Yup, critic_optimizer works but lm_optimizer crashes when stepping. Might be worth noting that we are currently trying to use bfloat16? But so far, mostly unsure about whats going on currently. Might check and see if this replicates on this repo itself