CTM Loss Not Actually Used in the Student Training?
Zhendong-Wang opened this issue · 14 comments
Refer to the code here: https://github.com/Aaditya-Prasad/consistency-policy/blob/4cc328ebca2299f25c7b96c8bc04866d69cfe210/consistency_policy/student/ctm_policy.py#L482C13-L495C58
# t -> s
pred = self._forward(self.model, noise_traj, times, stops,
local_cond=local_cond, global_cond=global_cond)
# u -> s
with torch.no_grad():
target = self._forward(self.model_ema, u_noise_traj, u_times, stops,
local_cond=local_cond, global_cond=global_cond)
with torch.no_grad():
start = torch.tensor([self.noise_scheduler.time_min], device = trajectory.device).expand(times.shape)
pred = self._forward(self.model_ema, pred, stops, start,
local_cond=local_cond, global_cond=global_cond)
target = self._forward(self.model_ema, target, stops, start,
local_cond=local_cond, global_cond=global_cond)
loss = Huber_Loss(pred, target, delta = self.delta, weights=weights)
total_loss["ctm"] = loss * self.losses["ctm"]
The pred
and target
are both insidetorch.no_grad()
, and this loss variable is not going to have grad_fn
for backpropagation.
When I run the code, in line
If I simply change it tolosses: [["ctm"], [1]]
, it will raise no grad_fn
error.The error could also be easily reproduced by
import torch
import torch.nn as nn
import torch.nn.functional as F
dim = 10
x = torch.randn((1, dim))
y = torch.zeros_like(x)
model = nn.Linear(dim, dim)
model_ema = nn.Linear(dim, dim)
pred = model(x)
print(f'pred: {pred}')
with torch.no_grad():
pred = model_ema(pred)
print(f'pred: {pred}')
loss = F.mse_loss(pred, y)
print(f'loss inside of no_grad: {loss}')
loss.backward()
print(f'backward successful!')
>>> python test_grad.py
pred: tensor([[-0.2391, 0.0476, -1.7864, 0.4620, -0.3655, 1.0930, -0.1256, 0.9729,
0.6107, -1.2118]], grad_fn=<AddmmBackward0>)
pred: tensor([[ 0.1322, -0.6358, 0.1975, 0.2054, -0.7371, 0.4366, -0.5071, 0.1329,
0.3269, -0.1779]])
loss inside of no_grad: 0.16503188014030457
Traceback (most recent call last):
File "/home/zhendongw/research/fast_dp_robotics_dev/test_grad.py", line 19, in <module>
loss.backward()
File "/home/zhendongw/miniconda3/envs/umi-gpu/lib/python3.9/site-packages/torch/_tensor.py", line 492, in backward
torch.autograd.backward(
File "/home/zhendongw/miniconda3/envs/umi-gpu/lib/python3.9/site-packages/torch/autograd/__init__.py", line 251, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
yes, I (actually Kevin) realized this as well, and fixed this in the recent commit.
I am unsure when this happened, because I actually ablated the ctm loss (by running just CTM with weight 1, no DSM) back when I was producing results, and back then it performed fine. Thus I am also re-running Consistency Policy on our sim tasks to check performance.
see 3faa563
Thanks for the response! Then could you help verify the performance of the new code? Thanks!
Yes. I am distilling Consistency Policy again, starting with the Square and ToolHang tasks since these were the hardest. This will take some time since I am also training the teachers from scratch again.
I am evaluating the same way as in our paper (take the highest performing checkpoint and evaluate it 200x on a fixed seed) and will confirm if performance matches/exceeds what we reported.
Testing has been taking some time because I've been having issues on the cluster I use which forces me to restart runs.
On square, 50 epochs of EDM -> 300 epochs of Consistency Policy training led to 91% test mean score and 100% train mean score over 200 evaluations. This was with distillation weight 8 and dropout probability of .3. Increasing the distillation loss weight and adding more regularization (or just training for longer) would likely improve performance, but I did not keep testing this because I didn't do much tuning for the Diffusion Policy baselines either.
For comparison, training EDM or DDiM for 250 epochs each and sampling with 1-step (which means 2 forward passes for EDM using the Heun scheduler) nets both 0% success rates. I would report the 350 epoch results (to match epochs with the Consistency Policy student + teacher) but those runs haven't gotten there yet.
The toolhang teacher is still training.
Thanks for the update!!! Could you share your config of training pusht, square, tool_hang, and transport? Will you share your pretrained checkpoints? I found that these tasks have big variance in evalution, and changing a seed sometimes make the performance change. I mean like 200 sometimes is not big enough.
BTW, do you know whether Consistency Distillation work in DDPM case?
I've found this variance as well -- let me think about the best way to do this. Few points:
- We did not train or test transport. We trained/tested lift, can, square, tool hang, pusht, kitchen.
- I didn't do a good job with keeping old runs/results/configs; this was completely my mistake and part of the reason for this whole issue. This was my first real research project and I'm carrying lessons from this forward.
- Policy configs did not vary too much between tasks. I tuned configs for Square and ToolHang since these were the hardest tasks, and I used the same policy values for the other tasks.
- I'll upload the teacher, baseline, and student ckpts + cfgs for Square (which I just trained and reported results for). I'll do the same for toolhang as it gets done. I'll also try to evaluate toolhang on more rounds (maybe 500 instead of 200); this will likely take a while though since this is the longest task. If needed, I can retrain policies for the other tasks, but again I think these are the most indicative.
DDPM adds stochasticity to the backwards diffusion process by integrating little bits of noise into each denoising step. This breaks Consistency Distillation, since it requires not only the same marginal distribution at t=0 (which DDPM satisfies) but also deterministic trajectories between t=T and t=0 (which DDPM breaks).
For now, this is a link to a zip file containing teacher, baseline, and student ckpts and cfgs for square: https://drive.google.com/file/d/1jG7HjDPu4qXs0tMRESTKJxeWpqgR-mL6/view?usp=drive_link.
Note that I was incorrect above about the student dropout and distillation weight; the config in the zip is correct.
I was hoping to use git lfs for this but the checkpoints are over the storage limit. When I have time, I'll pick a more permanent solution like adding these to the website.
Thanks for providing these!!! I will have a look.
Hey Aaditya,
I am trying to reproduce Consistency Policy from my end, and I first tried to pretrain the teacher model. This is the curve that I obtained. They all kind work much worse than the reported values and the values in Diffusion Policy. Especially for tool_hang
, it should reach somewhere about 0.83 success rate but now it is 0.46. Do you have insights on this?
I used the config configs/edm_square.yaml
and configs/edm_th.yaml
for square and tool_hang, and copy the setting of square to other tasks, such as pusht
, square_mh
, transport_mh
and transport_ph
. I also carefully check the crop size of images to be consistent with the previous Diffusion Policy configs.
When you say the values in Diffusion Policy, are you referring to the DDPM solver? Diffusion Policy has no reported results for the EDM solver. Diffusion Policy reports a .5 success rate on tool hang with their testing methodology, while we report a .79 success rate (again, for DDPM). However, we did get EDM to ~.8 success rate (at 50 evaluations, not a rigorous test).
On EDM success rate:
The main thing you should do is train for longer. Here are the success curves from an older edm run, again with test_mean_score given by 50 evaluations. Training took a long time, you can see that we only reached ~.8 success rate after 300-350 epochs. *edited, see note at bottom
I'm currently checking distillation with a 100 epoch EDM teacher (which achieved 0.0 success rate with 1 bin, since I wanted to sanity check EDM-only performance). I'll keep this thread updated with what length of training was required for a teacher that led to good distillation performance; I've found that toolhang requires a lot of time to train both the teacher and the student.
Also, just wanted to note that I used a 50 epoch EDM teacher to distill the Square task and that seemed to perform well, though of course teachers trained for longer might help by a few percentage points.
*I originally thought that run used 100 bins instead of 80, this was incorrect, it used 80 bins with a 3rd order solver, i.e. 3 forward passes per bin (compared to 2 forward passes for Heun). This might have contributed to the higher success rates. You can try this if you'd like, though I'd recommend just training for longer + testing distillation.
Thanks for sharing the curve! All I am now doing is just to reproduce the results from the Consistency Policy paper. I used all the configs provided from the Github (the edm_th.yaml
is by default trained by 400 epochs), usually that should reproduce the table results. For example, for tool_hang, it should be around 0.79 ,and I agree there might be some variance.
As for the values of Diffusion Policies, I mean this table, with 0.73 on tool_hang, 0.84 on square_mh and 0.69 on transport_mh. The peak performance value might not useful due to they even use less number seeds like 28. If EDM works worse than the original DDPM, then there is no reason we change to EDM as the teacher models. You could see from my previous figure, the transport_mh and square_mh of EDM are also working bad.
EDM is used as the teacher model because it learns the PFODE, i.e. it doesn't inject noise into the backwards diffusion process. DDPM is not suitable as a teacher model because it does not have deterministic trajectories. We make no claim in the paper to EDM's performance as compared to DDPM; EDM is only introduced so we can use distillation.
We do make the claim that training a teacher EDM and then distilling into a Consistency Policy provides single and 3-step results as in the table. Note that this is seperate from the success rates you will see from EDM: EDM's training is seperate from its sampler. To make teacher training faster, I have been training "1=step EDM's", which will reach 0.0 success rate even after 200-300 epochs. You can increase the number of bins, the order of the sampler, or other noise scheduler hyperparameters to get higher success rates out of EDM, which can be helpful when you're comparing different checkpoints to distill from.