eifuentes/swae-pytorch

CUDA issue. Exception: RuntimeError in _sliced_wasserstein_distance

Closed this issue · 5 comments

Thank you for creating a PyTorch version of the novel SWAE project.

Encountered an exception during runtime in line 48 of _sliced_wasserstein_distance

encoded_projections = encoded_samples.matmul(projections.transpose(0, 1))

with device being set to "cuda" in main()

Expected object of type torch.cuda.FloatTensor but found type torch.FloatTensor for arguement #2 'mat2'.

encoded_samples is a <Tensor, len() = 500> with is_cuda: True whereas
projections is a <Tensor, len() = 50> with is_cuda: False

System:
PyTorch 0.4.0
CUDA 9.0
Python 3.6.3 Anaconda/Intel
VS Code 1.25.1
Ubuntu Linux 18.04 x64

Regards.

thanks for logging the bug, oversight on my part. checkout #4 in the meantime. will test soon.

Proceeded to modify

trainer.py

to try and address this issue. Attached in the following post for your consideration and review. After these change, the code now runs without error message to completion: 30 epochs.

Added an additional argument to sliced_wasserstein_distance to pass the device type and propagated it.

def eval_on_batch(self, x):
        . . .
        w2 = float(self.weight_swd) * sliced_wasserstein_distance(
                                        z,
                                        self._distribution_fn,
                                        self.num_projections_,
                                        self.p_,
                                        self._device)
def sliced_wasserstein_distance(encoded_samples,
                                distribution_fn=rand_cirlce2d,
                                num_projections=50,
                                p=2,
                                device='cpu'):
    . . . 
    z = distribution_fn(batch_size).to(device)
    # approximate wasserstein_distance between encoded and prior distributions
    # for average over each projection
    swd = _sliced_wasserstein_distance(
                encoded_samples,
                z,
                num_projections,
                p,
                device)
def _sliced_wasserstein_distance(encoded_samples,
                                 distribution_samples,
                                 num_projections=50,
                                 p=2,
                                 device='cpu'):
    . . . 
    # generate random projections in latent space
    projections = rand_projections(embedding_dim, num_projections).to(device)
    . . .

Modified

trainer.py

import numpy as np
import torch
import torch.nn.functional as F

from .distributions import rand_cirlce2d


def rand_projections(embedding_dim, num_samples=50):
    """This fn generates `L` random samples from the latent space's unit sphere.

        Args:
            embedding_dim (int): embedding dimension size
            num_samples (int): number of random projection samples

        Return:
            torch.Tensor
    """
    theta = [w / np.sqrt((w**2).sum())
             for w in np.random.normal(size=(num_samples, embedding_dim))]
    theta = np.asarray(theta)
    return torch.from_numpy(theta).type(torch.FloatTensor)


def _sliced_wasserstein_distance(encoded_samples,
                                 distribution_samples,
                                 num_projections=50,
                                 p=2,
                                 device='cpu'):
    """Sliced Wasserstein Distance between encoded samples and drawn
        distribution samples.

        Args:
            encoded_samples (toch.Tensor): embedded training tensor samples
            distribution_samples (torch.Tensor): distribution training tensor
             samples
            num_projections (int): number of projectsion to approximate sliced
             wasserstein distance
            p (int): power of distance metric
            device: 'cuda' or 'cpu' (default 'cpu')

        Return:
            torch.Tensor
    """
    # derive latent space dimension size from random samples drawn from a
    # distribution in it
    embedding_dim = distribution_samples.size(1)
    # generate random projections in latent space
    projections = rand_projections(embedding_dim, num_projections).to(device)
    # calculate projection of the encoded samples
    encoded_projections = encoded_samples.matmul(projections.transpose(0, 1))
    # calculate projection of the random distribution samples
    distribution_projections = (
        distribution_samples.matmul(projections.transpose(0, 1)))
    # calculate the sliced wasserstein distance by
    # sorting the samples per projection and
    # calculating the difference between the
    # encoded samples and drawn samples per projection
    wasserstein_distance = (
        torch.sort(encoded_projections.transpose(0, 1), dim=1)[0] -
        torch.sort(distribution_projections.transpose(0, 1), dim=1)[0])
    # distance between them (L2 by default for Wasserstein-2)
    wasserstein_distance_p = torch.pow(wasserstein_distance, p)
    # approximate wasserstein_distance for each projection
    return wasserstein_distance_p.mean()


def sliced_wasserstein_distance(encoded_samples,
                                distribution_fn=rand_cirlce2d,
                                num_projections=50,
                                p=2,
                                device='cpu'):
    """Sliced Wasserstein Distance between encoded samples and drawn
        distribution samples.

        Args:
            encoded_samples (toch.Tensor): embedded training tensor samples
            distribution_fn (callable): callable to draw random samples
            num_projections (int): number of projectsion to approximate sliced
            Wasserstein distance
            p (int): power of distance metric
            device: 'cuda' or 'cpu' (default 'cpu')

        Return:
            torch.Tensor
    """
    # derive batch size from encoded samples
    batch_size = encoded_samples.size(0)
    # draw samples from latent space prior distribution
    z = distribution_fn(batch_size).to(device)
    # approximate wasserstein_distance between encoded and prior distributions
    # for average over each projection
    swd = _sliced_wasserstein_distance(
                encoded_samples,
                z,
                num_projections,
                p,
                device)
    return swd


class SWAEBatchTrainer:
    """Sliced Wasserstein Autoencoder Batch Trainer.

        Args:
            autoencoder (torch.nn.Module): module which implements autoencoder
            framework
            optimizer (torch.optim.Optimizer): torch optimizer
            distribution_fn (callable): callable to draw random samples
            num_projections (int): number of projectsion to approximate sliced
            Wasserstein distance
            p (int): power of distance metric
            weight_swd (float): weight of divergence metric compared to
            reconstruction in loss
            device (torch.Device): torch device
    """
    def __init__(self, autoencoder, optimizer, distribution_fn,
                 num_projections=50, p=2, weight_swd=10.0, device=None):
        self.model_ = autoencoder
        self.optimizer = optimizer
        self._distribution_fn = distribution_fn
        self.embedding_dim_ = self.model_.encoder.embedding_dim_
        self.num_projections_ = num_projections
        self.p_ = p
        self.weight_swd = weight_swd
        self._device = device if device else torch.device('cpu')

    def __call__(self, x):
        return self.eval_on_batch(x)

    def train_on_batch(self, x):
        # reset gradients
        self.optimizer.zero_grad()
        # autoencoder forward pass and loss
        evals = self.eval_on_batch(x)
        # backpropagate loss
        evals['loss'].backward()
        # update encoder and decoder parameters
        self.optimizer.step()
        return evals

    def test_on_batch(self, x):
        # reset gradients
        self.optimizer.zero_grad()
        # autoencoder forward pass and loss
        evals = self.eval_on_batch(x)
        return evals

    def eval_on_batch(self, x):
        x = x.to(self._device)
        recon_x, z = self.model_(x)
        bce = F.binary_cross_entropy(recon_x, x)
        l1 = F.l1_loss(recon_x, x)
        w2 = float(self.weight_swd) * sliced_wasserstein_distance(
                                        z,
                                        self._distribution_fn,
                                        self.num_projections_,
                                        self.p_,
                                        self._device)
        loss = bce + l1 + w2
        return {'loss': loss,
                'bce': bce,
                'l1': l1,
                'w2': w2,
                'encode': z,
                'decode': recon_x}

thanks @Audrius-St , if you'd like feel free to put this in a pull request and i can merge it to master

@Audrius-St merged #5, should resolve the issue. thanks again.