AdaptiveMotorControlLab/CEBRA

Sampling scheme for continuous auxiliary variables clarification

mariakesa opened this issue · 3 comments

Hello,

Thank you so much for your work! I am definitely having fun with CEBRA:-)

I just wanted to verify whether I understood the sampling scheme for shaping embeddings using an auxiliary variable. I read the supplement to the paper and looked at the code. Do I understand correctly that if I specify a continuous auxiliary variable, the model will create a distribution of delta in values between neighboring time points and then use this distribution to find the closest sample to the reference sample with this delta as a positive sample (the model still only learns from the main data ie the data that I would apply the time only algorithm to, except now I get the positive samples using the auxiliary variable)? Are the negative samples still sampled uniformly as in time only learning?

Thanks again!

thanks @mariakesa for the question; just a first small note to say we have Discussions open for questions: https://github.com/AdaptiveMotorControlLab/CEBRA/discussions , so in the future that would be an ideal place to post questions and leave issues to bug reports

  • and yes, the negative is still sample uniformly, unless you change that as well.

I apologize, I didn't notice the discussions forum.

Thank you so much! I'm glad I understood correctly!

stes commented

To this point,

Do I understand correctly that if I specify a continuous auxiliary variable, the model will create a distribution of delta in values between neighboring time points and then use this distribution to find the closest sample to the reference sample with this delta as a positive sample

That is correct. What you describe is implemented here:

def sample_conditional(self, reference_idx: torch.Tensor) -> torch.Tensor:
"""Return indices from the conditional distribution."""
if reference_idx.dim() != 1:
raise ValueError(
f"Reference indices have wrong shape: {reference_idx.shape}. "
"Pass a 1D array of indices of reference samples.")
num_samples = reference_idx.size(0)
diff_idx = self.randint(len(self.time_difference), (num_samples,))
query = self.data[reference_idx] + self.time_difference[diff_idx]
return self.index.search(query)

However, there are other possibilities, like sampling around the reference samples, as done in this case:

def sample_conditional(self, reference_idx: torch.Tensor) -> torch.Tensor:
"""Return indices from the conditional distribution."""
if reference_idx.dim() != 1:
raise ValueError(
f"Reference indices have wrong shape: {reference_idx.shape}. "
"Pass a 1D array of indices of reference samples.")
# TODO(stes): Set seed
query = torch.distributions.Normal(
self.data[reference_idx].squeeze(),
torch.ones_like(reference_idx, device=self.device) * self.std,
).sample()
return self.index.search(query.unsqueeze(-1))

(the model still only learns from the main data ie the data that I would apply the time only algorithm to, except now I get the positive samples using the auxiliary variable)?

Yes, it simply "re-indexes" the dataset you provide to fit()