CUDA issue. Exception: RuntimeError in _sliced_wasserstein_distance
Closed this issue · 5 comments
Thank you for creating a PyTorch version of the novel SWAE project.
Encountered an exception during runtime in line 48 of _sliced_wasserstein_distance
encoded_projections = encoded_samples.matmul(projections.transpose(0, 1))
with device
being set to "cuda" in main()
Expected object of type torch.cuda.FloatTensor but found type torch.FloatTensor for arguement #2 'mat2'.
encoded_samples
is a <Tensor, len() = 500>
with is_cuda: True
whereas
projections
is a <Tensor, len() = 50>
with is_cuda: False
System:
PyTorch 0.4.0
CUDA 9.0
Python 3.6.3 Anaconda/Intel
VS Code 1.25.1
Ubuntu Linux 18.04 x64
Regards.
thanks for logging the bug, oversight on my part. checkout #4 in the meantime. will test soon.
Proceeded to modify
trainer.py
to try and address this issue. Attached in the following post for your consideration and review. After these change, the code now runs without error message to completion: 30 epochs.
Added an additional argument to sliced_wasserstein_distance
to pass the device type and propagated it.
def eval_on_batch(self, x):
. . .
w2 = float(self.weight_swd) * sliced_wasserstein_distance(
z,
self._distribution_fn,
self.num_projections_,
self.p_,
self._device)
def sliced_wasserstein_distance(encoded_samples,
distribution_fn=rand_cirlce2d,
num_projections=50,
p=2,
device='cpu'):
. . .
z = distribution_fn(batch_size).to(device)
# approximate wasserstein_distance between encoded and prior distributions
# for average over each projection
swd = _sliced_wasserstein_distance(
encoded_samples,
z,
num_projections,
p,
device)
def _sliced_wasserstein_distance(encoded_samples,
distribution_samples,
num_projections=50,
p=2,
device='cpu'):
. . .
# generate random projections in latent space
projections = rand_projections(embedding_dim, num_projections).to(device)
. . .
Modified
trainer.py
import numpy as np
import torch
import torch.nn.functional as F
from .distributions import rand_cirlce2d
def rand_projections(embedding_dim, num_samples=50):
"""This fn generates `L` random samples from the latent space's unit sphere.
Args:
embedding_dim (int): embedding dimension size
num_samples (int): number of random projection samples
Return:
torch.Tensor
"""
theta = [w / np.sqrt((w**2).sum())
for w in np.random.normal(size=(num_samples, embedding_dim))]
theta = np.asarray(theta)
return torch.from_numpy(theta).type(torch.FloatTensor)
def _sliced_wasserstein_distance(encoded_samples,
distribution_samples,
num_projections=50,
p=2,
device='cpu'):
"""Sliced Wasserstein Distance between encoded samples and drawn
distribution samples.
Args:
encoded_samples (toch.Tensor): embedded training tensor samples
distribution_samples (torch.Tensor): distribution training tensor
samples
num_projections (int): number of projectsion to approximate sliced
wasserstein distance
p (int): power of distance metric
device: 'cuda' or 'cpu' (default 'cpu')
Return:
torch.Tensor
"""
# derive latent space dimension size from random samples drawn from a
# distribution in it
embedding_dim = distribution_samples.size(1)
# generate random projections in latent space
projections = rand_projections(embedding_dim, num_projections).to(device)
# calculate projection of the encoded samples
encoded_projections = encoded_samples.matmul(projections.transpose(0, 1))
# calculate projection of the random distribution samples
distribution_projections = (
distribution_samples.matmul(projections.transpose(0, 1)))
# calculate the sliced wasserstein distance by
# sorting the samples per projection and
# calculating the difference between the
# encoded samples and drawn samples per projection
wasserstein_distance = (
torch.sort(encoded_projections.transpose(0, 1), dim=1)[0] -
torch.sort(distribution_projections.transpose(0, 1), dim=1)[0])
# distance between them (L2 by default for Wasserstein-2)
wasserstein_distance_p = torch.pow(wasserstein_distance, p)
# approximate wasserstein_distance for each projection
return wasserstein_distance_p.mean()
def sliced_wasserstein_distance(encoded_samples,
distribution_fn=rand_cirlce2d,
num_projections=50,
p=2,
device='cpu'):
"""Sliced Wasserstein Distance between encoded samples and drawn
distribution samples.
Args:
encoded_samples (toch.Tensor): embedded training tensor samples
distribution_fn (callable): callable to draw random samples
num_projections (int): number of projectsion to approximate sliced
Wasserstein distance
p (int): power of distance metric
device: 'cuda' or 'cpu' (default 'cpu')
Return:
torch.Tensor
"""
# derive batch size from encoded samples
batch_size = encoded_samples.size(0)
# draw samples from latent space prior distribution
z = distribution_fn(batch_size).to(device)
# approximate wasserstein_distance between encoded and prior distributions
# for average over each projection
swd = _sliced_wasserstein_distance(
encoded_samples,
z,
num_projections,
p,
device)
return swd
class SWAEBatchTrainer:
"""Sliced Wasserstein Autoencoder Batch Trainer.
Args:
autoencoder (torch.nn.Module): module which implements autoencoder
framework
optimizer (torch.optim.Optimizer): torch optimizer
distribution_fn (callable): callable to draw random samples
num_projections (int): number of projectsion to approximate sliced
Wasserstein distance
p (int): power of distance metric
weight_swd (float): weight of divergence metric compared to
reconstruction in loss
device (torch.Device): torch device
"""
def __init__(self, autoencoder, optimizer, distribution_fn,
num_projections=50, p=2, weight_swd=10.0, device=None):
self.model_ = autoencoder
self.optimizer = optimizer
self._distribution_fn = distribution_fn
self.embedding_dim_ = self.model_.encoder.embedding_dim_
self.num_projections_ = num_projections
self.p_ = p
self.weight_swd = weight_swd
self._device = device if device else torch.device('cpu')
def __call__(self, x):
return self.eval_on_batch(x)
def train_on_batch(self, x):
# reset gradients
self.optimizer.zero_grad()
# autoencoder forward pass and loss
evals = self.eval_on_batch(x)
# backpropagate loss
evals['loss'].backward()
# update encoder and decoder parameters
self.optimizer.step()
return evals
def test_on_batch(self, x):
# reset gradients
self.optimizer.zero_grad()
# autoencoder forward pass and loss
evals = self.eval_on_batch(x)
return evals
def eval_on_batch(self, x):
x = x.to(self._device)
recon_x, z = self.model_(x)
bce = F.binary_cross_entropy(recon_x, x)
l1 = F.l1_loss(recon_x, x)
w2 = float(self.weight_swd) * sliced_wasserstein_distance(
z,
self._distribution_fn,
self.num_projections_,
self.p_,
self._device)
loss = bce + l1 + w2
return {'loss': loss,
'bce': bce,
'l1': l1,
'w2': w2,
'encode': z,
'decode': recon_x}
thanks @Audrius-St , if you'd like feel free to put this in a pull request and i can merge it to master
@Audrius-St merged #5, should resolve the issue. thanks again.