haritheja-e/robot-utility-models

Would it be possible to train with different configurations?

Closed this issue · 1 comments

  1. Assuming I wanted to train for different tasks with a different gripper, for example, the one from a franka emika panda. Would it be possible to train and reproduce the method if I had a dataset with all the correct information (30fps video, 6D pose and gripper opening angle)? Or is the network completely dependant on the gripper type used here?

  2. Do I need the r3d files for training or do I only need the MP4 video with the appropriate labels?

Hello!

  1. If you believe you have a sufficient dataset, then you should certainly be able to use our code to train a VQ-BeT policy on your data.
    a. One caveat with the gripper is that we are using a ResNet34 image encoder backbone that was pre-trained using MoCo-V3 on play data collected with our gripper: Dobb-E encoder (imitation-in-homes/configs/model/resnet_dobb-e.yaml). So, you may want to swap it out with your image encoder of choice (we haven't tried using our encoder on a different gripper, but using a ResNet with Imagenet pre-trained weights could work).
    b. As long as the data is in the correct format, having .mp4 videos with appropriate labels is sufficient for training:
    data_format.
    The rgb_rel_videos_exported.txt file can be empty (it just needs to exist). However, our dataloader takes in quaternion, which is then converted to axis-angle, so our code expects the rotation labels to be in quaternion.
  2. Having MP4 files with the appropriate labels is sufficient (.r3d is just the file extension of the zip file created by the Record3D iPhone app). You would just need to ensure that the labels are in the same format as in our dataset — xyz and quats (relative to starting frame), and absolute gripper value. I have attached a sample file below. r3d files are not necessary, but the file titled r3d_files.txt is necessary in the data directory and can be generated with the bash script attached.
    labels.json
    get_txt.txt

Furthermore, training a VQ-BeT policy is a 2-step process:
First, set the dataset roots in configs/env_vars/env_vars.yaml. Then,

  1. Train the residual vector quantization (RVQ) on the actions: set include_task to the task name (e.g. Bag_Pick_Up), in train_rvq.yaml and train by python train.py —config-name=train_rvq.
  2. Train the VQ-BeT policy: set include_task and alsovqvae_load_dir in train_rvq.yaml with the checkpoint.pt path of the rvq you just trained. Then, train by python train.py —config-name=train_vqbet. If you have multiple GPUs, you can also train in parallel by accelerate launch --config_file configs/accelerate/accel_cfg.yaml —config-name=train_vqbet, changing num_processes in configs/accelerate/accel_cfg.yaml accordingly.

If you have any further questions, let us know.