ThomasDelteil/HandwrittenTextRecognition_MXNet

Question on the shape of feature map of OCR_LSTM_CTC

Closed this issue · 3 comments

In handwriting_recognition.ipynb:

    def forward(self, x):
        x = x.transpose((0, 3, 1, 2))
        x = x.flatten()
        x = x.split(num_outputs=max_seq_len, axis=1) # (SEQ_LEN, N, CHANNELS)
        x = nd.concat(*[elem.expand_dims(axis=0) for elem in x], dim=0)
        x = self.lstm(x)
        x = x.transpose((1, 0, 2)) #(N, SEQ_LEN, HIDDEN_UNITS)
        return x

I notice the input featuremap for EncoderLayer has first been reshaped by: x = x.transpose((0, 3, 1, 2)) , but I think this code maybe useless, as this kind of transpose is usually done for picture array which has channel at the last dimension, but not for featuremap. Is there a special reason for the code?

In addition, for the reshape before doing lstm, I firstly replace code:

 x = x.split(num_outputs=max_seq_len, axis=1) # (SEQ_LEN, N, CHANNELS)
 x = nd.concat(*[elem.expand_dims(axis=0) for elem in x], dim=0)

with x = x.reshape(SEQ_LEN, BATCH_SIZE, -1), and I found the elements are ordered differently with the old one, though their final shapes are the same. Then I wonder if there is some reason to reshape it the way you did?

The idea is to reshape from N,C,H,W to N,T,C where T is the "temporal dimension" in that case, the top to bottom, left to right reading order.

Let's assume you have a batch size of 1, with 1 feature map (to simplify) of 100x100.

import mxnet as mx
from mxnet import gluon, nd
from mxnet.gluon import nn

import numpy as np


x = nd.arange(0,100*100).reshape(1,100,100).repeat(1, axis=0).astype(np.int64).expand_dims(axis=0)

x
[[[[   0    1    2 ...,   97   98   99]
   [ 100  101  102 ...,  197  198  199]
   [ 200  201  202 ...,  297  298  299]
   ..., 
   [9700 9701 9702 ..., 9797 9798 9799]
   [9800 9801 9802 ..., 9897 9898 9899]
   [9900 9901 9902 ..., 9997 9998 9999]]]]
<NDArray 1x1x100x100 @cpu(0)>

For illustration purposes our feature map, is numbered from 0 to 9999 top left to bottom right. This will help visualize what happens.

First we transpose to get the width axis after the batch axis. That way we basically have a "list" of the top to bottom columns values, from left to right respecting the "temporal" order. N C H W -> N W C H

x = x.transpose((0,3,1,2))
x
[[[[   0  100  200 ..., 9700 9800 9900]]

  [[   1  101  201 ..., 9701 9801 9901]]

  [[   2  102  202 ..., 9702 9802 9902]]

  ..., 
  [[  97  197  297 ..., 9797 9897 9997]]

  [[  98  198  298 ..., 9798 9898 9998]]

  [[  99  199  299 ..., 9799 9899 9999]]]]
<NDArray 1x100x1x100 @cpu(0)>

Then we flatten the lot so instead of reading elements row-wise, we have them column-wise.

x = x.flatten()
x
[[   0  100  200 ..., 9799 9899 9999]]
<NDArray 1x10000 @cpu(0)>

We split this into SEQ_LEN lists. Notice the top to bottom, left to right order of elements. Effectively the first element of this list has the first 1000 elements of the feature map in temporal order.

x = x.split(num_outputs=10, axis = 1) # (SEQ_LEN, N, CHANNELS)
x[0]
[[   0  100  200  300  400  500  600  700  800  900 1000 1100 1200 1300
  1400 1500 1600 1700 1800 1900 2000 2100 2200 2300 2400 2500 2600 2700
  2800 2900 3000 3100 3200 3300 3400 3500 3600 3700 3800 3900 4000 4100
  4200 4300 4400 4500 4600 4700 4800 4900 5000 5100 5200 5300 5400 5500
  5600 5700 5800 5900 6000 6100 6200 6300 6400 6500 6600 6700 6800 6900
  7000 7100 7200 7300 7400 7500 7600 7700 7800 7900 8000 8100 8200 8300
  8400 8500 8600 8700 8800 8900 9000 9100 9200 9300 9400 9500 9600 9700
  9800 9900    1  101  201  301  401  501  601  701  801  901 1001 1101
  1201 1301 1401 1501 1601 1701 1801 1901 2001 2101 2201 2301 2401 2501
  2601 2701 2801 2901 3001 3101 3201 3301 3401 3501 3601 3701 3801 3901
  4001 4101 4201 4301 4401 4501 4601 4701 4801 4901 5001 5101 5201 5301
  5401 5501 5601 5701 5801 5901 6001 6101 6201 6301 6401 6501 6601 6701
  6801 6901 7001 7101 7201 7301 7401 7501 7601 7701 7801 7901 8001 8101
  8201 8301 8401 8501 8601 8701 8801 8901 9001 9101 9201 9301 9401 9501
  9601 9701 9801 9901    2  102  202  302  402  502  602  702  802  902
  1002 1102 1202 1302 1402 1502 1602 1702 1802 1902 2002 2102 2202 2302
  2402 2502 2602 2702 2802 2902 3002 3102 3202 3302 3402 3502 3602 3702
  3802 3902 4002 4102 4202 4302 4402 4502 4602 4702 4802 4902 5002 5102
  5202 5302 5402 5502 5602 5702 5802 5902 6002 6102 6202 6302 6402 6502
  6602 6702 6802 6902 7002 7102 7202 7302 7402 7502 7602 7702 7802 7902
  8002 8102 8202 8302 8402 8502 8602 8702 8802 8902 9002 9102 9202 9302
  9402 9502 9602 9702 9802 9902    3  103  203  303  403  503  603  703
   803  903 1003 1103 1203 1303 1403 1503 1603 1703 1803 1903 2003 2103
  2203 2303 2403 2503 2603 2703 2803 2903 3003 3103 3203 3303 3403 3503
  3603 3703 3803 3903 4003 4103 4203 4303 4403 4503 4603 4703 4803 4903
  5003 5103 5203 5303 5403 5503 5603 5703 5803 5903 6003 6103 6203 6303
  6403 6503 6603 6703 6803 6903 7003 7103 7203 7303 7403 7503 7603 7703
  7803 7903 8003 8103 8203 8303 8403 8503 8603 8703 8803 8903 9003 9103
  9203 9303 9403 9503 9603 9703 9803 9903    4  104  204  304  404  504
   604  704  804  904 1004 1104 1204 1304 1404 1504 1604 1704 1804 1904
  2004 2104 2204 2304 2404 2504 2604 2704 2804 2904 3004 3104 3204 3304
  3404 3504 3604 3704 3804 3904 4004 4104 4204 4304 4404 4504 4604 4704
  4804 4904 5004 5104 5204 5304 5404 5504 5604 5704 5804 5904 6004 6104
  6204 6304 6404 6504 6604 6704 6804 6904 7004 7104 7204 7304 7404 7504
  7604 7704 7804 7904 8004 8104 8204 8304 8404 8504 8604 8704 8804 8904
  9004 9104 9204 9304 9404 9504 9604 9704 9804 9904    5  105  205  305
   405  505  605  705  805  905 1005 1105 1205 1305 1405 1505 1605 1705
  1805 1905 2005 2105 2205 2305 2405 2505 2605 2705 2805 2905 3005 3105
  3205 3305 3405 3505 3605 3705 3805 3905 4005 4105 4205 4305 4405 4505
  4605 4705 4805 4905 5005 5105 5205 5305 5405 5505 5605 5705 5805 5905
  6005 6105 6205 6305 6405 6505 6605 6705 6805 6905 7005 7105 7205 7305
  7405 7505 7605 7705 7805 7905 8005 8105 8205 8305 8405 8505 8605 8705
  8805 8905 9005 9105 9205 9305 9405 9505 9605 9705 9805 9905    6  106
   206  306  406  506  606  706  806  906 1006 1106 1206 1306 1406 1506
  1606 1706 1806 1906 2006 2106 2206 2306 2406 2506 2606 2706 2806 2906
  3006 3106 3206 3306 3406 3506 3606 3706 3806 3906 4006 4106 4206 4306
  4406 4506 4606 4706 4806 4906 5006 5106 5206 5306 5406 5506 5606 5706
  5806 5906 6006 6106 6206 6306 6406 6506 6606 6706 6806 6906 7006 7106
  7206 7306 7406 7506 7606 7706 7806 7906 8006 8106 8206 8306 8406 8506
  8606 8706 8806 8906 9006 9106 9206 9306 9406 9506 9606 9706 9806 9906
     7  107  207  307  407  507  607  707  807  907 1007 1107 1207 1307
  1407 1507 1607 1707 1807 1907 2007 2107 2207 2307 2407 2507 2607 2707
  2807 2907 3007 3107 3207 3307 3407 3507 3607 3707 3807 3907 4007 4107
  4207 4307 4407 4507 4607 4707 4807 4907 5007 5107 5207 5307 5407 5507
  5607 5707 5807 5907 6007 6107 6207 6307 6407 6507 6607 6707 6807 6907
  7007 7107 7207 7307 7407 7507 7607 7707 7807 7907 8007 8107 8207 8307
  8407 8507 8607 8707 8807 8907 9007 9107 9207 9307 9407 9507 9607 9707
  9807 9907    8  108  208  308  408  508  608  708  808  908 1008 1108
  1208 1308 1408 1508 1608 1708 1808 1908 2008 2108 2208 2308 2408 2508
  2608 2708 2808 2908 3008 3108 3208 3308 3408 3508 3608 3708 3808 3908
  4008 4108 4208 4308 4408 4508 4608 4708 4808 4908 5008 5108 5208 5308
  5408 5508 5608 5708 5808 5908 6008 6108 6208 6308 6408 6508 6608 6708
  6808 6908 7008 7108 7208 7308 7408 7508 7608 7708 7808 7908 8008 8108
  8208 8308 8408 8508 8608 8708 8808 8908 9008 9108 9208 9308 9408 9508
  9608 9708 9808 9908    9  109  209  309  409  509  609  709  809  909
  1009 1109 1209 1309 1409 1509 1609 1709 1809 1909 2009 2109 2209 2309
  2409 2509 2609 2709 2809 2909 3009 3109 3209 3309 3409 3509 3609 3709
  3809 3909 4009 4109 4209 4309 4409 4509 4609 4709 4809 4909 5009 5109
  5209 5309 5409 5509 5609 5709 5809 5909 6009 6109 6209 6309 6409 6509
  6609 6709 6809 6909 7009 7109 7209 7309 7409 7509 7609 7709 7809 7909
  8009 8109 8209 8309 8409 8509 8609 8709 8809 8909 9009 9109 9209 9309
  9409 9509 9609 9709 9809 9909]]
<NDArray 1x1000 @cpu(0)>

We then effectively transform this list of ndarrasy back into a single ndarray by adding a dummy first dimension and concatenating across it.

x = nd.concat(*[elem.expand_dims(axis=0) for elem in x], dim=0)
x.shape
(10, 1, 1000)

We have (T, N, C)

There might be a simpler way to achieve the same dimensions split and reading order but that's basically the idea. Does that help?

@ThomasDelteil ,thanks for such a concrete explanation. Now I totally understand.
So generally, if the feature maps are treated as a raw pic of shape NCHW, and we want to slice the pic VERTICALLY from left to right into SEQ_LEN columns BEFORE sending them into RNN, then we need to preproccess the pic array as you did.

And here I want to use my example to make this manipulation more clear:

import numpy as np
from mxnet import nd

# x = nd.arange(0, 100 * 100).reshape(1, 100, 100).repeat(1, axis=0).astype(np.int64).expand_dims(axis=0) 
batch_size = 3
channel_num = 2
height_size = 5
width_size = 10
x = nd.arange(0, 300).reshape(batch_size, channel_num, height_size, width_size).astype(np.int64)
x

here the output would be:

[[[[  0   1   2   3   4   5   6   7   8   9]
   [ 10  11  12  13  14  15  16  17  18  19]
   [ 20  21  22  23  24  25  26  27  28  29]
   [ 30  31  32  33  34  35  36  37  38  39]
   [ 40  41  42  43  44  45  46  47  48  49]]

  [[ 50  51  52  53  54  55  56  57  58  59]
   [ 60  61  62  63  64  65  66  67  68  69]
   [ 70  71  72  73  74  75  76  77  78  79]
   [ 80  81  82  83  84  85  86  87  88  89]
   [ 90  91  92  93  94  95  96  97  98  99]]]


 [[[100 101 102 103 104 105 106 107 108 109]
   [110 111 112 113 114 115 116 117 118 119]
   [120 121 122 123 124 125 126 127 128 129]
   [130 131 132 133 134 135 136 137 138 139]
   [140 141 142 143 144 145 146 147 148 149]]

  [[150 151 152 153 154 155 156 157 158 159]
   [160 161 162 163 164 165 166 167 168 169]
   [170 171 172 173 174 175 176 177 178 179]
   [180 181 182 183 184 185 186 187 188 189]
   [190 191 192 193 194 195 196 197 198 199]]]


 [[[200 201 202 203 204 205 206 207 208 209]
   [210 211 212 213 214 215 216 217 218 219]
   [220 221 222 223 224 225 226 227 228 229]
   [230 231 232 233 234 235 236 237 238 239]
   [240 241 242 243 244 245 246 247 248 249]]

  [[250 251 252 253 254 255 256 257 258 259]
   [260 261 262 263 264 265 266 267 268 269]
   [270 271 272 273 274 275 276 277 278 279]
   [280 281 282 283 284 285 286 287 288 289]
   [290 291 292 293 294 295 296 297 298 299]]]]
<NDArray 3x2x5x10 @cpu(0)>

Then, after the manipulation:

x = x.transpose((0, 3, 1, 2))
x = x.flatten()
x = x.split(num_outputs=10, axis=1)  # (SEQ_LEN, N, CHANNELS)
x = nd.concat(*[elem.expand_dims(axis=0) for elem in x], dim=0)
x

we get:

[[[  0  10  20  30  40  50  60  70  80  90]
  [100 110 120 130 140 150 160 170 180 190]
  [200 210 220 230 240 250 260 270 280 290]]

 [[  1  11  21  31  41  51  61  71  81  91]
  [101 111 121 131 141 151 161 171 181 191]
  [201 211 221 231 241 251 261 271 281 291]]

 [[  2  12  22  32  42  52  62  72  82  92]
  [102 112 122 132 142 152 162 172 182 192]
  [202 212 222 232 242 252 262 272 282 292]]

 [[  3  13  23  33  43  53  63  73  83  93]
  [103 113 123 133 143 153 163 173 183 193]
  [203 213 223 233 243 253 263 273 283 293]]

 [[  4  14  24  34  44  54  64  74  84  94]
  [104 114 124 134 144 154 164 174 184 194]
  [204 214 224 234 244 254 264 274 284 294]]

 [[  5  15  25  35  45  55  65  75  85  95]
  [105 115 125 135 145 155 165 175 185 195]
  [205 215 225 235 245 255 265 275 285 295]]

 [[  6  16  26  36  46  56  66  76  86  96]
  [106 116 126 136 146 156 166 176 186 196]
  [206 216 226 236 246 256 266 276 286 296]]

 [[  7  17  27  37  47  57  67  77  87  97]
  [107 117 127 137 147 157 167 177 187 197]
  [207 217 227 237 247 257 267 277 287 297]]

 [[  8  18  28  38  48  58  68  78  88  98]
  [108 118 128 138 148 158 168 178 188 198]
  [208 218 228 238 248 258 268 278 288 298]]

 [[  9  19  29  39  49  59  69  79  89  99]
  [109 119 129 139 149 159 169 179 189 199]
  [209 219 229 239 249 259 269 279 289 299]]]
<NDArray 10x3x10 @cpu(0)>

Here, as there are 3 batches with 2 channels of the whole feature map, then each column should have 10 elements which is consisted of 5 element from each feature map. And it is also clear to see that the elements are ordered by feature map, which means the elements of ONE feature map should be combined together, then the seonde map, the third.... until the last one.

So here the manipulation is totally CORRECT.

But, is there a simpler way to do such a manipulation?

Test 1

x = x.transpose((0, 3, 1, 2)) # to vertically slice, the transpose is essential~
x = x.reshape(10, batch_size, -1)
x

The result is:

[[[  0  10  20  30  40  50  60  70  80  90]
  [  1  11  21  31  41  51  61  71  81  91]
  [  2  12  22  32  42  52  62  72  82  92]]

 [[  3  13  23  33  43  53  63  73  83  93]
  [  4  14  24  34  44  54  64  74  84  94]
  [  5  15  25  35  45  55  65  75  85  95]]

 [[  6  16  26  36  46  56  66  76  86  96]
  [  7  17  27  37  47  57  67  77  87  97]
  [  8  18  28  38  48  58  68  78  88  98]]

 [[  9  19  29  39  49  59  69  79  89  99]
  [100 110 120 130 140 150 160 170 180 190]
  [101 111 121 131 141 151 161 171 181 191]]

 [[102 112 122 132 142 152 162 172 182 192]
  [103 113 123 133 143 153 163 173 183 193]
  [104 114 124 134 144 154 164 174 184 194]]

 [[105 115 125 135 145 155 165 175 185 195]
  [106 116 126 136 146 156 166 176 186 196]
  [107 117 127 137 147 157 167 177 187 197]]

 [[108 118 128 138 148 158 168 178 188 198]
  [109 119 129 139 149 159 169 179 189 199]
  [200 210 220 230 240 250 260 270 280 290]]

 [[201 211 221 231 241 251 261 271 281 291]
  [202 212 222 232 242 252 262 272 282 292]
  [203 213 223 233 243 253 263 273 283 293]]

 [[204 214 224 234 244 254 264 274 284 294]
  [205 215 225 235 245 255 265 275 285 295]
  [206 216 226 236 246 256 266 276 286 296]]

 [[207 217 227 237 247 257 267 277 287 297]
  [208 218 228 238 248 258 268 278 288 298]
  [209 219 229 239 249 259 269 279 289 299]]]
<NDArray 10x3x10 @cpu(0)>

Notice here the 2nd column of the 1st batch is [ 3 13 23 33 43 53 63 73 83 93], which is WRONG~ so this modification should be kicked out.

Test 2

x = x.swapaxes(1, 3)
x = x.reshape(10, batch_size, -1)
x

What if we just swap axes and do reshaping?

[[[  0  50  10  60  20  70  30  80  40  90]
  [  1  51  11  61  21  71  31  81  41  91]
  [  2  52  12  62  22  72  32  82  42  92]]

 [[  3  53  13  63  23  73  33  83  43  93]
  [  4  54  14  64  24  74  34  84  44  94]
  [  5  55  15  65  25  75  35  85  45  95]]

 [[  6  56  16  66  26  76  36  86  46  96]
  [  7  57  17  67  27  77  37  87  47  97]
  [  8  58  18  68  28  78  38  88  48  98]]

 [[  9  59  19  69  29  79  39  89  49  99]
  [100 150 110 160 120 170 130 180 140 190]
  [101 151 111 161 121 171 131 181 141 191]]

 [[102 152 112 162 122 172 132 182 142 192]
  [103 153 113 163 123 173 133 183 143 193]
  [104 154 114 164 124 174 134 184 144 194]]

 [[105 155 115 165 125 175 135 185 145 195]
  [106 156 116 166 126 176 136 186 146 196]
  [107 157 117 167 127 177 137 187 147 197]]

 [[108 158 118 168 128 178 138 188 148 198]
  [109 159 119 169 129 179 139 189 149 199]
  [200 250 210 260 220 270 230 280 240 290]]

 [[201 251 211 261 221 271 231 281 241 291]
  [202 252 212 262 222 272 232 282 242 292]
  [203 253 213 263 223 273 233 283 243 293]]

 [[204 254 214 264 224 274 234 284 244 294]
  [205 255 215 265 225 275 235 285 245 295]
  [206 256 216 266 226 276 236 286 246 296]]

 [[207 257 217 267 227 277 237 287 247 297]
  [208 258 218 268 228 278 238 288 248 298]
  [209 259 219 269 229 279 239 289 249 299]]]
<NDArray 10x3x10 @cpu(0)>

It is WRONG, too. What is funny here is that the elements of each column are combined in a different order. Anyway, it is wrong.

So , it seems there is no simpler way to do the manipulation. Is there one there? If no, I think it a good advice to modify MXNET reshape function to support such a little complex manipulation.

Anyway, thanks your concrete explanation!!

No worries, thanks for trying out simpler solutions 👍. I agree that it would be great if there was a cleaner API to perform this reshaping operation.