qAp/gresearch_crypto_forecasting_kaggle

What does time2vec do?

qAp opened this issue · 2 comments

qAp commented

It turns time features into vectors, but how exactly does it do that?

class Time2Vec(nn.Module):
    def __init__(self, input_dim=6, embed_dim=512, act_function=torch.sin):
        assert embed_dim % input_dim == 0
        super(Time2Vec, self).__init__()
        self.enabled = embed_dim > 0
        if self.enabled:
            self.embed_dim = embed_dim // input_dim
            self.input_dim = input_dim
            self.embed_weight = nn.parameter.Parameter(
                torch.randn(self.input_dim, self.embed_dim)
            )
            self.embed_bias = nn.parameter.Parameter(torch.randn(self.embed_dim))
            self.act_function = act_function

    def forward(self, x):
        if self.enabled:
            # size of x = [bs, sample, input_dim]
            x = torch.diag_embed(x)
            x_affine = torch.matmul(x, self.embed_weight) + self.embed_bias
            # size of x_affine = [bs, sample, input_dim, embed_dim]
            x_affine_0, x_affine_remain = torch.split(
                x_affine, [1, self.embed_dim - 1], dim=-1
            )
            x_affine_remain = self.act_function(x_affine_remain)
            x_output = torch.cat([x_affine_0, x_affine_remain], dim=-1)
            x_output = x_output.view(x_output.size(0), x_output.size(1), -1)
        else:
            x_output = x
        return x_output
qAp commented

An example of a timestamp in this Kaggle could be 2019-01-23 15:05. From this, 6 features can be derived, namely ['Year', 'Month', 'Day', 'Weekday', 'Hour', 'Minute']. When this timestamp is considered part of a sequence, it also has a relative position in the sequence. All together, a single timestamp has 6 + 1 = 7 features. These are normalised and passed to Time2Vec. x is therefore of shape (batch size, sequence length, 7). e.g. x[:,:,0] contains the value of the Year feature, for all the elements of all the sequences in the batch.

From this point on, each time feature can be considered separately as it passes through Time2Vec. Each time feature has an associated vector of trainable parameters that represents it. As a time feature value passes through Time2Vec, this associated vector is scaled by the value, and a bias is then added:

x_affine = torch.matmul(x, self.embed_weight) + self.embed_bias

(The bias, self.embed_bias, is shared between all time features.) self.embed_dim is the length of the vector (in this case it's ~ 6, but it can be specified by the user).

Then, element 0 of this vector is separated from the remaining elements, the sine function is applied to the remaining elements, and is then joined back with element 0:

            x_affine_0, x_affine_remain = torch.split(
                x_affine, [1, self.embed_dim - 1], dim=-1
            )
            x_affine_remain = self.act_function(x_affine_remain)
            x_output = torch.cat([x_affine_0, x_affine_remain], dim=-1)

Doing the above for each time feature results in 7 such vectors in total. Finally, these 7 vectors are concatenated, joined end-to-end, starting with the vector for Year, then Month, and so on. This results in a vector of length 7 x 6 = 42. This vector is the time2vec representation of the timestamp.

qAp commented

When self.embed_dim equals 6, a time feature value x is converted to the time2vec vector:

[p_0, sin(p_1 * x), sin(p_2 * x), ..., sin(p_5, * x)]

Given it's unlikely that p_1 to p_5 are equal to each other, the vector points in a different direction for different values of x. Therefore, if x is the Month feature, for example, the different months will have time2vec vectors pointing in different directions.