sample function in SSLGaussMixture
ThyrixYang opened this issue · 1 comments
Hi,
In the sample function in SSLGaussMixture, I'm confused in the following code
n_samples = sample_shape[0]
idx = np.random.choice(self.n_components, size=(n_samples, 1), p=F.softmax(self.weights))
all_samples = [g.sample(sample_shape) for g in self.gaussians]
samples = all_samples[0]
for i in range(self.n_components):
mask = np.where(idx == i)
samples[mask] = all_samples[i][mask]
return samples
I found that mask = np.where(idx == I) generates a 2 dimensional index (since idx is 2d), thus samples[mask] only selects single elements instead of rows.
But I think samples[mask] should be replaced row by row accordingly?
For example, with 128 dimensional means, I get mask=(array([9]), array([0]))
and samples[mask]=tensor([1.4112]), but samples[mask] should be a 128d tensor instead.
Hi @ThyrixYang, I think you are right. It seems like in the code for the experiments we did not use the sampling from the GMM and instead sampled from each class separately:
flowgmm/experiments/train_flows/train_semisup_cons.py
Lines 422 to 424 in 422ff5d
I updated the code here following your suggestion.