This is an example which adopts torchsample package to implement data augmentation. This package provides many data augmentation methods such as rotation, zoom in or out.
The CIFAR-10 classification task is used to show how to utilize this package to implement data augmentation.
python main.py
and
python advanced_main.py
Standard method: (random horizontal flip data augmentation.)
import torchvision.transforms as transforms
import torchsample as ts
train_tf= transforms.Compose([
transforms.RandomHorizontalFlip(), # data augmentation: random horizontal flip
transforms.Scale(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
Adding rotation data augmentation:
import torchvision.transforms as transforms
import torchsample as ts
train_tf= transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.Scale(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
ts.transforms.Rotate(20), # data augmentation: rotation
ts.transforms.Rotate(-20), # data augmentation: rotation
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
The final accuracy is shown as follows:
Network | Baselines | Data Augmentation/(Rotation) |
---|---|---|
AlexNet (No pretrained) |
The torchsample is a very awesome package implemented by Nick Cullen.