How to use BatchLossFilter callback with blurr
Opened this issue · 2 comments
I've been trying to experiment with using tsai's BatchLossFilter callback. If I try to run the training with this callback
model = HF_BaseModelWrapper(hf_model)
learn = Learner(dls,
model,
loss_func=LabelSmoothingCrossEntropyFlat(),
metrics=[accuracy],
cbs=[HF_BaseModelCallback],
splitter=hf_splitter).to_fp16()
learn.unfreeze()
cbs = [BatchLossFilter(loss_perc=0.4)]
learn.fit_one_cycle(
3,
lr_max=3e-5,
cbs = cbs
)
I get the following error, which is due to the SequenceClassifierOutput
object from huggingface
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-13-ce9f3d20977f> in <module>()
2 3,
3 lr_max=3e-5,
----> 4 cbs = cbs
5 )
18 frames
/usr/local/lib/python3.7/dist-packages/fastai/losses.py in __call__(self, inp, targ, **kwargs)
32 if self.floatify and targ.dtype!=torch.float16: targ = targ.float()
33 if targ.dtype in [torch.int8, torch.int16, torch.int32]: targ = targ.long()
---> 34 if self.flatten: inp = inp.view(-1,inp.shape[-1]) if self.is_2d else inp.view(-1)
35 return self.func.__call__(inp, targ.view(-1) if self.flatten else targ, **kwargs)
36
AttributeError: 'SequenceClassifierOutput' object has no attribute 'view'
Is there any way I can adapt BatchLossFilter
to be functional with blurr? I haven't had any issues using other callbacks with blurr
First, cool library ... wasn't aware of that one!
Second, it would be helpful you can post a gist I can run so I can see full stack trace here. The inp
is expecting a tensor, but in the case of Blurr (and HuggingFace) what is output at this point is an object with a bunch of info such as loss, etc.... With callbacks, I'm sure this can be altered to work with Blurr/HF.
Btw, any particular reason you're using BatchLoss
? Just curious :)
Here is a gist using the BatchLossFilter
callback with the standard blurr training script. I'm currently experimenting with BatchLossFilter
since it had some traction on twitter a while back, plus intuitively, focusing on the harder examples could potentially improve performance, so extra tools in the toolbox never hurts 😄
i was able using the example implementation with other forms of data (images, timeseries etc.), so it looks to be an issue specific to huggingface if that helps