Question related to input data
chandayh opened this issue · 4 comments
Dear authors,
Thank you for your work. I am new to this field and I am trying to understand the code. Basically when I tried to print the input data sample (print(sample)
after here), I got the following output: (sorry this is a little bit long)
{'id': array(['2799386', '73473', '4854053', '6051283', '2232440', '1620037',
'5308384', '3507770', '4748278', '2919324', '3509232', '1664255',
'4152397', '3699919', '619370', '6110939', '1955315', '912612',
'355299', '4966370'], dtype='<U7'), 'nsentences': 20, 'ntokens': 80, 'net_input': {'src_tokens': tensor([[ 101, 2029, 2555, 2515, 1996, 3793, 1000, 2665, 12402, 1997,
10529, 1000, 6235, 1029, 102, 0, 0, 0, 0, 0],
[ 101, 2029, 2555, 2515, 1996, 3793, 1000, 3861, 1997, 1037,
16373, 3496, 1000, 6235, 1029, 102, 0, 0, 0, 0],
[ 101, 2029, 2555, 2515, 1996, 3793, 1000, 1037, 3940, 1997,
7753, 1000, 6235, 1029, 102, 0, 0, 0, 0, 0],
[ 101, 2029, 2555, 2515, 1996, 3793, 1000, 7167, 12824, 3162,
1000, 6235, 1029, 102, 0, 0, 0, 0, 0, 0],
[ 101, 2029, 2555, 2515, 1996, 3793, 1000, 2023, 2003, 5568,
2181, 1000, 6235, 1029, 102, 0, 0, 0, 0, 0],
[ 101, 2029, 2555, 2515, 1996, 3793, 1000, 23060, 3788, 1999,
2392, 1997, 3121, 1000, 6235, 1029, 102, 0, 0, 0],
[ 101, 2029, 2555, 2515, 1996, 3793, 1000, 1996, 2158, 2038,
7877, 2006, 1000, 6235, 1029, 102, 0, 0, 0, 0],
[ 101, 2029, 2555, 2515, 1996, 3793, 1000, 2158, 20497, 12701,
2006, 1996, 2250, 1012, 1000, 6235, 1029, 102, 0, 0],
[ 101, 2029, 2555, 2515, 1996, 3793, 1000, 14068, 2006, 1037,
5127, 1000, 6235, 1029, 102, 0, 0, 0, 0, 0],
[ 101, 2029, 2555, 2515, 1996, 3793, 1000, 3392, 3589, 2007,
3727, 1000, 6235, 1029, 102, 0, 0, 0, 0, 0],
[ 101, 2029, 2555, 2515, 1996, 3793, 1000, 3756, 6546, 1999,
18781, 1000, 6235, 1029, 102, 0, 0, 0, 0, 0],
[ 101, 2029, 2555, 2515, 1996, 3793, 1000, 1037, 2304, 5189,
5882, 1000, 6235, 1029, 102, 0, 0, 0, 0, 0],
[ 101, 2029, 2555, 2515, 1996, 3793, 1000, 2158, 2635, 1037,
3861, 1000, 6235, 1029, 102, 0, 0, 0, 0, 0],
[ 101, 2029, 2555, 2515, 1996, 3793, 1000, 9092, 1998, 2829,
7282, 21025, 27528, 7959, 1000, 6235, 1029, 102, 0, 0],
[ 101, 2029, 2555, 2515, 1996, 3793, 1000, 2028, 20581, 2003,
4987, 2000, 1996, 5894, 1012, 1000, 6235, 1029, 102, 0],
[ 101, 2029, 2555, 2515, 1996, 3793, 1000, 2158, 1000, 6235,
1029, 102, 0, 0, 0, 0, 0, 0, 0, 0],
[ 101, 2029, 2555, 2515, 1996, 3793, 1000, 2630, 17983, 15723,
11876, 1000, 6235, 1029, 102, 0, 0, 0, 0, 0],
[ 101, 2029, 2555, 2515, 1996, 3793, 1000, 1037, 3221, 1997,
2300, 1000, 6235, 1029, 102, 0, 0, 0, 0, 0],
[ 101, 2029, 2555, 2515, 1996, 3793, 1000, 2158, 2007, 28799,
13383, 1000, 6235, 1029, 102, 0, 0, 0, 0, 0],
[ 101, 2029, 2555, 2515, 1996, 3793, 1000, 2217, 1997, 1996,
2311, 5507, 2039, 2011, 1996, 3103, 1000, 6235, 1029, 102]],
device='cuda:0'), 'src_lengths': tensor(315, device='cuda:0'), 'att_masks': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],
device='cuda:0'), 'patch_images': tensor([[[[-0.6782, -0.6548, -0.6313, ..., -0.8115, -0.8115, -0.8037],
[-0.7100, -0.6782, -0.6548, ..., -0.8115, -0.8115, -0.8198],
[-0.7334, -0.7100, -0.6782, ..., -0.8037, -0.7959, -0.8037],
...,
[-0.4509, -0.4431, -0.4431, ..., -1.0000, -1.0000, -1.0000],
[-0.4509, -0.4431, -0.4353, ..., -1.0000, -1.0000, -1.0000],
[-0.4587, -0.4509, -0.4353, ..., -1.0000, -1.0000, -1.0000]],
[[-0.7881, -0.7725, -0.7568, ..., -0.9292, -0.9292, -0.9214],
[-0.8037, -0.7959, -0.7881, ..., -0.9292, -0.9292, -0.9370],
[-0.8115, -0.8037, -0.7959, ..., -0.9370, -0.9370, -0.9370],
...,
[-0.6626, -0.6548, -0.6548, ..., -1.0000, -1.0000, -1.0000],
[-0.6626, -0.6548, -0.6470, ..., -1.0000, -1.0000, -1.0000],
[-0.6704, -0.6626, -0.6470, ..., -1.0000, -1.0000, -1.0000]],
[[-0.8745, -0.8667, -0.8667, ..., -0.9844, -0.9844, -0.9766],
[-0.8979, -0.8901, -0.8901, ..., -0.9844, -0.9844, -0.9922],
[-0.8901, -0.8901, -0.8979, ..., -0.9688, -0.9766, -0.9922],
...,
[-0.8901, -0.8823, -0.8823, ..., -1.0000, -1.0000, -1.0000],
[-0.8901, -0.8823, -0.8823, ..., -1.0000, -1.0000, -1.0000],
[-0.9058, -0.8979, -0.8823, ..., -1.0000, -1.0000, -1.0000]]],
[[[ 0.3333, 0.3254, 0.3098, ..., 0.5605, 0.5605, 0.5688],
[ 0.3254, 0.3098, 0.2942, ..., 0.5605, 0.5688, 0.5688],
[ 0.3098, 0.2864, 0.2705, ..., 0.5688, 0.5767, 0.5688],
...,
[-0.1608, -0.1921, -0.2157, ..., 0.3176, 0.0432, -0.4353],
[-0.1765, -0.2000, -0.2000, ..., 0.3098, 0.0039, -0.5767],
[-0.1843, -0.2079, -0.2079, ..., 0.3098, -0.0039, -0.6235]],
[[ 0.4353, 0.4275, 0.4119, ..., 0.5923, 0.6001, 0.6079],
[ 0.4197, 0.4119, 0.3960, ..., 0.5923, 0.6001, 0.6079],
[ 0.3960, 0.3804, 0.3647, ..., 0.5923, 0.6079, 0.6157],
...,
[-0.5605, -0.5923, -0.6079, ..., 0.2313, -0.0588, -0.5527],
[-0.5845, -0.6079, -0.6079, ..., 0.2313, -0.0745, -0.6626],
[-0.5923, -0.6235, -0.6235, ..., 0.2393, -0.0667, -0.6943]],
[[ 0.3020, 0.2942, 0.2783, ..., 0.5449, 0.5449, 0.5527],
[ 0.2942, 0.2783, 0.2627, ..., 0.5449, 0.5449, 0.5527],
[ 0.2705, 0.2471, 0.2313, ..., 0.5371, 0.5449, 0.5371],
...,
[-0.8115, -0.8276, -0.8433, ..., 0.0588, -0.2235, -0.7178],
[-0.8589, -0.8745, -0.8589, ..., 0.0823, -0.2157, -0.7881],
[-0.8901, -0.9058, -0.8745, ..., 0.0980, -0.2000, -0.8037]]],
[[[ 0.7646, 0.7803, 0.8037, ..., 0.5605, 0.5449, 0.5449],
[ 0.7646, 0.7803, 0.8037, ..., 0.5605, 0.5449, 0.5449],
[ 0.7646, 0.7803, 0.8115, ..., 0.5605, 0.5449, 0.5449],
...,
[-0.7646, -0.6313, -0.5137, ..., -0.6548, -0.6470, -0.7959],
[-0.5767, -0.7568, -0.6860, ..., -0.5923, -0.6079, -0.6782],
[-0.3569, -0.8433, -0.8511, ..., -0.5527, -0.5923, -0.5688]],
[[ 0.8901, 0.8901, 0.8979, ..., 0.7178, 0.7021, 0.7021],
[ 0.8901, 0.8901, 0.9058, ..., 0.7178, 0.7021, 0.7021],
[ 0.8901, 0.8901, 0.9136, ..., 0.7178, 0.7021, 0.7021],
...,
[-0.7412, -0.6235, -0.4902, ..., -0.6626, -0.6548, -0.8037],
[-0.5767, -0.7490, -0.6704, ..., -0.6157, -0.6235, -0.7021],
[-0.3726, -0.8354, -0.8433, ..., -0.5767, -0.6157, -0.5923]],
[[ 0.9922, 0.9922, 0.9922, ..., 0.9058, 0.8979, 0.8979],
[ 0.9922, 0.9844, 0.9922, ..., 0.9058, 0.8979, 0.8979],
[ 0.9922, 0.9766, 0.9844, ..., 0.9136, 0.8979, 0.8979],
...,
[-0.7412, -0.6001, -0.4587, ..., -0.7021, -0.6943, -0.8433],
[-0.6313, -0.7490, -0.6392, ..., -0.6548, -0.6626, -0.7412],
[-0.4746, -0.8745, -0.8276, ..., -0.6157, -0.6548, -0.6313]]],
...,
[[[-0.4038, -0.3882, -0.4746, ..., -0.0353, -0.0902, -0.2313],
[-0.4666, -0.4038, -0.4509, ..., -0.1451, -0.0353, -0.1059],
[-0.5293, -0.5527, -0.5527, ..., -0.3333, -0.1686, -0.1608],
...,
[-0.1059, 0.1059, 0.1765, ..., -0.8037, -0.9058, -1.0000],
[-0.0196, 0.0823, 0.0667, ..., -0.7959, -0.8433, -0.9058],
[-0.0275, -0.1843, -0.2000, ..., -0.8667, -0.7803, -0.7568]],
[[-0.7568, -0.7568, -0.7412, ..., 0.3411, 0.3098, 0.0667],
[-0.7803, -0.7021, -0.7100, ..., 0.2313, 0.2393, 0.0432],
[-0.7725, -0.6704, -0.6548, ..., 0.0353, 0.0902, -0.0039],
...,
[-0.1843, -0.2627, -0.2157, ..., -0.8901, -0.9136, -0.8901],
[-0.2079, -0.1765, -0.1372, ..., -0.9292, -0.9214, -0.8745],
[-0.2705, 0.0432, -0.0745, ..., -0.8901, -0.8433, -0.8433]],
[[-0.7412, -0.7334, -0.6001, ..., 0.5137, 0.4509, 0.2549],
[-0.8823, -0.7334, -0.7412, ..., 0.4038, 0.4824, 0.4275],
[-0.9453, -0.8198, -0.8901, ..., 0.1921, 0.3333, 0.4353],
...,
[-0.2157, -0.2393, -0.1137, ..., -0.8823, -0.9609, -1.0000],
[-0.4902, -0.4509, -0.1921, ..., -0.8276, -0.8901, -0.9370],
[-0.8667, -0.6548, -0.4119, ..., -0.7334, -0.7334, -0.8276]]],
[[[-0.5923, -0.6001, -0.6079, ..., 0.8433, 0.8354, 0.8354],
[-0.6313, -0.6157, -0.6079, ..., 0.8511, 0.8511, 0.8433],
[-0.6235, -0.6235, -0.6235, ..., 0.8667, 0.8589, 0.8589],
...,
[-0.4746, -0.5137, -0.5293, ..., -0.0902, -0.1216, -0.1372],
[-0.4824, -0.5137, -0.5059, ..., -0.0432, -0.0902, -0.1294],
[-0.5293, -0.5605, -0.4980, ..., -0.0275, -0.0353, -0.0823]],
[[-0.6860, -0.6943, -0.6943, ..., 0.8037, 0.8115, 0.8037],
[-0.7021, -0.7021, -0.7021, ..., 0.8115, 0.8198, 0.8115],
[-0.7021, -0.7256, -0.7412, ..., 0.8276, 0.8198, 0.8198],
...,
[-0.5371, -0.5527, -0.5527, ..., 0.0745, 0.0432, 0.0275],
[-0.5293, -0.5449, -0.5527, ..., 0.1216, 0.0902, 0.0432],
[-0.5527, -0.5767, -0.5527, ..., 0.1451, 0.1294, 0.0823]],
[[-0.7725, -0.7803, -0.7959, ..., 0.7803, 0.7881, 0.7803],
[-0.8037, -0.8037, -0.8037, ..., 0.7881, 0.7959, 0.7881],
[-0.7646, -0.7881, -0.8037, ..., 0.8037, 0.7959, 0.7959],
...,
[-0.6001, -0.5845, -0.6157, ..., 0.2235, 0.1843, 0.1608],
[-0.5688, -0.5845, -0.6001, ..., 0.2864, 0.2393, 0.1921],
[-0.5767, -0.5923, -0.6001, ..., 0.3098, 0.2942, 0.2471]]],
[[[-0.2627, -0.3020, -0.2549, ..., -0.2235, -0.2157, -0.2000],
[-0.2705, -0.3098, -0.2705, ..., -0.2393, -0.2235, -0.2079],
[-0.2705, -0.3020, -0.2864, ..., -0.2393, -0.2157, -0.2157],
...,
[-0.4666, -0.4666, -0.6782, ..., 0.4038, 0.3882, 0.3960],
[-0.4587, -0.5293, -0.6548, ..., 0.4197, 0.4038, 0.4038],
[-0.4666, -0.5845, -0.6313, ..., 0.4275, 0.4119, 0.4119]],
[[ 0.1686, 0.1843, 0.1608, ..., 0.2313, 0.2235, 0.2235],
[ 0.1608, 0.1608, 0.1451, ..., 0.2313, 0.2313, 0.2313],
[ 0.1608, 0.1294, 0.1216, ..., 0.2313, 0.2393, 0.2313],
...,
[-0.6943, -0.6626, -0.8589, ..., 0.6392, 0.6235, 0.6313],
[-0.6943, -0.7100, -0.8354, ..., 0.6548, 0.6392, 0.6392],
[-0.7021, -0.7646, -0.8115, ..., 0.6626, 0.6470, 0.6470]],
[[ 0.6079, 0.5845, 0.5845, ..., 0.6235, 0.6079, 0.5923],
[ 0.6079, 0.5767, 0.5767, ..., 0.6001, 0.5923, 0.5845],
[ 0.6079, 0.5767, 0.5767, ..., 0.5845, 0.5845, 0.5845],
...,
[-0.6548, -0.6157, -0.8115, ..., 0.8433, 0.8276, 0.8354],
[-0.6392, -0.6548, -0.7881, ..., 0.8667, 0.8433, 0.8433],
[-0.6392, -0.7021, -0.7646, ..., 0.8823, 0.8511, 0.8511]]]],
device='cuda:0', dtype=torch.float16), 'patch_masks': tensor([True, True, True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True], device='cuda:0'), 'prev_output_tokens_11': tensor([[ 0, 718, 1493],
[ 0, 1616, 2011],
[ 0, 619, 887],
[ 0, 11, 2238],
[ 0, 1209, 1469],
[ 0, 27, 2431],
[ 0, 2140, 2403],
[ 0, 2324, 3296],
[ 0, 241, 503],
[ 0, 2271, 3115],
[ 0, 2638, 3287],
[ 0, 3342, 3672],
[ 0, 782, 1836],
[ 0, 1514, 2044],
[ 0, 1670, 2701],
[ 0, 1233, 2546],
[ 0, 2658, 3948],
[ 0, 1412, 2123],
[ 0, 1816, 2880],
[ 0, 936, 3009]], device='cuda:0'), 'prev_output_tokens_12': tensor([[ 0, 719, 1494],
[ 0, 1617, 2012],
[ 0, 620, 888],
[ 0, 12, 2239],
[ 0, 1210, 1470],
[ 0, 28, 2432],
[ 0, 2141, 2404],
[ 0, 2325, 3297],
[ 0, 242, 504],
[ 0, 2272, 3116],
[ 0, 2639, 3288],
[ 0, 3343, 3673],
[ 0, 783, 1837],
[ 0, 1515, 2045],
[ 0, 1671, 2702],
[ 0, 1234, 2547],
[ 0, 2659, 3949],
[ 0, 1413, 2124],
[ 0, 1817, 2881],
[ 0, 937, 3010]], device='cuda:0'), 'prev_output_tokens_21': tensor([[ 0, 782, 1557],
[ 0, 1680, 2075],
[ 0, 683, 951],
[ 0, 75, 2302],
[ 0, 1273, 1533],
[ 0, 91, 2495],
[ 0, 2204, 2467],
[ 0, 2388, 3360],
[ 0, 305, 567],
[ 0, 2335, 3179],
[ 0, 2702, 3351],
[ 0, 3406, 3736],
[ 0, 846, 1900],
[ 0, 1578, 2108],
[ 0, 1734, 2765],
[ 0, 1297, 2610],
[ 0, 2722, 4012],
[ 0, 1476, 2187],
[ 0, 1880, 2944],
[ 0, 1000, 3073]], device='cuda:0'), 'prev_output_tokens_22': tensor([[ 0, 783, 1558],
[ 0, 1681, 2076],
[ 0, 684, 952],
[ 0, 76, 2303],
[ 0, 1274, 1534],
[ 0, 92, 2496],
[ 0, 2205, 2468],
[ 0, 2389, 3361],
[ 0, 306, 568],
[ 0, 2336, 3180],
[ 0, 2703, 3352],
[ 0, 3407, 3737],
[ 0, 847, 1901],
[ 0, 1579, 2109],
[ 0, 1735, 2766],
[ 0, 1298, 2611],
[ 0, 2723, 4013],
[ 0, 1477, 2188],
[ 0, 1881, 2945],
[ 0, 1001, 3074]], device='cuda:0'), 'delta_x1': tensor([[0.0000, 0.9233, 0.0409],
[0.0000, 0.2000, 0.0275],
[0.0000, 0.5760, 0.7340],
[0.0000, 0.1260, 0.3980],
[0.0000, 0.6480, 0.8060],
[0.0000, 0.5040, 0.6740],
[0.0000, 0.2640, 0.0440],
[0.0000, 0.4037, 0.8628],
[0.0000, 0.9060, 0.8120],
[0.0000, 0.2800, 0.6360],
[0.0000, 0.5800, 0.6600],
[0.0000, 0.1640, 0.4560],
[0.0000, 0.2220, 0.2240],
[0.0000, 0.5620, 0.3740],
[0.0000, 0.6757, 0.3784],
[0.0000, 0.6560, 0.4380],
[0.0000, 0.6856, 0.3024],
[0.0000, 0.0500, 0.2640],
[0.0000, 0.1777, 0.6045],
[0.0000, 0.8680, 0.6200]], device='cuda:0', dtype=torch.float64), 'delta_y1': tensor([[0.0000, 0.7100, 0.3880],
[0.0000, 0.8616, 0.5433],
[0.0000, 0.7297, 0.4595],
[0.0000, 0.0731, 0.8877],
[0.0000, 0.9135, 0.3462],
[0.0000, 0.8198, 0.4382],
[0.0000, 0.2941, 0.2941],
[0.0000, 0.7580, 0.2240],
[0.0000, 0.6497, 0.0401],
[0.0000, 0.8460, 0.8160],
[0.0000, 0.0576, 0.8359],
[0.0000, 0.2480, 0.3280],
[0.0000, 0.9200, 0.3200],
[0.0000, 0.7838, 0.1892],
[0.0000, 0.1420, 0.0720],
[0.0000, 0.2432, 0.7297],
[0.0000, 0.6180, 0.3200],
[0.0000, 0.1680, 0.2240],
[0.0000, 0.6921, 0.6906],
[0.0000, 0.2440, 0.8614]], device='cuda:0', dtype=torch.float64), 'delta_x2': tensor([[1.0000, 0.0767, 0.9591],
[1.0000, 0.8000, 0.9725],
[1.0000, 0.4240, 0.2660],
[1.0000, 0.8740, 0.6020],
[1.0000, 0.3520, 0.1940],
[1.0000, 0.4960, 0.3260],
[1.0000, 0.7360, 0.9560],
[1.0000, 0.5963, 0.1372],
[1.0000, 0.0940, 0.1880],
[1.0000, 0.7200, 0.3640],
[1.0000, 0.4200, 0.3400],
[1.0000, 0.8360, 0.5440],
[1.0000, 0.7780, 0.7760],
[1.0000, 0.4380, 0.6260],
[1.0000, 0.3243, 0.6216],
[1.0000, 0.3440, 0.5620],
[1.0000, 0.3144, 0.6976],
[1.0000, 0.9500, 0.7360],
[1.0000, 0.8223, 0.3955],
[1.0000, 0.1320, 0.3800]], device='cuda:0', dtype=torch.float64), 'delta_y2': tensor([[1.0000, 0.2900, 0.6120],
[1.0000, 0.1384, 0.4567],
[1.0000, 0.2703, 0.5405],
[1.0000, 0.9269, 0.1123],
[1.0000, 0.0865, 0.6538],
[1.0000, 0.1802, 0.5618],
[1.0000, 0.7059, 0.7059],
[1.0000, 0.2420, 0.7760],
[1.0000, 0.3503, 0.9599],
[1.0000, 0.1540, 0.1840],
[1.0000, 0.9424, 0.1641],
[1.0000, 0.7520, 0.6720],
[1.0000, 0.0800, 0.6800],
[1.0000, 0.2162, 0.8108],
[1.0000, 0.8580, 0.9280],
[1.0000, 0.7568, 0.2703],
[1.0000, 0.3820, 0.6800],
[1.0000, 0.8320, 0.7760],
[1.0000, 0.3079, 0.3094],
[1.0000, 0.7560, 0.1386]], device='cuda:0', dtype=torch.float64)}, 'target': tensor([[[0.1892, 0.1700],
[0.3657, 0.2759],
[1.0000, 1.0000]],
[[0.3999, 0.2041],
[0.4924, 0.3738],
[1.0000, 1.0000]],
[[0.1520, 0.6309],
[0.2180, 0.8169],
[1.0000, 1.0000]],
[[0.0020, 0.1122],
[0.5459, 0.9346],
[1.0000, 1.0000]],
[[0.2959, 0.8560],
[0.3621, 0.9102],
[1.0000, 1.0000]],
[[0.0080, 0.3782],
[0.5981, 0.9434],
[1.0000, 1.0000]],
[[0.5278, 0.3857],
[0.5879, 0.4968],
[1.0000, 1.0000]],
[[0.5776, 0.2661],
[0.8232, 0.4480],
[1.0000, 1.0000]],
[[0.0620, 0.7246],
[0.1240, 0.8101],
[1.0000, 1.0000]],
[[0.5601, 0.4419],
[0.7720, 0.6318],
[1.0000, 1.0000]],
[[0.6602, 0.1597],
[0.8198, 0.3149],
[1.0000, 1.0000]],
[[0.8281, 0.1627],
[0.9121, 0.3228],
[1.0000, 1.0000]],
[[0.1940, 0.1733],
[0.4480, 0.6401],
[1.0000, 1.0000]],
[[0.3740, 0.6157],
[0.4980, 0.8921],
[1.0000, 1.0000]],
[[0.4233, 0.0340],
[0.6729, 0.1440],
[1.0000, 1.0000]],
[[0.3120, 0.2102],
[0.6260, 0.7417],
[1.0000, 1.0000]],
[[0.6616, 0.4861],
[0.9731, 0.6401],
[1.0000, 1.0000]],
[[0.3501, 0.0027],
[0.5278, 0.1147],
[1.0000, 1.0000]],
[[0.4473, 0.3284],
[0.7080, 0.9634],
[1.0000, 1.0000]],
[[0.2360, 0.5752],
[0.7402, 0.9819],
[1.0000, 1.0000]]], device='cuda:0', dtype=torch.float16), 'token_type': tensor([[0, 0, 2],
[0, 0, 2],
[0, 0, 2],
[0, 0, 2],
[0, 0, 2],
[0, 0, 2],
[0, 0, 2],
[0, 0, 2],
[0, 0, 2],
[0, 0, 2],
[0, 0, 2],
[0, 0, 2],
[0, 0, 2],
[0, 0, 2],
[0, 0, 2],
[0, 0, 2],
[0, 0, 2],
[0, 0, 2],
[0, 0, 2],
[0, 0, 2]], device='cuda:0'), 'w_resize_ratios': tensor([1.3096, 1.2803, 1.0244, 1.0244, 1.0244, 1.0244, 1.0244, 1.3506, 1.0244,
1.0244, 1.0244, 1.0244, 1.0244, 1.0244, 1.5371, 1.0244, 1.5332, 1.0244,
0.5000, 1.0244], device='cuda:0', dtype=torch.float16), 'h_resize_ratios': tensor([1.0244, 1.7715, 1.5371, 1.3369, 1.6406, 1.8096, 1.6729, 1.0244, 1.3691,
1.0244, 1.1357, 1.3652, 1.3652, 1.5371, 1.0244, 1.5371, 1.0244, 1.3652,
0.7510, 1.5420], device='cuda:0', dtype=torch.float16), 'region_coords': tensor([[ 74., 85., 143., 138.],
[160., 59., 197., 108.],
[ 76., 210., 109., 272.],
[ 1., 43., 273., 358.],
[148., 267., 181., 284.],
[ 4., 107., 299., 267.],
[264., 118., 294., 152.],
[219., 133., 312., 224.],
[ 31., 271., 62., 303.],
[280., 221., 386., 316.],
[330., 72., 410., 142.],
[414., 61., 456., 121.],
[ 97., 65., 224., 240.],
[187., 205., 249., 297.],
[141., 17., 224., 72.],
[156., 70., 313., 247.],
[221., 243., 325., 320.],
[175., 1., 264., 43.],
[458., 224., 725., 657.],
[118., 191., 370., 326.]], device='cuda:0', dtype=torch.float16)}
So from my understanding, the id
is the index of sample. And in net_input
, src_tokens
is the padded tokens for input (20 sentences, each with 20 tokens), att_masks
denotes whether it is a padded token, patch_images
is the input images. But what are prev_output_tokens_{1, 2}{1, 2}
? I think these are the coordinates of grids but there are leading zeros which I couldn't quite understand. Also for delta_{x, y}{1, 2}
there are also leading zeros. And what are the meaning of target
and region_coords
? They look like coordinates but the format is a little bit confusing.
I appreciate it if you can address my problems. Thank you!
Best,
Hi, prev_output_tokens_{1, 2}{1, 2}
are the coordinates of the 4 nearest grid points. The leading zeros are the bos token. target
is the ground truth coordinate and region_coords
are the coordinates of the bounding boxes.
Thank you so much! So the prev_output_tokens
and delta
construct the input data altogether, and the output of the model should be close to target
, with suffixing ones denoting the eos
token, right?
Thank you so much! So the
prev_output_tokens
anddelta
construct the input data altogether, and the output of the model should be close totarget
, with suffixing ones denoting theeos
token, right?
Yes
Perfect, thank you!