/particlediffusion

Master's Project - Stable Diffusion Particle Based Inference

Primary LanguageJupyter Notebook

Repulsion API

Repulsion is done in repulse.py. Use -h on scripts to find argument options.

Basic usage:

python repulse.py --method repulsive --model averagedim --strength 100

The number of particles and prompt can also be specified as arguments or within the script.

Sample generations from evaluation for each repulsion method is contained in data/final/. Be warned of possible unsafe images.

Code glossary

averagedim: latent channel average method

cnn16: random CNN method (with 16 dimensional output embedding)

ro3: rule of thirds

vgg_noise: style classifier (vgg model with conditioning on the noise level of latent)

vgg_noisero3: style classifier with rule of thirds

Repo guide

Sampling utils

src
├───denoise_utils.py: Particle diffusion methods
├───sampling_utils.py: Step methods including repulsive steps
├───score_utils.py: Score processing and computation methods
├───steps.py: API to add score or repulsion steps to certain noise levels in diffusion
├───embedding.py: Embedding models to compute feature space to repulse in
├───kernel.py: RBF kernel methods

Evaluation utils

src
├───metric_utils.py: Calculation of FID, location and style features and metrics on image datasets
├───visualise.py: tools for decoding latent to image space and visualising results of diffusion

Style classifier

style.ipynb: experimentation notebook

style_generator.py: generate latents data for training style classifer

style_train.py: Script to train style classifier model with specified configuration

src/train.py: Training method

src/datasets.py: Dataset classes for loading style data generated by style_generator.py

Evaluation

data_generator.py: Generate each dataset for evaluation specified by --mode

metrics.py, metrics.ipynb, metrics_2.ipynb: metrics calculations and figure plotting

repulse_report.sh: Generate figures for repulsion at different repulsion strengths for report

fid/: The FID score is computed on the generated images using the [the PyTorch port](https://github.com/mseitzer/pytorch-fid).

figures/: figures for report

Experimentation notebooks

experiments/: Much of the exploration of sampling techniques was done in these Python notebooks