amazon-science/polygon-transformer

Question related to input data

chandayh opened this issue · 4 comments

Dear authors,
Thank you for your work. I am new to this field and I am trying to understand the code. Basically when I tried to print the input data sample (print(sample) after here), I got the following output: (sorry this is a little bit long)

{'id': array(['2799386', '73473', '4854053', '6051283', '2232440', '1620037',
       '5308384', '3507770', '4748278', '2919324', '3509232', '1664255',
       '4152397', '3699919', '619370', '6110939', '1955315', '912612',
       '355299', '4966370'], dtype='<U7'), 'nsentences': 20, 'ntokens': 80, 'net_input': {'src_tokens': tensor([[  101,  2029,  2555,  2515,  1996,  3793,  1000,  2665, 12402,  1997,
         10529,  1000,  6235,  1029,   102,     0,     0,     0,     0,     0],
        [  101,  2029,  2555,  2515,  1996,  3793,  1000,  3861,  1997,  1037,
         16373,  3496,  1000,  6235,  1029,   102,     0,     0,     0,     0],
        [  101,  2029,  2555,  2515,  1996,  3793,  1000,  1037,  3940,  1997,
          7753,  1000,  6235,  1029,   102,     0,     0,     0,     0,     0],
        [  101,  2029,  2555,  2515,  1996,  3793,  1000,  7167, 12824,  3162,
          1000,  6235,  1029,   102,     0,     0,     0,     0,     0,     0],
        [  101,  2029,  2555,  2515,  1996,  3793,  1000,  2023,  2003,  5568,
          2181,  1000,  6235,  1029,   102,     0,     0,     0,     0,     0],
        [  101,  2029,  2555,  2515,  1996,  3793,  1000, 23060,  3788,  1999,
          2392,  1997,  3121,  1000,  6235,  1029,   102,     0,     0,     0],
        [  101,  2029,  2555,  2515,  1996,  3793,  1000,  1996,  2158,  2038,
          7877,  2006,  1000,  6235,  1029,   102,     0,     0,     0,     0],
        [  101,  2029,  2555,  2515,  1996,  3793,  1000,  2158, 20497, 12701,
          2006,  1996,  2250,  1012,  1000,  6235,  1029,   102,     0,     0],
        [  101,  2029,  2555,  2515,  1996,  3793,  1000, 14068,  2006,  1037,
          5127,  1000,  6235,  1029,   102,     0,     0,     0,     0,     0],
        [  101,  2029,  2555,  2515,  1996,  3793,  1000,  3392,  3589,  2007,
          3727,  1000,  6235,  1029,   102,     0,     0,     0,     0,     0],
        [  101,  2029,  2555,  2515,  1996,  3793,  1000,  3756,  6546,  1999,
         18781,  1000,  6235,  1029,   102,     0,     0,     0,     0,     0],
        [  101,  2029,  2555,  2515,  1996,  3793,  1000,  1037,  2304,  5189,
          5882,  1000,  6235,  1029,   102,     0,     0,     0,     0,     0],
        [  101,  2029,  2555,  2515,  1996,  3793,  1000,  2158,  2635,  1037,
          3861,  1000,  6235,  1029,   102,     0,     0,     0,     0,     0],
        [  101,  2029,  2555,  2515,  1996,  3793,  1000,  9092,  1998,  2829,
          7282, 21025, 27528,  7959,  1000,  6235,  1029,   102,     0,     0],
        [  101,  2029,  2555,  2515,  1996,  3793,  1000,  2028, 20581,  2003,
          4987,  2000,  1996,  5894,  1012,  1000,  6235,  1029,   102,     0],
        [  101,  2029,  2555,  2515,  1996,  3793,  1000,  2158,  1000,  6235,
          1029,   102,     0,     0,     0,     0,     0,     0,     0,     0],
        [  101,  2029,  2555,  2515,  1996,  3793,  1000,  2630, 17983, 15723,
         11876,  1000,  6235,  1029,   102,     0,     0,     0,     0,     0],
        [  101,  2029,  2555,  2515,  1996,  3793,  1000,  1037,  3221,  1997,
          2300,  1000,  6235,  1029,   102,     0,     0,     0,     0,     0],
        [  101,  2029,  2555,  2515,  1996,  3793,  1000,  2158,  2007, 28799,
         13383,  1000,  6235,  1029,   102,     0,     0,     0,     0,     0],
        [  101,  2029,  2555,  2515,  1996,  3793,  1000,  2217,  1997,  1996,
          2311,  5507,  2039,  2011,  1996,  3103,  1000,  6235,  1029,   102]],
       device='cuda:0'), 'src_lengths': tensor(315, device='cuda:0'), 'att_masks': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],
       device='cuda:0'), 'patch_images': tensor([[[[-0.6782, -0.6548, -0.6313,  ..., -0.8115, -0.8115, -0.8037],
          [-0.7100, -0.6782, -0.6548,  ..., -0.8115, -0.8115, -0.8198],
          [-0.7334, -0.7100, -0.6782,  ..., -0.8037, -0.7959, -0.8037],
          ...,
          [-0.4509, -0.4431, -0.4431,  ..., -1.0000, -1.0000, -1.0000],
          [-0.4509, -0.4431, -0.4353,  ..., -1.0000, -1.0000, -1.0000],
          [-0.4587, -0.4509, -0.4353,  ..., -1.0000, -1.0000, -1.0000]],

         [[-0.7881, -0.7725, -0.7568,  ..., -0.9292, -0.9292, -0.9214],
          [-0.8037, -0.7959, -0.7881,  ..., -0.9292, -0.9292, -0.9370],
          [-0.8115, -0.8037, -0.7959,  ..., -0.9370, -0.9370, -0.9370],
          ...,
          [-0.6626, -0.6548, -0.6548,  ..., -1.0000, -1.0000, -1.0000],
          [-0.6626, -0.6548, -0.6470,  ..., -1.0000, -1.0000, -1.0000],
          [-0.6704, -0.6626, -0.6470,  ..., -1.0000, -1.0000, -1.0000]],

         [[-0.8745, -0.8667, -0.8667,  ..., -0.9844, -0.9844, -0.9766],
          [-0.8979, -0.8901, -0.8901,  ..., -0.9844, -0.9844, -0.9922],
          [-0.8901, -0.8901, -0.8979,  ..., -0.9688, -0.9766, -0.9922],
          ...,
          [-0.8901, -0.8823, -0.8823,  ..., -1.0000, -1.0000, -1.0000],
          [-0.8901, -0.8823, -0.8823,  ..., -1.0000, -1.0000, -1.0000],
          [-0.9058, -0.8979, -0.8823,  ..., -1.0000, -1.0000, -1.0000]]],


        [[[ 0.3333,  0.3254,  0.3098,  ...,  0.5605,  0.5605,  0.5688],
          [ 0.3254,  0.3098,  0.2942,  ...,  0.5605,  0.5688,  0.5688],
          [ 0.3098,  0.2864,  0.2705,  ...,  0.5688,  0.5767,  0.5688],
          ...,
          [-0.1608, -0.1921, -0.2157,  ...,  0.3176,  0.0432, -0.4353],
          [-0.1765, -0.2000, -0.2000,  ...,  0.3098,  0.0039, -0.5767],
          [-0.1843, -0.2079, -0.2079,  ...,  0.3098, -0.0039, -0.6235]],

         [[ 0.4353,  0.4275,  0.4119,  ...,  0.5923,  0.6001,  0.6079],
          [ 0.4197,  0.4119,  0.3960,  ...,  0.5923,  0.6001,  0.6079],
          [ 0.3960,  0.3804,  0.3647,  ...,  0.5923,  0.6079,  0.6157],
          ...,
          [-0.5605, -0.5923, -0.6079,  ...,  0.2313, -0.0588, -0.5527],
          [-0.5845, -0.6079, -0.6079,  ...,  0.2313, -0.0745, -0.6626],
          [-0.5923, -0.6235, -0.6235,  ...,  0.2393, -0.0667, -0.6943]],

         [[ 0.3020,  0.2942,  0.2783,  ...,  0.5449,  0.5449,  0.5527],
          [ 0.2942,  0.2783,  0.2627,  ...,  0.5449,  0.5449,  0.5527],
          [ 0.2705,  0.2471,  0.2313,  ...,  0.5371,  0.5449,  0.5371],
          ...,
          [-0.8115, -0.8276, -0.8433,  ...,  0.0588, -0.2235, -0.7178],
          [-0.8589, -0.8745, -0.8589,  ...,  0.0823, -0.2157, -0.7881],
          [-0.8901, -0.9058, -0.8745,  ...,  0.0980, -0.2000, -0.8037]]],


        [[[ 0.7646,  0.7803,  0.8037,  ...,  0.5605,  0.5449,  0.5449],
          [ 0.7646,  0.7803,  0.8037,  ...,  0.5605,  0.5449,  0.5449],
          [ 0.7646,  0.7803,  0.8115,  ...,  0.5605,  0.5449,  0.5449],
          ...,
          [-0.7646, -0.6313, -0.5137,  ..., -0.6548, -0.6470, -0.7959],
          [-0.5767, -0.7568, -0.6860,  ..., -0.5923, -0.6079, -0.6782],
          [-0.3569, -0.8433, -0.8511,  ..., -0.5527, -0.5923, -0.5688]],

         [[ 0.8901,  0.8901,  0.8979,  ...,  0.7178,  0.7021,  0.7021],
          [ 0.8901,  0.8901,  0.9058,  ...,  0.7178,  0.7021,  0.7021],
          [ 0.8901,  0.8901,  0.9136,  ...,  0.7178,  0.7021,  0.7021],
          ...,
          [-0.7412, -0.6235, -0.4902,  ..., -0.6626, -0.6548, -0.8037],
          [-0.5767, -0.7490, -0.6704,  ..., -0.6157, -0.6235, -0.7021],
          [-0.3726, -0.8354, -0.8433,  ..., -0.5767, -0.6157, -0.5923]],

         [[ 0.9922,  0.9922,  0.9922,  ...,  0.9058,  0.8979,  0.8979],
          [ 0.9922,  0.9844,  0.9922,  ...,  0.9058,  0.8979,  0.8979],
          [ 0.9922,  0.9766,  0.9844,  ...,  0.9136,  0.8979,  0.8979],
          ...,
          [-0.7412, -0.6001, -0.4587,  ..., -0.7021, -0.6943, -0.8433],
          [-0.6313, -0.7490, -0.6392,  ..., -0.6548, -0.6626, -0.7412],
          [-0.4746, -0.8745, -0.8276,  ..., -0.6157, -0.6548, -0.6313]]],


        ...,


        [[[-0.4038, -0.3882, -0.4746,  ..., -0.0353, -0.0902, -0.2313],
          [-0.4666, -0.4038, -0.4509,  ..., -0.1451, -0.0353, -0.1059],
          [-0.5293, -0.5527, -0.5527,  ..., -0.3333, -0.1686, -0.1608],
          ...,
          [-0.1059,  0.1059,  0.1765,  ..., -0.8037, -0.9058, -1.0000],
          [-0.0196,  0.0823,  0.0667,  ..., -0.7959, -0.8433, -0.9058],
          [-0.0275, -0.1843, -0.2000,  ..., -0.8667, -0.7803, -0.7568]],

         [[-0.7568, -0.7568, -0.7412,  ...,  0.3411,  0.3098,  0.0667],
          [-0.7803, -0.7021, -0.7100,  ...,  0.2313,  0.2393,  0.0432],
          [-0.7725, -0.6704, -0.6548,  ...,  0.0353,  0.0902, -0.0039],
          ...,
          [-0.1843, -0.2627, -0.2157,  ..., -0.8901, -0.9136, -0.8901],
          [-0.2079, -0.1765, -0.1372,  ..., -0.9292, -0.9214, -0.8745],
          [-0.2705,  0.0432, -0.0745,  ..., -0.8901, -0.8433, -0.8433]],

         [[-0.7412, -0.7334, -0.6001,  ...,  0.5137,  0.4509,  0.2549],
          [-0.8823, -0.7334, -0.7412,  ...,  0.4038,  0.4824,  0.4275],
          [-0.9453, -0.8198, -0.8901,  ...,  0.1921,  0.3333,  0.4353],
          ...,
          [-0.2157, -0.2393, -0.1137,  ..., -0.8823, -0.9609, -1.0000],
          [-0.4902, -0.4509, -0.1921,  ..., -0.8276, -0.8901, -0.9370],
          [-0.8667, -0.6548, -0.4119,  ..., -0.7334, -0.7334, -0.8276]]],


        [[[-0.5923, -0.6001, -0.6079,  ...,  0.8433,  0.8354,  0.8354],
          [-0.6313, -0.6157, -0.6079,  ...,  0.8511,  0.8511,  0.8433],
          [-0.6235, -0.6235, -0.6235,  ...,  0.8667,  0.8589,  0.8589],
          ...,
          [-0.4746, -0.5137, -0.5293,  ..., -0.0902, -0.1216, -0.1372],
          [-0.4824, -0.5137, -0.5059,  ..., -0.0432, -0.0902, -0.1294],
          [-0.5293, -0.5605, -0.4980,  ..., -0.0275, -0.0353, -0.0823]],

         [[-0.6860, -0.6943, -0.6943,  ...,  0.8037,  0.8115,  0.8037],
          [-0.7021, -0.7021, -0.7021,  ...,  0.8115,  0.8198,  0.8115],
          [-0.7021, -0.7256, -0.7412,  ...,  0.8276,  0.8198,  0.8198],
          ...,
          [-0.5371, -0.5527, -0.5527,  ...,  0.0745,  0.0432,  0.0275],
          [-0.5293, -0.5449, -0.5527,  ...,  0.1216,  0.0902,  0.0432],
          [-0.5527, -0.5767, -0.5527,  ...,  0.1451,  0.1294,  0.0823]],

         [[-0.7725, -0.7803, -0.7959,  ...,  0.7803,  0.7881,  0.7803],
          [-0.8037, -0.8037, -0.8037,  ...,  0.7881,  0.7959,  0.7881],
          [-0.7646, -0.7881, -0.8037,  ...,  0.8037,  0.7959,  0.7959],
          ...,
          [-0.6001, -0.5845, -0.6157,  ...,  0.2235,  0.1843,  0.1608],
          [-0.5688, -0.5845, -0.6001,  ...,  0.2864,  0.2393,  0.1921],
          [-0.5767, -0.5923, -0.6001,  ...,  0.3098,  0.2942,  0.2471]]],


        [[[-0.2627, -0.3020, -0.2549,  ..., -0.2235, -0.2157, -0.2000],
          [-0.2705, -0.3098, -0.2705,  ..., -0.2393, -0.2235, -0.2079],
          [-0.2705, -0.3020, -0.2864,  ..., -0.2393, -0.2157, -0.2157],
          ...,
          [-0.4666, -0.4666, -0.6782,  ...,  0.4038,  0.3882,  0.3960],
          [-0.4587, -0.5293, -0.6548,  ...,  0.4197,  0.4038,  0.4038],
          [-0.4666, -0.5845, -0.6313,  ...,  0.4275,  0.4119,  0.4119]],

         [[ 0.1686,  0.1843,  0.1608,  ...,  0.2313,  0.2235,  0.2235],
          [ 0.1608,  0.1608,  0.1451,  ...,  0.2313,  0.2313,  0.2313],
          [ 0.1608,  0.1294,  0.1216,  ...,  0.2313,  0.2393,  0.2313],
          ...,
          [-0.6943, -0.6626, -0.8589,  ...,  0.6392,  0.6235,  0.6313],
          [-0.6943, -0.7100, -0.8354,  ...,  0.6548,  0.6392,  0.6392],
          [-0.7021, -0.7646, -0.8115,  ...,  0.6626,  0.6470,  0.6470]],

         [[ 0.6079,  0.5845,  0.5845,  ...,  0.6235,  0.6079,  0.5923],
          [ 0.6079,  0.5767,  0.5767,  ...,  0.6001,  0.5923,  0.5845],
          [ 0.6079,  0.5767,  0.5767,  ...,  0.5845,  0.5845,  0.5845],
          ...,
          [-0.6548, -0.6157, -0.8115,  ...,  0.8433,  0.8276,  0.8354],
          [-0.6392, -0.6548, -0.7881,  ...,  0.8667,  0.8433,  0.8433],
          [-0.6392, -0.7021, -0.7646,  ...,  0.8823,  0.8511,  0.8511]]]],
       device='cuda:0', dtype=torch.float16), 'patch_masks': tensor([True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True], device='cuda:0'), 'prev_output_tokens_11': tensor([[   0,  718, 1493],
        [   0, 1616, 2011],
        [   0,  619,  887],
        [   0,   11, 2238],
        [   0, 1209, 1469],
        [   0,   27, 2431],
        [   0, 2140, 2403],
        [   0, 2324, 3296],
        [   0,  241,  503],
        [   0, 2271, 3115],
        [   0, 2638, 3287],
        [   0, 3342, 3672],
        [   0,  782, 1836],
        [   0, 1514, 2044],
        [   0, 1670, 2701],
        [   0, 1233, 2546],
        [   0, 2658, 3948],
        [   0, 1412, 2123],
        [   0, 1816, 2880],
        [   0,  936, 3009]], device='cuda:0'), 'prev_output_tokens_12': tensor([[   0,  719, 1494],
        [   0, 1617, 2012],
        [   0,  620,  888],
        [   0,   12, 2239],
        [   0, 1210, 1470],
        [   0,   28, 2432],
        [   0, 2141, 2404],
        [   0, 2325, 3297],
        [   0,  242,  504],
        [   0, 2272, 3116],
        [   0, 2639, 3288],
        [   0, 3343, 3673],
        [   0,  783, 1837],
        [   0, 1515, 2045],
        [   0, 1671, 2702],
        [   0, 1234, 2547],
        [   0, 2659, 3949],
        [   0, 1413, 2124],
        [   0, 1817, 2881],
        [   0,  937, 3010]], device='cuda:0'), 'prev_output_tokens_21': tensor([[   0,  782, 1557],
        [   0, 1680, 2075],
        [   0,  683,  951],
        [   0,   75, 2302],
        [   0, 1273, 1533],
        [   0,   91, 2495],
        [   0, 2204, 2467],
        [   0, 2388, 3360],
        [   0,  305,  567],
        [   0, 2335, 3179],
        [   0, 2702, 3351],
        [   0, 3406, 3736],
        [   0,  846, 1900],
        [   0, 1578, 2108],
        [   0, 1734, 2765],
        [   0, 1297, 2610],
        [   0, 2722, 4012],
        [   0, 1476, 2187],
        [   0, 1880, 2944],
        [   0, 1000, 3073]], device='cuda:0'), 'prev_output_tokens_22': tensor([[   0,  783, 1558],
        [   0, 1681, 2076],
        [   0,  684,  952],
        [   0,   76, 2303],
        [   0, 1274, 1534],
        [   0,   92, 2496],
        [   0, 2205, 2468],
        [   0, 2389, 3361],
        [   0,  306,  568],
        [   0, 2336, 3180],
        [   0, 2703, 3352],
        [   0, 3407, 3737],
        [   0,  847, 1901],
        [   0, 1579, 2109],
        [   0, 1735, 2766],
        [   0, 1298, 2611],
        [   0, 2723, 4013],
        [   0, 1477, 2188],
        [   0, 1881, 2945],
        [   0, 1001, 3074]], device='cuda:0'), 'delta_x1': tensor([[0.0000, 0.9233, 0.0409],
        [0.0000, 0.2000, 0.0275],
        [0.0000, 0.5760, 0.7340],
        [0.0000, 0.1260, 0.3980],
        [0.0000, 0.6480, 0.8060],
        [0.0000, 0.5040, 0.6740],
        [0.0000, 0.2640, 0.0440],
        [0.0000, 0.4037, 0.8628],
        [0.0000, 0.9060, 0.8120],
        [0.0000, 0.2800, 0.6360],
        [0.0000, 0.5800, 0.6600],
        [0.0000, 0.1640, 0.4560],
        [0.0000, 0.2220, 0.2240],
        [0.0000, 0.5620, 0.3740],
        [0.0000, 0.6757, 0.3784],
        [0.0000, 0.6560, 0.4380],
        [0.0000, 0.6856, 0.3024],
        [0.0000, 0.0500, 0.2640],
        [0.0000, 0.1777, 0.6045],
        [0.0000, 0.8680, 0.6200]], device='cuda:0', dtype=torch.float64), 'delta_y1': tensor([[0.0000, 0.7100, 0.3880],
        [0.0000, 0.8616, 0.5433],
        [0.0000, 0.7297, 0.4595],
        [0.0000, 0.0731, 0.8877],
        [0.0000, 0.9135, 0.3462],
        [0.0000, 0.8198, 0.4382],
        [0.0000, 0.2941, 0.2941],
        [0.0000, 0.7580, 0.2240],
        [0.0000, 0.6497, 0.0401],
        [0.0000, 0.8460, 0.8160],
        [0.0000, 0.0576, 0.8359],
        [0.0000, 0.2480, 0.3280],
        [0.0000, 0.9200, 0.3200],
        [0.0000, 0.7838, 0.1892],
        [0.0000, 0.1420, 0.0720],
        [0.0000, 0.2432, 0.7297],
        [0.0000, 0.6180, 0.3200],
        [0.0000, 0.1680, 0.2240],
        [0.0000, 0.6921, 0.6906],
        [0.0000, 0.2440, 0.8614]], device='cuda:0', dtype=torch.float64), 'delta_x2': tensor([[1.0000, 0.0767, 0.9591],
        [1.0000, 0.8000, 0.9725],
        [1.0000, 0.4240, 0.2660],
        [1.0000, 0.8740, 0.6020],
        [1.0000, 0.3520, 0.1940],
        [1.0000, 0.4960, 0.3260],
        [1.0000, 0.7360, 0.9560],
        [1.0000, 0.5963, 0.1372],
        [1.0000, 0.0940, 0.1880],
        [1.0000, 0.7200, 0.3640],
        [1.0000, 0.4200, 0.3400],
        [1.0000, 0.8360, 0.5440],
        [1.0000, 0.7780, 0.7760],
        [1.0000, 0.4380, 0.6260],
        [1.0000, 0.3243, 0.6216],
        [1.0000, 0.3440, 0.5620],
        [1.0000, 0.3144, 0.6976],
        [1.0000, 0.9500, 0.7360],
        [1.0000, 0.8223, 0.3955],
        [1.0000, 0.1320, 0.3800]], device='cuda:0', dtype=torch.float64), 'delta_y2': tensor([[1.0000, 0.2900, 0.6120],
        [1.0000, 0.1384, 0.4567],
        [1.0000, 0.2703, 0.5405],
        [1.0000, 0.9269, 0.1123],
        [1.0000, 0.0865, 0.6538],
        [1.0000, 0.1802, 0.5618],
        [1.0000, 0.7059, 0.7059],
        [1.0000, 0.2420, 0.7760],
        [1.0000, 0.3503, 0.9599],
        [1.0000, 0.1540, 0.1840],
        [1.0000, 0.9424, 0.1641],
        [1.0000, 0.7520, 0.6720],
        [1.0000, 0.0800, 0.6800],
        [1.0000, 0.2162, 0.8108],
        [1.0000, 0.8580, 0.9280],
        [1.0000, 0.7568, 0.2703],
        [1.0000, 0.3820, 0.6800],
        [1.0000, 0.8320, 0.7760],
        [1.0000, 0.3079, 0.3094],
        [1.0000, 0.7560, 0.1386]], device='cuda:0', dtype=torch.float64)}, 'target': tensor([[[0.1892, 0.1700],
         [0.3657, 0.2759],
         [1.0000, 1.0000]],

        [[0.3999, 0.2041],
         [0.4924, 0.3738],
         [1.0000, 1.0000]],

        [[0.1520, 0.6309],
         [0.2180, 0.8169],
         [1.0000, 1.0000]],

        [[0.0020, 0.1122],
         [0.5459, 0.9346],
         [1.0000, 1.0000]],

        [[0.2959, 0.8560],
         [0.3621, 0.9102],
         [1.0000, 1.0000]],

        [[0.0080, 0.3782],
         [0.5981, 0.9434],
         [1.0000, 1.0000]],

        [[0.5278, 0.3857],
         [0.5879, 0.4968],
         [1.0000, 1.0000]],

        [[0.5776, 0.2661],
         [0.8232, 0.4480],
         [1.0000, 1.0000]],

        [[0.0620, 0.7246],
         [0.1240, 0.8101],
         [1.0000, 1.0000]],

        [[0.5601, 0.4419],
         [0.7720, 0.6318],
         [1.0000, 1.0000]],

        [[0.6602, 0.1597],
         [0.8198, 0.3149],
         [1.0000, 1.0000]],

        [[0.8281, 0.1627],
         [0.9121, 0.3228],
         [1.0000, 1.0000]],

        [[0.1940, 0.1733],
         [0.4480, 0.6401],
         [1.0000, 1.0000]],

        [[0.3740, 0.6157],
         [0.4980, 0.8921],
         [1.0000, 1.0000]],

        [[0.4233, 0.0340],
         [0.6729, 0.1440],
         [1.0000, 1.0000]],

        [[0.3120, 0.2102],
         [0.6260, 0.7417],
         [1.0000, 1.0000]],

        [[0.6616, 0.4861],
         [0.9731, 0.6401],
         [1.0000, 1.0000]],

        [[0.3501, 0.0027],
         [0.5278, 0.1147],
         [1.0000, 1.0000]],

        [[0.4473, 0.3284],
         [0.7080, 0.9634],
         [1.0000, 1.0000]],

        [[0.2360, 0.5752],
         [0.7402, 0.9819],
         [1.0000, 1.0000]]], device='cuda:0', dtype=torch.float16), 'token_type': tensor([[0, 0, 2],
        [0, 0, 2],
        [0, 0, 2],
        [0, 0, 2],
        [0, 0, 2],
        [0, 0, 2],
        [0, 0, 2],
        [0, 0, 2],
        [0, 0, 2],
        [0, 0, 2],
        [0, 0, 2],
        [0, 0, 2],
        [0, 0, 2],
        [0, 0, 2],
        [0, 0, 2],
        [0, 0, 2],
        [0, 0, 2],
        [0, 0, 2],
        [0, 0, 2],
        [0, 0, 2]], device='cuda:0'), 'w_resize_ratios': tensor([1.3096, 1.2803, 1.0244, 1.0244, 1.0244, 1.0244, 1.0244, 1.3506, 1.0244,
        1.0244, 1.0244, 1.0244, 1.0244, 1.0244, 1.5371, 1.0244, 1.5332, 1.0244,
        0.5000, 1.0244], device='cuda:0', dtype=torch.float16), 'h_resize_ratios': tensor([1.0244, 1.7715, 1.5371, 1.3369, 1.6406, 1.8096, 1.6729, 1.0244, 1.3691,
        1.0244, 1.1357, 1.3652, 1.3652, 1.5371, 1.0244, 1.5371, 1.0244, 1.3652,
        0.7510, 1.5420], device='cuda:0', dtype=torch.float16), 'region_coords': tensor([[ 74.,  85., 143., 138.],
        [160.,  59., 197., 108.],
        [ 76., 210., 109., 272.],
        [  1.,  43., 273., 358.],
        [148., 267., 181., 284.],
        [  4., 107., 299., 267.],
        [264., 118., 294., 152.],
        [219., 133., 312., 224.],
        [ 31., 271.,  62., 303.],
        [280., 221., 386., 316.],
        [330.,  72., 410., 142.],
        [414.,  61., 456., 121.],
        [ 97.,  65., 224., 240.],
        [187., 205., 249., 297.],
        [141.,  17., 224.,  72.],
        [156.,  70., 313., 247.],
        [221., 243., 325., 320.],
        [175.,   1., 264.,  43.],
        [458., 224., 725., 657.],
        [118., 191., 370., 326.]], device='cuda:0', dtype=torch.float16)}

So from my understanding, the id is the index of sample. And in net_input, src_tokens is the padded tokens for input (20 sentences, each with 20 tokens), att_masks denotes whether it is a padded token, patch_images is the input images. But what are prev_output_tokens_{1, 2}{1, 2}? I think these are the coordinates of grids but there are leading zeros which I couldn't quite understand. Also for delta_{x, y}{1, 2} there are also leading zeros. And what are the meaning of target and region_coords? They look like coordinates but the format is a little bit confusing.

I appreciate it if you can address my problems. Thank you!

Best,

Hi, prev_output_tokens_{1, 2}{1, 2} are the coordinates of the 4 nearest grid points. The leading zeros are the bos token. target is the ground truth coordinate and region_coords are the coordinates of the bounding boxes.

Thank you so much! So the prev_output_tokens and delta construct the input data altogether, and the output of the model should be close to target, with suffixing ones denoting the eos token, right?

Thank you so much! So the prev_output_tokens and delta construct the input data altogether, and the output of the model should be close to target, with suffixing ones denoting the eos token, right?

Yes

Perfect, thank you!