What does time2vec do?
qAp opened this issue · 2 comments
It turns time features into vectors, but how exactly does it do that?
class Time2Vec(nn.Module):
def __init__(self, input_dim=6, embed_dim=512, act_function=torch.sin):
assert embed_dim % input_dim == 0
super(Time2Vec, self).__init__()
self.enabled = embed_dim > 0
if self.enabled:
self.embed_dim = embed_dim // input_dim
self.input_dim = input_dim
self.embed_weight = nn.parameter.Parameter(
torch.randn(self.input_dim, self.embed_dim)
)
self.embed_bias = nn.parameter.Parameter(torch.randn(self.embed_dim))
self.act_function = act_function
def forward(self, x):
if self.enabled:
# size of x = [bs, sample, input_dim]
x = torch.diag_embed(x)
x_affine = torch.matmul(x, self.embed_weight) + self.embed_bias
# size of x_affine = [bs, sample, input_dim, embed_dim]
x_affine_0, x_affine_remain = torch.split(
x_affine, [1, self.embed_dim - 1], dim=-1
)
x_affine_remain = self.act_function(x_affine_remain)
x_output = torch.cat([x_affine_0, x_affine_remain], dim=-1)
x_output = x_output.view(x_output.size(0), x_output.size(1), -1)
else:
x_output = x
return x_output
An example of a timestamp in this Kaggle could be 2019-01-23 15:05. From this, 6 features can be derived, namely ['Year', 'Month', 'Day', 'Weekday', 'Hour', 'Minute']
. When this timestamp is considered part of a sequence, it also has a relative position in the sequence. All together, a single timestamp has 6 + 1 = 7 features. These are normalised and passed to Time2Vec
. x
is therefore of shape (batch size, sequence length, 7). e.g. x[:,:,0]
contains the value of the Year feature, for all the elements of all the sequences in the batch.
From this point on, each time feature can be considered separately as it passes through Time2Vec
. Each time feature has an associated vector of trainable parameters that represents it. As a time feature value passes through Time2Vec
, this associated vector is scaled by the value, and a bias is then added:
x_affine = torch.matmul(x, self.embed_weight) + self.embed_bias
(The bias, self.embed_bias
, is shared between all time features.) self.embed_dim
is the length of the vector (in this case it's ~ 6, but it can be specified by the user).
Then, element 0 of this vector is separated from the remaining elements, the sine function is applied to the remaining elements, and is then joined back with element 0:
x_affine_0, x_affine_remain = torch.split(
x_affine, [1, self.embed_dim - 1], dim=-1
)
x_affine_remain = self.act_function(x_affine_remain)
x_output = torch.cat([x_affine_0, x_affine_remain], dim=-1)
Doing the above for each time feature results in 7 such vectors in total. Finally, these 7 vectors are concatenated, joined end-to-end, starting with the vector for Year, then Month, and so on. This results in a vector of length 7 x 6 = 42. This vector is the time2vec representation of the timestamp.
When self.embed_dim
equals 6, a time feature value x is converted to the time2vec vector:
[p_0, sin(p_1 * x), sin(p_2 * x), ..., sin(p_5, * x)]
Given it's unlikely that p_1 to p_5 are equal to each other, the vector points in a different direction for different values of x. Therefore, if x is the Month feature, for example, the different months will have time2vec vectors pointing in different directions.