prasunroy/stefann

How many epochs does it take to converge either network

ThomasDelteil opened this issue · 4 comments

Hi,

Thanks for sharing your code. I am trying to reproduce your results, how many epochs does it take to converge to results similar to what you shared in your paper ?

Thanks

Hi,

Thanks for your interest in our work. During training, we set the number of epochs to 1000 for both models. It took about 2 days to train each model on a single NVIDIA TITAN X gpu. After training, we took the best weights with lowest validation loss.

NOTE : Though we did not use early stopping while training the models, we recommend to include it as a callback which might help to reduce training time significantly.

Hi @prasunroy,

I am trying to reproduce your results. During training, I found the training loss kept decreasing. However, the validation loss decreased a bit and then started to increase. May I know if you met this problem during your training? Thanks.

This is the log file:

epoch,loss,val_loss
0,25.873983780795303,24.720771616555044
1,21.214732255166023,24.215766965136115
2,19.5305793424886,24.196309418226956
3,18.712535324936503,24.275966902367863
4,18.146149668061067,24.374509240923594
5,17.724842212435494,24.453728989652156
6,17.391642734110228,24.61247360260889
7,17.121832768067257,24.62370510996124
8,16.892489922195285,24.729061182281118
9,16.7110024330354,24.8554647670165
10,16.53455253376735,24.82867559362341
11,16.394012556903693,24.864671244680146
12,16.26253899460196,24.97646004539458
13,16.14363468705843,24.860293950210384
14,16.04869938655353,25.060472141469948
15,15.947671655328838,25.00020275539822
16,15.860044397464208,25.148321718757536
17,15.775175130818944,25.084358497902198
18,15.699415777072211,25.23869942182376
19,15.625746090766077,25.22325478573395
20,15.56532149765821,25.17339846638495
21,15.49498631669274,25.15990415879238
22,15.441376395870181,25.262783289465883
23,15.387716065353267,25.268257591341747
24,15.340794180663755,25.454775524060913
25,15.292626717475864,25.305134811401366
26,15.24900266108762,25.348525802863477
27,15.209791906206362,25.482600516174067
28,15.169450110080707,25.418516927648472
29,15.133638902868457,25.444777194364573
30,15.092550883809052,25.50373841368122

I'm assuming you are trying to train FANNet. In my case I did not encounter such issue. It seems to overfit the data heavily. Make sure you are training and validating on correct data. Initializing weights or fine-tuning pretrained weights may be helpful as well. Check this notebook for how to train FANNet.

Hi @prasunroy ,

Yes, I am trying to train FANNet. Previously, I refer to this section https://github.com/prasunroy/stefann#training-networks for training the FANNet. After reading your reply, I downloaded the dataset from the kaggle notebook, and retrained the network. Please see the loss log below. I also trained FANNet for editing numbers by changing:

SOURCE_CHARS = '0123456789'
TARGET_CHARS = '0123456789'

and observed a similar loss trend.

Did you use any initializing weights? I am wondering if you can share a training log with me.
Here are some hyperparameters I used for my training:

bach_size=64, learning_rate=1e-3, num_epochs=1000, random_seed=99999999

Thank you very much!

-------------------------------loss log----------------------------:
epoch,loss,val_loss
0,26.26617855483769,24.714188140176926
1,22.092690324656843,23.90586946521285
2,20.228913092326202,23.823645403709637
3,18.52181828288936,23.419035493899614
4,17.856711517101814,23.36808690193372
5,17.376590267742586,23.59961173815605
6,16.98921621295223,23.582613836134207
7,16.681834468229436,23.679680378780326
8,16.402573309690528,23.820174634132158
9,16.168961282607178,23.710408119939018
10,15.970573168615115,23.85237797558425
11,15.791060625273605,23.83962789065269
12,15.64507364414898,23.981840452988003
13,15.501352553170664,24.06868635535005
14,15.366264881330066,24.018727514052532
15,15.233981397413459,24.150950007391636
16,15.13249607948483,24.147842947303428
17,15.029789179467445,24.119907183825852
18,14.941484995366713,24.26204340887728
19,14.858203082390853,24.081106002175595
20,14.777942465712865,24.05389933749769
21,14.703543624085123,24.1952846188517
22,14.639047020793111,24.14757335279115
23,14.60939291607313,24.358510151097526
24,14.515080870801457,24.472659191455126
25,14.46598849499358,24.435821577745546
26,14.402234116554622,24.183679646900185
27,14.357933266363645,24.43247356031068
28,14.304078304327353,24.323559466781465
29,14.521008158837862,24.435517751930732
30,14.282474018938322,24.3830908538698
31,14.18283783984406,24.49984305011215
32,14.137610166087509,24.588876989992883
33,14.102619508340823,24.5291731230606
34,14.071577536101787,24.669347497311804
35,14.030894595484845,24.467897058370315
36,13.994722770747403,24.445620752217973
37,13.963140455469414,24.81086515227248
38,13.978784107261959,24.54192369871121
39,13.909097001882477,24.625660411977673
40,13.87162385008699,24.525976675511348
41,13.838730952337128,24.731286375508507
42,13.81940902624676,24.786532296626525
43,13.77954439393433,24.626132281452005
44,13.759489284302669,24.65723967918983
45,13.735157028712177,24.950002682025616
46,13.701066372251193,24.668061384844357
47,13.890552218062446,24.714612914316753
48,13.69564190964442,24.766836412869967
49,13.641727469998544,24.858267988998744
50,13.619227478324675,24.859715143237594
51,13.606633979435319,24.789710685684835
52,13.596047972148387,24.94700794889612
53,13.63080885655281,24.811601585222654
54,13.542940427526798,24.951465888540657
55,13.519473331282514,25.268309912295972
56,13.513307161320336,24.935002160909377
57,13.645894066748923,25.156722570265067