A PyTorch implementation of Xception
This repository is a PyTorch reimplementation of Xception, and almost is an op-to-op translation from the official implementation. Moreover, we provide a function to convert the official TensorFlow pretrained weights(which can be download in here) to PyTorch weights, hence it is very convenient to infer or finetune your own datasets.
As mentioned in the official version, the Xception implemented here made a few more changes:
-
Fully convolutional: All the max-pooling layers are replaced with separable conv2d with stride = 2. This allows us to use atrous convolution to extract feature maps at any resolution.
-
We support adding ReLU and BatchNorm after depthwise convolution, motivated by the design of MobileNetv1.
At the moment, you can easily:
- Load pretrained Xception models
- Use Xception models for classification or feature extraction.
First, you need to download the official pretrained weights at the bottom of the page. There are three pretrained models: xception_xx_imagenet, where xx is one of [41, 65, 71]. Then, run the following command:
python3 xception_test.py --tf_checkpoint_path "xxxx.....xxx/model.ckpt" --model_name "xception_xx"
You will find a new created folder 'pretrained_models' where the output pytorch model file is stored, and print a few lines in console like this (if model_name is not specified or model_name == xception_65):
TensorFlow predicion:
[286]
[[279 288 282 283 286]]
PyTorch prediction:
[286]
[[279 288 282 283 286]]
Save model to: ./pretrained_models/xception_65.pth
Load pretrained weights successfully.
PyTorch prediction:
[286]
[[279 288 282 283 286]]
Load a Xception:
import xception
xception_65 = xception.xception_65(pretrained=False)
Load a pretrained Xception:
import xception
xception_65 = xception.xception_65(pretrained=True)
In this case, num_classes
must be specified, like this:
import xception
model = xception.xception_65(num_classes=8, pretrained=True)
In this case, please set the keywords num_classes=None, global_pool=False
:
import xception
model = xception.xception_65(num_classes=None, global_pool=False, pretrained=True)