What is the meaning of the dataset tensors?
celsofranssa opened this issue · 1 comments
When iterating over dataset samples we have the fowling dictionary of tensors:
for sample in train_dataset:
for key in sample.keys():
print(f"\n{key} ({sample[key].shape}):\n {sample[key]}")
# encodings (torch.Size([35])):
# tensor([ 101, 3780, 1036, 7607, 1005, 1057, 1012, 1055, 1012, 5426,
# 2930, 2824, 13109, 16932, 2692, 28332, 15136, 2683, 2549, 15278,
# 2557, 2128, 4135, 3501, 2897, 1999, 3009, 12875, 2692, 13938,
# 2102, 2410, 13114, 6365, 102])
# context_masks (torch.Size([35])):
# tensor([True, True, True, True, True, True, True, True, True, True, True, True,
# True, True, True, True, True, True, True, True, True, True, True, True,
# True, True, True, True, True, True, True, True, True, True, True])
# entity_masks (torch.Size([105, 35])):
# tensor([[False, False, False, ..., False, False, False],
# [False, False, False, ..., False, False, False],
# [False, False, False, ..., False, False, False],
# ...,
# [False, False, False, ..., False, False, False],
# [False, False, False, ..., False, False, False],
# [False, True, True, ..., False, False, False]])
# entity_sizes (torch.Size([105])):
# tensor([ 1, 1, 3, 2, 3, 8, 9, 8, 2, 4, 4, 1, 9, 1, 4, 10, 10, 1,
# 1, 10, 8, 2, 4, 2, 3, 3, 6, 1, 2, 6, 1, 10, 10, 9, 3, 5,
# 3, 8, 8, 5, 1, 3, 1, 3, 5, 7, 8, 1, 3, 5, 2, 7, 8, 6,
# 10, 4, 4, 7, 3, 5, 5, 8, 6, 5, 8, 2, 6, 4, 6, 9, 9, 9,
# 10, 7, 1, 7, 9, 10, 5, 5, 2, 3, 1, 3, 7, 3, 5, 2, 6, 2,
# 7, 8, 3, 1, 6, 2, 1, 4, 6, 10, 4, 1, 7, 6, 9])
# entity_types (torch.Size([105])):
# tensor([1, 1, 2, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
# 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
# 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
# 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
# 0, 0, 0, 0, 0, 0, 0, 0, 0])
# rels (torch.Size([20, 2])):
# tensor([[2, 1],
# [0, 1],
# [1, 0],
# [1, 2],
# [2, 0],
# [1, 4],
# [1, 3],
# [0, 2],
# [3, 1],
# [2, 4],
# [2, 3],
# [3, 4],
# [3, 2],
# [4, 3],
# [4, 2],
# [0, 4],
# [4, 1],
# [4, 0],
# [0, 3],
# [3, 0]])
# rel_masks (torch.Size([20, 35])):
# tensor([[False, False, False, False, False, False, False, False, False, False,
# False, False, False, False, False, False, False, False, False, False,
# False, False, False, False, False, False, False, False, False, False,
# False, False, False, False, False],
# [False, False, False, False, False, False, False, False, False, True,
# True, True, True, True, True, True, True, True, True, False,
# False, False, False, False, False, False, False, False, False, False,
# False, False, False, False, False],
# [False, False, False, False, False, False, False, False, False, True,
# True, True, True, True, True, True, True, True, True, False,
# False, False, False, False, False, False, False, False, False, False,
# False, False, False, False, False],
# [False, False, False, False, False, False, False, False, False, False,
# False, False, False, False, False, False, False, False, False, False,
# False, False, False, False, False, False, False, False, False, False,
# False, False, False, False, False],
# [False, False, False, False, False, False, False, False, False, True,
# True, True, True, True, True, True, True, True, True, True,
# False, False, False, False, False, False, False, False, False, False,
# False, False, False, False, False],
# [False, False, False, False, False, False, False, False, False, False,
# False, False, False, False, False, False, False, False, False, False,
# True, True, True, True, True, True, True, True, True, True,
# True, False, False, False, False],
# [False, False, False, False, False, False, False, False, False, False,
# False, False, False, False, False, False, False, False, False, False,
# True, True, True, True, True, True, True, False, False, False,
# False, False, False, False, False],
# [False, False, False, False, False, False, False, False, False, True,
# True, True, True, True, True, True, True, True, True, True,
# False, False, False, False, False, False, False, False, False, False,
# False, False, False, False, False],
# [False, False, False, False, False, False, False, False, False, False,
# False, False, False, False, False, False, False, False, False, False,
# True, True, True, True, True, True, True, False, False, False,
# False, False, False, False, False],
# [False, False, False, False, False, False, False, False, False, False,
# False, False, False, False, False, False, False, False, False, False,
# False, False, False, False, False, True, True, True, True, True,
# True, False, False, False, False],
# [False, False, False, False, False, False, False, False, False, False,
# False, False, False, False, False, False, False, False, False, False,
# False, False, False, False, False, True, True, False, False, False,
# False, False, False, False, False],
# [False, False, False, False, False, False, False, False, False, False,
# False, False, False, False, False, False, False, False, False, False,
# False, False, False, False, False, False, False, False, False, False,
# False, False, False, False, False],
# [False, False, False, False, False, False, False, False, False, False,
# False, False, False, False, False, False, False, False, False, False,
# False, False, False, False, False, True, True, False, False, False,
# False, False, False, False, False],
# [False, False, False, False, False, False, False, False, False, False,
# False, False, False, False, False, False, False, False, False, False,
# False, False, False, False, False, False, False, False, False, False,
# False, False, False, False, False],
# [False, False, False, False, False, False, False, False, False, False,
# False, False, False, False, False, False, False, False, False, False,
# False, False, False, False, False, True, True, True, True, True,
# True, False, False, False, False],
# [False, False, False, False, False, False, False, False, False, True,
# True, True, True, True, True, True, True, True, True, True,
# True, True, True, True, True, True, True, True, True, True,
# True, False, False, False, False],
# [False, False, False, False, False, False, False, False, False, False,
# False, False, False, False, False, False, False, False, False, False,
# True, True, True, True, True, True, True, True, True, True,
# True, False, False, False, False],
# [False, False, False, False, False, False, False, False, False, True,
# True, True, True, True, True, True, True, True, True, True,
# True, True, True, True, True, True, True, True, True, True,
# True, False, False, False, False],
# [False, False, False, False, False, False, False, False, False, True,
# True, True, True, True, True, True, True, True, True, True,
# True, True, True, True, True, True, True, False, False, False,
# False, False, False, False, False],
# [False, False, False, False, False, False, False, False, False, True,
# True, True, True, True, True, True, True, True, True, True,
# True, True, True, True, True, True, True, False, False, False,
# False, False, False, False, False]])
# rel_types (torch.Size([20, 5])):
# tensor([[0., 0., 1., 0., 0.],
# [0., 0., 0., 0., 0.],
# [0., 0., 0., 0., 0.],
# [0., 0., 0., 0., 0.],
# [0., 0., 0., 0., 0.],
# [0., 0., 0., 0., 0.],
# [0., 0., 0., 0., 0.],
# [0., 0., 0., 0., 0.],
# [0., 0., 0., 0., 0.],
# [0., 0., 0., 0., 0.],
# [0., 0., 0., 0., 0.],
# [0., 0., 0., 0., 0.],
# [0., 0., 0., 0., 0.],
# [0., 0., 0., 0., 0.],
# [0., 0., 0., 0., 0.],
# [0., 0., 0., 0., 0.],
# [0., 0., 0., 0., 0.],
# [0., 0., 0., 0., 0.],
# [0., 0., 0., 0., 0.],
# [0., 0., 0., 0., 0.]])
# entity_sample_masks (torch.Size([105])):
# tensor([True, True, True, True, True, True, True, True, True, True, True, True,
# True, True, True, True, True, True, True, True, True, True, True, True,
# True, True, True, True, True, True, True, True, True, True, True, True,
# True, True, True, True, True, True, True, True, True, True, True, True,
# True, True, True, True, True, True, True, True, True, True, True, True,
# True, True, True, True, True, True, True, True, True, True, True, True,
# True, True, True, True, True, True, True, True, True, True, True, True,
# True, True, True, True, True, True, True, True, True, True, True, True,
# True, True, True, True, True, True, True, True, True])
# rel_sample_masks (torch.Size([20])):
# tensor([True, True, True, True, True, True, True, True, True, True, True, True,
# True, True, True, True, True, True, True, True])
Could you provide the meaning of these tensors?
For instance, encodings
and context-mask
maps directly to input_ids
and attention_mask
of BERT
forward
method. Therefore what are the semantics of the others tensors?
Hi,
entity_masks
is a ExC tensor (E := number of positive+negative entity mention samples, C := context size), used for accessing tokens belonging to an entity span (...and masking any other token).
entity_sizes
is a tensor of size E, containing the size of each entity mention span (which is later mapped to an embedding)
entity_types
is a tensor of size E, containing the id of the corresponding entity type (also mapped to an embedding)
rels
is a Rx2 tensor (R := number of positive+negative relation samples, i.e. pairs of related (or unrelated) entity mentions), which contains the indices of corresponding entity mentions in entity_masks
(and entity_size
+ entity_types
). Used to retrieve entity mention representations for each pair after max-pooling is applied via entity_masks
.
rel_masks
is a RxC tensor, used to access the tokens between two entity mention (and mask any other token).
rel_types
is a RxT tensor (T := number of relation types), which contains the multi-hot-encoding of relation types for each pair (all 0 -> strong negative sample)
entity_sample_masks
is a tensor of size E, used for masking 'padding' entity mention samples (since we need to introduce 'padding' mentions due to batching over sentences)
relation_sample_masks
is a tensor of size R, used for masking 'padding' relation samples (since we need to introduce 'padding' relations due to batching over sentences)