Tracking: development around Rectified Flow
yqzhishen opened this issue · 3 comments
We are introducing Rectified Flow, a new ODE-based generative model, to this repository (in RectifiedFlow
branch). The differences between Rectified Flow and the currently used DDPM will result in some API changes. The testing and adaptation may take one or more weeks. Since we are still in the early stage and the code is not well-organized, the APIs and configurations on the branch may change over time without any backward compatibility. This issue is raised mainly to inform those who are testing and researching on that branch with the changes (and possible migration steps).
TODOs
- Initial implementation with temporary configurations
- Testing, verifying and comparing
- Migrate inference APIs and corresponding configurations to continuous acceleration profile: int64 depth to float32 depth, speedup to steps
- Re-organize code of Rectified Flow, and adapt DDPM to continuous acceleration (convert to discrete settings)
- ONNX exporter adaptations and fixes
- More tests to determine proper default configuration
- Documentation, ready to merge
The first stage of refactoring and migration to continous acceleration has been finished.
Rectified Flow models can still run with full compatibility, but the following configurations will no longer take effects on Rectified Flow at training time (they will be converted automatically at inference time if the config file does not contain the new keys):
- timesteps: replaced by time_scale_factor, and can be float
- K_step: replaced by T_start (between 0 and 1; 0 means K_step = timesteps, 1 means K_step = 0)
- K_step_infer: replaced by T_start_infer (between 0 and 1)
- diff_speedup: replaced by sampling_steps (meaning the actual steps of sampling)
Inference API (scripts/infer.py) has been changed as follows:
- --depth now accepts a float value between 0 and 1
- --speedup is removed and replaced by --steps
ONNX exporting is supported now, but some early Rectified Flow models will result in KeyError. Please manually add the missing keys into the configuration file.
The second stage of refactoring has been finished in dc6896b.
Due to adjustment in the state dict, previous model trained on this branch before the commit should be migrated with the following code:
import collections
import pathlib
from typing import Dict, Any
import click
import torch
@click.command()
@click.argument(
'in_ckpt', type=click.Path(
exists=True, dir_okay=False, file_okay=True, readable=True, path_type=pathlib.Path
)
)
@click.argument(
'out_ckpt', type=click.Path(
exists=False, dir_okay=False, file_okay=True, writable=True, path_type=pathlib.Path
)
)
def migrate_reflow(in_ckpt: pathlib.Path, out_ckpt: pathlib.Path):
ckpt = torch.load(in_ckpt, map_location='cpu')
in_state_dict: Dict[str, Any] = ckpt['state_dict']
out_state_dict = collections.OrderedDict()
for k, v in in_state_dict.items():
if 'denoise_fn' in k:
out_state_dict[k.replace('denoise_fn', 'velocity_fn')] = v
elif 'spec_min' in k or 'spec_max' in k:
continue
else:
out_state_dict[k] = v
torch.save({'category': ckpt['category'], 'state_dict': out_state_dict}, out_ckpt)
if __name__ == '__main__':
migrate_reflow()
The following configuration keys are renamed:
- diffusion_type: RectifiedFlow -> diffusion_type: reflow
- diff_decoder_type -> backbone_type
- diff_loss_type -> main_loss_type
- lognorm loss now has its own switch: main_loss_log_norm (only for Rectified Flow models)