izmailovpavel/flowgmm

sample function in SSLGaussMixture

ThyrixYang opened this issue · 1 comments

Hi,

In the sample function in SSLGaussMixture, I'm confused in the following code

n_samples = sample_shape[0]
idx = np.random.choice(self.n_components, size=(n_samples, 1), p=F.softmax(self.weights))
all_samples = [g.sample(sample_shape) for g in self.gaussians]
samples = all_samples[0]
for i in range(self.n_components):
    mask = np.where(idx == i)
    samples[mask] = all_samples[i][mask]
return samples

I found that mask = np.where(idx == I) generates a 2 dimensional index (since idx is 2d), thus samples[mask] only selects single elements instead of rows.
But I think samples[mask] should be replaced row by row accordingly?
For example, with 128 dimensional means, I get mask=(array([9]), array([0]))
and samples[mask]=tensor([1.4112]), but samples[mask] should be a 128d tensor instead.

Hi @ThyrixYang, I think you are right. It seems like in the code for the experiments we did not use the sampling from the GMM and instead sampled from each class separately:

for i in range(10):
images_cls = utils.sample(net, loss_fn.prior, args.num_samples // 10,
cls=i, device=device, sample_shape=img_shape)
.

I updated the code here following your suggestion.