ohmeow/blurr

How to make predictions on a batch?

Closed this issue · 8 comments

Hi,

I'm currently facing some errors when trying to generate predictions on a batch using fastai functions such as dls.test_dl() to create the test dataloaders and then learn.get_preds() to get all the predictions. I searched through the documentation but couldn't find any examples of batch prediction. Could you show me an example of how to do it.?

Update:

I just saw the example of batch inference in the Question Answering example, and managed to extend it to Classification tasks.

However, for tasks such as Token Classification and Summarization, I could create the test_dl but the faced this error when calling learn.get_preds(dl=tesl_dl):

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-15-289917d065fe> in <module>
----> 1 learn.get_preds(dl=test_dl)

/opt/conda/envs/fastai/lib/python3.8/site-packages/fastai/learner.py in get_preds(self, ds_idx, dl, with_input, with_decoded, with_loss, act, inner, reorder, cbs, **kwargs)
    246             pred_i = 1 if with_input else 0
    247             if res[pred_i] is not None:
--> 248                 res[pred_i] = act(res[pred_i])
    249                 if with_decoded: res.insert(pred_i+2, getattr(self.loss_func, 'decodes', noop)(res[pred_i]))
    250             if reorder and hasattr(dl, 'get_idxs'): res = nested_reorder(res, tensor(idxs).argsort())

/opt/conda/envs/fastai/lib/python3.8/site-packages/fastai/losses.py in activation(self, x)
     43     def __init__(self, *args, axis=-1, **kwargs): super().__init__(nn.CrossEntropyLoss, *args, axis=axis, **kwargs)
     44     def decodes(self, x):    return x.argmax(dim=self.axis)
---> 45     def activation(self, x): return F.softmax(x, dim=self.axis)
     46 
     47 # Cell

/opt/conda/envs/fastai/lib/python3.8/site-packages/torch/nn/functional.py in softmax(input, dim, _stacklevel, dtype)
   1510         dim = _get_softmax_dim('softmax', input.dim(), _stacklevel)
   1511     if dtype is None:
-> 1512         ret = input.softmax(dim)
   1513     else:
   1514         ret = input.softmax(dim, dtype=dtype)

/opt/conda/envs/fastai/lib/python3.8/site-packages/fastcore/basics.py in __getattr__(self, k)
    386         if self._component_attr_filter(k):
    387             attr = getattr(self,self._default,None)
--> 388             if attr is not None: return getattr(attr,k)
    389         raise AttributeError(k)
    390     def __dir__(self): return custom_dir(self,self._dir())

AttributeError: 'list' object has no attribute 'softmax'

Sorry to disturb you, but could you help me take a look at this? @ohmeow

fastai makes some assumptions about how results look/work that won't work for things like token classification or summarization (which is essentially the same thing) out of the box. Try passing this activation function into get_preds:

def my_act(x): 
    return  L([ F.softmax(i, dim=-1) for i in x ])

... like this ...

res = learn.get_preds(act=my_act)

Let me know how that works, and if you're able to get batch inferencing working with token classification. I think the same custom activation function might work with summarization as well since the task is essentially the same, i.e., predict a bunch of tokens.

Oh, is it because for those tasks the final tensors are not of the same length, so fastai stack them all into a list instead of a 'tensor`?

Thanks, it seems to be working!

Hi @ohmeow, sorry again, now because it is a list, using with_decoded=True results in this error. Any simple ways to fix it?

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-131-ac518f356e89> in <module>
----> 1 learn.get_preds(dl=test_dl, act=my_act, with_decoded=True)

/opt/conda/envs/fastai/lib/python3.8/site-packages/fastai/learner.py in get_preds(self, ds_idx, dl, with_input, with_decoded, with_loss, act, inner, reorder, cbs, **kwargs)
    247             if res[pred_i] is not None:
    248                 res[pred_i] = act(res[pred_i])
--> 249                 if with_decoded: res.insert(pred_i+2, getattr(self.loss_func, 'decodes', noop)(res[pred_i]))
    250             if reorder and hasattr(dl, 'get_idxs'): res = nested_reorder(res, tensor(idxs).argsort())
    251             return tuple(res)

/opt/conda/envs/fastai/lib/python3.8/site-packages/fastai/losses.py in decodes(self, x)
     42     @use_kwargs_dict(keep=True, weight=None, ignore_index=-100, reduction='mean')
     43     def __init__(self, *args, axis=-1, **kwargs): super().__init__(nn.CrossEntropyLoss, *args, axis=axis, **kwargs)
---> 44     def decodes(self, x):    return x.argmax(dim=self.axis)
     45     def activation(self, x): return F.softmax(x, dim=self.axis)
     46 

/opt/conda/envs/fastai/lib/python3.8/site-packages/fastcore/basics.py in __getattr__(self, k)
    386         if self._component_attr_filter(k):
    387             attr = getattr(self,self._default,None)
--> 388             if attr is not None: return getattr(attr,k)
    389         raise AttributeError(k)
    390     def __dir__(self): return custom_dir(self,self._dir())

AttributeError: 'list' object has no attribute 'argmax'

Yah if you look at CrossEntropyLossFlat you'll see this code:

def decodes(self, x):    return x.argmax(dim=self.axis)

Because x is a list, you're getting the error you are. One idea is to write a custom loss function, that operates like CrossEntropyLoss but includes an activation and decodes function that will work properly for tasks like this where you are predicting multiple tokens, and even a different number of tokens each batch. I already provided the activation function, now we just need a decodes inside a subclass of CrossEntropyLossFlat.

Give it a shot if you want, else I'll take a look at putting something together tomorrow ... but the loss function should look something like this:

class TokenCrossEntropyLossFlat(CrossEntropyLossFlat):
    def decodes(self, x): return  L([i.argmax(dim=self.axis) for i in x ])
    def activation(self, x): return  L([ F.softmax(i, dim=-1) for i in x ])

Not 100% sure the above is right, but it should give you an idea of what the decodes should look like if you want to give it a try. Lmk.

Hi @ohmeow, the new loss function idea works, I just need to modify a few lines in blurr_predict and both single inference and batch inference are working fine now.

Now I'm looking for a way to apply blurr_predict_tokens on a batch for the Token Classification task. Do you have any ideas or suggestions?

Oh nvm, I managed to use blurr_predict to make batch inference by passing in a list of inputs, and so does blurr_predict_tokens. So I guess get_preds is no longer necessary. Thank you very much for your help!

Ok cool. Yah the blurr predict methods take a list of inputs ... so you can definitely use them for batch inference (see blurr_predict_tokens for an example).

If you wanted, you could create a blurr_get_token_preds and something similar for summarization that would replace the standard get_preds with something that works with transformers. You're actually pretty close with the above. Another option, and the one I usually use in doing batch inference, is simply to loop through my test DataLoader and call into the model directly. Here's a multilabel classification model batch inference code I used with blurr:

inf_learn = load_learner(fname=learner_export_path, cpu=cpu)
inf_learn.model = inf_learn.model.to(device)
inf_learn.model = inf_learn.model.eval()

inf_dl = inf_learn.dls.test_dl(inf_df, rm_type_tfms=None, bs=16)

test_probs = []
with torch.no_grad():
    for index, b in enumerate(inf_dl):
        if index % 1000 == 0:  print(index)
            # note: even though there is no targets, each batch is a tuple!
            probs = torch.sigmoid(inf_learn.model(b[0])[0])

            # why "detach"? the computation of gradients wrt the weights of netG can be fully
            # avoided in the backward pass if the graph is detached where it is.
            test_probs.append(to_detach(probs))

all_probs = L(torch.cat(test_probs))

# ensure results are returned in order
# test_dl.get_idxs() => unsorted/original order items
all_probs = all_probs[0][np.argsort(inf_dl.get_idxs())]

I actually like this approach because it's 1) faster than get_preds and 2) clearer as to what is going on as I work directly with the predictions I get back from the model. Sure its a bit more code, but when it comes to debugging and looking back at your work a month or two later ... I think it's a bit easier to decipher than the magic happening in get_preds.