metadriverse/policydissect

How can I quickly reference your policy in a new scenario?

Opened this issue · 1 comments

How can I quickly implement your policy in a new scenario?
I am looking to apply your policy in a new scenario, but I noticed that your code loads a pre-trained reinforcement learning model. How can I rapidly train a model that matches your policy in a new environment? For instance, how can I train a model that can be directly used with your ppo_inference_tf function?

You can apply the method to any well-trained MLP policies after training. Usually, the weights of this MLP policy can be stored locally in TF or Torch format. We then convert them into numpy format like what play/torch_to_numpy.py does, so we can reconstruct the policy with numpy. That is what ppo_inference_tf does. In the numpy policy execution process, the neuron activations are recorded, enabling further analysis. Thus, if you want to use ppo_inference_tf , you have to ensure your policy is a MLP policy trained with TF and convert it to numpy format. An example code is as follows given a policy trained with rllib:

remove_value_network = True
path = "expert_weights.npz"

with open(ckpt_path, "rb") as f:
    data = f.read()
unpickled = pickle.loads(data)
worker = pickle.loads(unpickled.pop("worker"))
if "_optimizer_variables" in worker["state"]["default_policy"]:
    worker["state"]["default_policy"].pop("_optimizer_variables")
pickled_worker = pickle.dumps(worker)
weights = worker["state"]["default_policy"]
if remove_value_network:
    weights = {k: v for k, v in weights.items() if "value" not in k}
np.savez_compressed(path, **weights)
print("Numpy agent weight is saved at: {}!".format(path))

Actually, you can record the neuron activation information in your torch or TF policy directly by logging information in the forward function supposing it is a torch policy. In this way, the numpy conversion process can be removed. We additionally convert them into numpy format for aligning the inference process for both TF and Torch policies.