Neural Distributed Source Coding

This repository contains the official implementation of the paper Neural Distributed Source Coding.

Authors

Jay Whang, Anish Acharya, Hyeji Kim, Alexandros G. Dimakis

Abstract

Distributed source coding is the task of encoding an input in the absence of correlated side information that is only available to the decoder. Remarkably, Slepian and Wolf showed in 1973 that an encoder that has no access to the correlated side information can asymptotically achieve the same compression rate as when the side information is available at both the encoder and the decoder. While there is significant prior work on this topic in information theory, practical distributed source coding has been limited to synthetic datasets and specific correlation structures. Here we present a general framework for lossy distributed source coding that is agnostic to the correlation structure and can scale to high dimensions. Rather than relying on hand-crafted source-modeling, our method utilizes a powerful conditional deep generative model to learn the distributed encoder and decoder. We evaluate our method on realistic high-dimensional datasets and show substantial improvements in distributed compression performance.

Setup

Environment

  • Ubuntu Bionic 18.04.5 LTS
  • Python 3.8.8, OpenMPI-2.1.1, NCCL-2.8.4
  • PyTorch 1.8.0+cu111, Tensorflow 2.4.1, Horovod 0.21.3
  • CUDA 11.0, cuDNN 8.0.5, cudatoolkit 11.0.221

Installation

We use Anaconda for managing Python environment.

conda env create --file environment.yml
conda activate neural_dsc
pip install torch==1.8.0+cu111 torchvision==0.9.0+cu111 torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html
HOROVOD_GPU_OPERATIONS=NCCL pip install --upgrade --no-cache-dir "horovod[pytorch]"

Experiment 1: Distributed Image Compression

Activate environment

conda activate neural_dsc

Prepare data (CelebA-HQ 256x256)

Download celeba-tfr.tar inside data/ directory, then run the following command:

python run_top.py prep celebahq256

Train VQ-VAE

Repeat the following with different --codebook_bits argument to control the total rate.

# Joint VQ-VAE
horovodrun -n 2 python -O run_top.py train --dataset celebahq256 --arch vqvae_top_8x --ch_latent 1 \
  --root_dir checkpoints/celeba256_vqvae_joint_4bit --dec_si True --enc_si True  --codebook_bits 4

# Distributed VQ-VAE
horovodrun -n 2 python -O run_top.py train --dataset celebahq256 --arch vqvae_top_8x --ch_latent 1 \
  --root_dir checkpoints/celeba256_vqvae_dist_4bit --dec_si True --enc_si False --codebook_bits 4

# Separate VQ-VAE
horovodrun -n 2 python -O run_top.py train --dataset celebahq256 --arch vqvae_top_8x --ch_latent 1 \
  --root_dir checkpoints/celeba256_vqvae_separate_4bit --dec_si False --enc_si False --codebook_bits 4

Evaluate VQ-VAE

horovodrun -n 2 python run_top.py eval --batch_size 250 \
  checkpoints/celebahq256_vqvae_{joint,dist,separate}_4bit/ckpt_ep=020_step=0016880.pt

Plot rate-distortion curves from eval results

All generated plots will be stored in the folder paper/.

python plot_rd_curves.py

Experiment 2: Distributed SGD

Activate environment

conda activate neural_dsc

Prepare data

# Following command may take a while to finish due to slow download speed.
python run_mnist_grad.py prep mnist

Gather gradients

python -O run_mnist_grad.py gather_gradients --out_dir checkpoints/mnist_grad_data

Train VQ-VAE

# Joint VQ-VAE
python -O run_mnist_grad.py train_vqvae --grad_dump checkpoints/mnist_grad_data/grads.pt --d_latent 40 --codebook_bits 8 \
  --enc_si True  --dec_si True  --root_dir checkpoints/mnist_grad_vqvae_joint_40d_8bits

# Distributed VQ-VAE
python -O run_mnist_grad.py train_vqvae --grad_dump checkpoints/mnist_grad_data/grads.pt --d_latent 40 --codebook_bits 8 \
  --enc_si False --dec_si True  --root_dir checkpoints/mnist_grad_vqvae_dist_40d_8bits

# Separate VQ-VAE
python -O run_mnist_grad.py train_vqvae --grad_dump checkpoints/mnist_grad_data/grads.pt --d_latent 40 --codebook_bits 8 \
  --enc_si False --dec_si False --root_dir checkpoints/mnist_grad_vqvae_separate_40d_8bits

Evaluate

for seed in $(seq 1 20); do
    python run_mnist_grad.py eval checkpoints/mnist_grad_vqvae_{joint,dist,separate}_40d_8bits/ckpt_ep=500_step=0391000.pt --seed $seed;
done

Plot

python run_mnist_grad.py plot checkpoints/mnist_grad_vqvae_{joint,dist,separate}_40d_8bits/ckpt_ep=500_step=0391000.pt \
  --seeds 1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20 --out_dir paper --labels Joint,Distributed,Separate