A general purpose training and inference library for Consistency Models introduced in the paper "Consistency Models" by OpenAI.
Consistency Models are a new family of generative models that achieve high sample quality without adversarial training. They support fast one-step generation by design, while still allowing for few-step sampling to trade compute for sample quality. They also support zero-shot data editing, like image inpainting, colorization, and super-resolution, without requiring explicit training on these tasks.
Note: The library is the code base for implementing consistency models. It can be used to train all kinds of consistency models.
There are some different design choices between the original consistency model and the improved CT in the paper "Improved techniques for training consistency models".
Before we use the consistency model, we should train the consistency model. Specifically, given a data point
Afterwards, we train the consistency model by minimizing its output differences on the pair
def loss(self, state, action, z, t1, t2, ema_model=None, weights=torch.tensor(1.0)):
x2 = action + z * t2
if self.action_norm:
x2 = self.max_action * torch.tanh(x2)
x2 = self.predict_consistency(state, x2, t2)
with torch.no_grad():
x1 = action + z * t1
if self.action_norm:
x1 = self.max_action * torch.tanh(x1)
if ema_model is None:
x1 = self.predict_consistency(state, x1, t1)
else:
x1 = ema_model(state, x1, t1)
loss = self.loss_fn(x2, x1, weights, take_mean=False)
return loss
Starting from an initial random noise
def sample(self, state):
ts = list(reversed(self.t_seq))
action_shape = list(state.shape)
action_shape[-1] = self.action_dim
action = torch.randn(action_shape).to(device=state.device) * self.max_T
if self.action_norm:
action = self.max_action * torch.tanh(action)
action = self.predict_consistency(state, action, ts[0])
for t in ts[1:]:
z = torch.randn_like(action)
action = action + z * math.sqrt(t**2 - self.eps**2)
if self.action_norm:
action = self.max_action * torch.tanh(action)
action = self.predict_consistency(state, action, t)
action.clamp_(-self.max_action, self.max_action)
return action
def predict_consistency(self, state, action, t) -> torch.Tensor:
if isinstance(t, float):
t = (
torch.Tensor([t] * action.shape[0], dtype=torch.float32).to(action.device).unsqueeze(1)
)
action_ori = action
action = self.model(action, t, state)
sigma_data = torch.Tensor(0.5)
t_ = t - self.eps
c_skip_t = sigma_data.pow(2) / (t_.pow(2) + sigma_data.pow(2))
c_out_t = sigma_data * t_ / (sigma_data.pow(2) + t.pow(2)).pow(0.5)
output = c_skip_t * action_ori + c_out_t * action
if self.action_norm:
output = self.max_action * torch.tanh(output)
return output