BloodAxe/pytorch-toolbelt

How to Implement TTA For binary segmentation

chefkrym opened this issue · 3 comments

Anyone kind enough to share a code on how to use TTA for binary segmentation using this code?
I have my trained model weights but can't figure out how to use Pytroch toolbelt.

Thank you.

Assuming you have the model that takes [B, Cin, H, W] and outputs logits a single tensor of shape [B, Cout, H, W], where Cout is 1 for binary segmentation, but can be Cout>1 as well.

from pytorch_toolbelt.inference import tta

model = nn.Sequential(model, nn.Sigmoid()) # Apply sigmoid activation to logits predictions
model = tta.GeneralizedTTA(model, augment_fn=tta.fliplr_image_augment, deaugment_fn=tta.fliplr_image_deaugment)

You may or may not need to apply a sigmoid externally if your model already does it. Here it's more for a reference.
After wrapping your model into tta.GeneralizedTTA that's it. You simply run inference as you normally would and TTA would be done for you inside. This class is even torch jit traceable, so you can export this model if you need.

Thank you so very much @BloodAxe I am really really grateful. I'm rather new to this and unsure as to how to wrap my model. I share my code below and request for your kind guidance? I wanted to apply TTA to just one test image (last block of cells of my code).
Thank you sir.

https://colab.research.google.com/drive/1xJ62lFGlaVpbw6WPFfZCAR8_IKX4jclR?usp=sharing

The best way to learn new things is try out understanding how it works. Take a look at the code, corresponding tests which should give you nice intuition how it works and how to use it in your case.