mgrankin/fast_tabnet

Cannot apply "EarlyStoppingCallback" and "SaveModelCallback" in fast_tabnet

zjgbz opened this issue · 3 comments

zjgbz commented

Hi,

I tried to apply "EarlyStoppingCallback" and "SaveModelCallback" in fast_tabnet regressor as the following,

cycle_num = 10
model = TabNetModel(emb_szs, len(to.cont_names), dls.c, n_d=64, n_a=64, n_steps=5, virtual_batch_size=256)
opt_func = partial(Adam, wd=0.01, eps=1e-5)
learn = Learner(dls, model, MSELossFlat(), opt_func=opt_func, lr=1e-2, metrics = rmse,
                path=Path('../model/fast_tabnet/'),
                callback_fns=[
                    partial(SaveModelCallback, monitor ='val_loss', every='epoch', mode ='min', name = f'emb_szs_{cycle_num}'),
                    partial(EarlyStoppingCallback, monitor='val_loss', patience=2) 
                ]
               )
learn.lr_find()
_SuggestedLRs(lr_min=0.06309573650360108, lr_steep=1.9054607491852948e-06)_ (this is the output of learn.lr_find())
learn.fit_one_cycle(4, 0.06309573650360108)

However, after learn.fit_one_cycle(), in the path I specified, I cannot find any model saved, and it does not early stop even though there are more than two iterations that validation error increases consecutively. May I have your suggestions? Besides, I select the lr_min as the lr used in learn.fit_one_cycle(), is that correct? Thank you very much.

Hello,

Your questions are related more to the fastai library itself. Tabnet is just an architecture. Fastai is responsible for Learner and lr_finder. You can check their forum, it's welcoming and full of people who might help you https://forums.fast.ai

it should be monitor ='valid_loss', not 'val_loss', I guess. https://docs.fast.ai/callback.tracker.html#EarlyStoppingCallback

zjgbz commented

it should be monitor ='valid_loss', not 'val_loss', I guess. https://docs.fast.ai/callback.tracker.html#EarlyStoppingCallback

It works now. Thank you very much for your help!