ohmeow/blurr

How to use BatchLossFilter callback with blurr

Opened this issue · 2 comments

I've been trying to experiment with using tsai's BatchLossFilter callback. If I try to run the training with this callback

model = HF_BaseModelWrapper(hf_model)

learn = Learner(dls, 
                model,
                loss_func=LabelSmoothingCrossEntropyFlat(),
                metrics=[accuracy],
                cbs=[HF_BaseModelCallback],
                splitter=hf_splitter).to_fp16()

learn.unfreeze()

cbs = [BatchLossFilter(loss_perc=0.4)]

learn.fit_one_cycle(
    3,
    lr_max=3e-5,
    cbs = cbs
)

I get the following error, which is due to the SequenceClassifierOutput object from huggingface

---------------------------------------------------------------------------

AttributeError                            Traceback (most recent call last)

<ipython-input-13-ce9f3d20977f> in <module>()
      2     3,
      3     lr_max=3e-5,
----> 4     cbs = cbs
      5 )

18 frames

/usr/local/lib/python3.7/dist-packages/fastai/losses.py in __call__(self, inp, targ, **kwargs)
     32         if self.floatify and targ.dtype!=torch.float16: targ = targ.float()
     33         if targ.dtype in [torch.int8, torch.int16, torch.int32]: targ = targ.long()
---> 34         if self.flatten: inp = inp.view(-1,inp.shape[-1]) if self.is_2d else inp.view(-1)
     35         return self.func.__call__(inp, targ.view(-1) if self.flatten else targ, **kwargs)
     36 

AttributeError: 'SequenceClassifierOutput' object has no attribute 'view'

Is there any way I can adapt BatchLossFilter to be functional with blurr? I haven't had any issues using other callbacks with blurr

First, cool library ... wasn't aware of that one!

Second, it would be helpful you can post a gist I can run so I can see full stack trace here. The inp is expecting a tensor, but in the case of Blurr (and HuggingFace) what is output at this point is an object with a bunch of info such as loss, etc.... With callbacks, I'm sure this can be altered to work with Blurr/HF.

Btw, any particular reason you're using BatchLoss? Just curious :)

Here is a gist using the BatchLossFilter callback with the standard blurr training script. I'm currently experimenting with BatchLossFilter since it had some traction on twitter a while back, plus intuitively, focusing on the harder examples could potentially improve performance, so extra tools in the toolbox never hurts 😄

i was able using the example implementation with other forms of data (images, timeseries etc.), so it looks to be an issue specific to huggingface if that helps