ohayonguy/PMRF

What's the intuition behind stratified_uniform?

Closed this issue · 12 comments

Hi, What's the intuition behind stratified_uniform timestep scheduler?

It divides the interval [0,1] into N subpopulations and samples uniformly from each, where N is the batch size. For large N, this stabilizes and improves training as the loss “sees” samples from the entire interval [0,1] at every training step. For example, suppose that N=2. Then, if you sample uniformly from [0,1], you might end up with two samples from [0.5,1]. With stratified sampling, you will always end up with one sample from [0.5,1] and one sample from [0,0.5]. Now imagine you do this for N=256. Does that make sense?

check https://en.m.wikipedia.org/wiki/Stratified_sampling

Hi @ohayonguy ,thank you for your explanation. But for my understanding it help on small N, other than large N, since for large N, the samples should be uniformly distributed in [0, 1] already.

You are right that it helps small N, but it will also significantly affect large N such as N=256. Try drawing 256 samples from a uniform distribution, and then try stratified sampling. Plot the histograms of the results. You'll see that stratified sampling leads to a much more uniform histogram.

Here is some code you can try:

import torch
import matplotlib.pyplot as plt
plt.style.use('default')

def stratified_uniform(bs, group=0, groups=1, dtype=None, device=None):
    if groups <= 0:
        raise ValueError(f"groups must be positive, got {groups}")
    if group < 0 or group >= groups:
        raise ValueError(f"group must be in [0, {groups})")
    n = bs * groups
    offsets = torch.arange(group, n, groups, dtype=dtype, device=device)
    u = torch.rand(bs, dtype=dtype, device=device)
    return (offsets + u) / n

num_gpus = 1
bs = 1024
bs_per_gpu = bs

a_s = []
for group in range(0, num_gpus):
    a_s.append(stratified_uniform(bs_per_gpu, group, num_gpus))
a = torch.cat(a_s)
b = torch.rand(bs)
fig, ax = plt.subplots(1, 2)
ax[0].hist(a, density=True, bins=20)
ax[1].hist(b, density=True, bins=20)
plt.show()

And here is the output:
image

I hope this helps

Thank you very much for the detailed answer! So, stratified uniform is actually a more uniformed "uniform sampling" technique. How much improvement do you observe compared to regular uniform sampling? And how does it compare to logit-normal?

If I remember correctly, I didn't find significant differences between logit-normal and uniform, but stratified uniform improved training and performance considerably. Unfortunately I didn't log these results properly. I just played with several schedulers on small problems (e.g., simple denoising), and then went with the scheduler that performed best (stratified uniform) in the rest of the tasks. Please let me know if you test these things out!

Sure! I'm currently use uniform sampling, and plan to try stratified uniform sampling. My task is trained with Real-ESRGAN's second order degradation, I'll show you results once it is done.

By the way, I noticed that you are using a very large batch size (256) and a relatively high learning rate (5e-4). I’m currently using a batch size of 32 and a learning rate of 2e-4, and I’ve found that convergence is relatively slow. Do you have any results regarding the network's convergence efficiency? For example, how many steps later the network could produce reasonable good images.

Unfortunately I didn't try smaller batch sizes. I tried to make as little changes as possible to the training paradigm of the HDiT architecture (https://crowsonkb.github.io/hourglass-diffusion-transformers/).
With a batch size of 256, PMRF starts producing good results quite fast for face images (If I remember correctly, after about 50 epochs).

I think that diffusion and flow matching models struggle to converge with small batch sizes. I am also not sure how the HDiT architecture would perform in these scenarios. It would be interesting to see if a smaller batch size is sufficient to achieve good performance with PMRF.

Dear @ohayonguy,

I've trained a PMRF model using a batch size of 32 for 200k iterations. I'd like to share my validation results in the endo of training with you:
image
I also show the results in the very beginning of training:
image

From the results, it appears that PMRF is quite challenging to train effectively with a small batch size. I'm wondering if these results align with your expectations? Alternatively, there might be some discrepancies in my implementation since I adapted your code into my own codebase.

Could you please help me understand if this is the expected behavior, or if there might be potential issues in my reproduction of your method?

Hi @Luciennnnnnn . While flow models should be trained with a larger batch size, these results do not make sense to me. Are you adding a small amount of noise to the MMSE outputs? (as described in Algorithm 1 in our paper). Such a noise should be added both for training and inference.

Hi @ohayonguy , I added small level of noise to MMSE outputs as you suggested though we see the result is worse. I'll check it further and try your code directly, thank you for your kindly supporting.

@Luciennnnnnn If noise is not added, then PMRF won't work. Otherwise, I find it hard to believe that the image will remain blurry (similar to the MMSE estimate), even if you use a small batch size. So I suggest that you try our code :)
But regardless, I also find it hard to believe that PMRF will work effectively for a batch size of 32. Maybe I am wrong.