Compute the Fréchet Inception Distance (FID) between two distributions from different data sources (tensor, generator model, or dataset). Raw image data is passed through an embedding model to compute ‘clean’ features. Check the cleanfeatures documentation for a list of available embedding models (default: InceptionV3). Partially builds on code from bioinf-jku/TTUR.
Apart from the conventional FID formulation, this repository implements Weighted FID (wFID) as proposed in Towards Mode Balancing of Generative Models via Diversity Weights (Berns et al., 2023, ICCC). Weight vectors can be passed alongside the data sources to quantify the weighting of individual data examples. Statistics are then computed on the weighted distribution.
- torch (Pytorch)
- numpy
- scipy
- cleanfeatures (sebastianberns/cleanfeatures)
Assuming that the repository is available in the working directory or Python path.
from cleanfid import FID # 1.
measure = FID('path/to/model/checkpoint/') # 2.
fid = measure.score(data_1, data_2) # 3.
- Import the main class.
- Create a new instance, providing a directory path of an embedding model. This can be either the place the model checkpoint is already saved, or the place it should be downloaded and saved to.
- Compute the FID, given two data sources (tensor, generator model, or dataset).
A typical use case is to evaluate model performance during training. For this, it is most efficient to first calculate the mean and covariance of the dataset before the training starts and save the statistics. Then, during training, compute the model statistics. Finally, calculate the FID as the Fréchet distance between the data and model distributions.
from cleanfid import FID
measure = FID('path/to/model/checkpoint/')
num_samples = 50_000
batch_size = 128
# Before training, compute dataset mean and covariance
dataset_stats = measure.compute_feature_statistics(self.dataloader.dataset,
num_samples=num_samples, batch_size=batch_size)
# [...]
# During training, compute the model statistics
model_stats = measure.compute_feature_statistics(generator, z_dim=generator.z_dim,
num_samples=num_samples, batch_size=batch_size)
# Then evaluate the FID
fid = measure.frechet_distance(*dataset_stats, *model_stats)
measure = FID(model_path='./models', model='InceptionV3', device=None, **kwargs)
model_path
(str or Path object, optional): path to directory where model checkpoint is saved or should be saved to. Default: './models'.model
(str, optional): choice of pre-trained feature extraction model. Options:- CLIP
- DVAE (DALL-E)
- InceptionV3 (default)
- Resnet50
cf
(CleanFeatures, optional): an initialized instance of CleanFeatures. If set, all other arguments will be ignored.device
(str or torch.device, optional): device which the loaded model will be allocated to. Default: 'cuda' if a GPU is available, otherwise 'cpu'.kwargs
(dict): additional model-specific arguments passed on tocleanfeatures
. See below.
clip_model
(str, optional): choice of pre-trained CLIP model. Options: RN50, RN101, RN50x4, RN50x16, RN50x64, ViT-B/32, ViT-B/16, ViT-L/14 (default), ViT-L/14@336px
The class provides three methods to process different types of input: a data tensor, a generator network, or a dataloader.
All methods return a tensor of embedded features [B, F], where F is the number of features.
Calculate FID between two distributions from two data sources.
fid = measure.score(input1, input2, weights1, weights2, **kwargs)
input1
,input2
(Tensor or nn.Module or Dataset): data sources, can be different types (see above)weights1
,weights2
(ndarray, optional): 1-D array of observation vector weights or probabilitieskwargs
(dict): additional data source-specific arguments passed on to the correspondingcleanfeatures
method. See below.
- Tensor of samples (
torch.Tensor
):batch_size
(int, optional): Batch size for sample processing. Default: 128
- Generator model (
torch.nn.Module
):z_dim
(int): Number of generator input dimensionsnum_samples
(int): Number of samples to generate and processbatch_size
(int, optional): Batch size for sample processing. Default: 128
- Dataset (
torch.utils.data.Dataset
):num_samples
(int): Number of samples to generate and processbatch_size
(int, optional): Batch size for sample processing. Default: 128num_workers
(int, optional): Number of parallel threads. Best practice is to set to the number of CPU threads available. Default: 0shuffle
(bool, optional): Indicates whether samples will be randomly shuffled or not. Default: False
Calculate Fréchet distance between two multi-variate normal distributions.
distance = measure.frechet_distance(mean1, cov1, mean2, cov2, eps=1e-6)
mean1
,mean2
(ndarray): vectors of distribution means [N]cov1
,cov2
(ndarray): distribution covariance matrices [N x N]eps
(float, optional): small number for numerical stability
Calculate statistics of multi-variate normal distributions. Return tuple of statistics: mean and covariance matrix.
mean, cov = measure.compute_statistics(features, weights)
features
(ndarray): Matrix of data features where rows are observations and columns are variablesweights
(ndarray, optional): 1-D array of observation vector weights or probabilities
Compute features given a data source (can be of different types), handled by cleanfeatures
. Return matrix of data features where rows are observations and columns are variables.
features = measure.compute_features(input, **kwargs)
input
accepts different data types:- (Tensor): data matrix with observations in rows and variables in columns. Processed by
cleanfeatures.compute_features_from_samples()
- (nn.Module): pre-trained generator model with tensor output [B, C, W, H]. Processed by
cleanfeatures.compute_features_from_generator()
- (Dataset): dataset with tensors in range [0, 1]. Processed by
cleanfeatures.compute_features_from_dataset()
- (Tensor): data matrix with observations in rows and variables in columns. Processed by
kwargs
(dict): additional data source-specific arguments passed on to the correspondingcleanfeatures
method. See above.
Compute the statistics of features given a data source (can be of different types), handled by cleanfeatures
. Return tuple of statistics: mean and covariance matrix.
Combines the previous two methods into one, first calling compute_features
, followed by compute_statistics
.
features = measure.compute_feature_statistics(input, **kwargs)
input
accepts different data types:- (Tensor): data matrix with observations in rows and variables in columns. Processed by
cleanfeatures.compute_features_from_samples()
- (nn.Module): pre-trained generator model with tensor output [B, C, W, H]. Processed by
cleanfeatures.compute_features_from_generator()
- (Dataset): dataset with tensors in range [0, 1]. Processed by
cleanfeatures.compute_features_from_dataset()
- (Tensor): data matrix with observations in rows and variables in columns. Processed by
kwargs
(dict): additional data source-specific arguments passed on to the correspondingcleanfeatures
method. See above.
Heusel, M., Ramsauer, H., Unterthiner, T., Nessler, B., & Hochreiter, S. (2017). GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium. Advances in Neural Information Processing Systems, 30.