This a implementation of vgg16 with tensorflow and python, study-oriented.
vgg16 is an important convolutional neural network posted by Karen Simonyan. The link of their essay is here below:
https://arxiv.org/abs/1409.1556
To train vgg16 with this project, you only need to provide a path. The path should contain a file named "model.json" and this JSON file contains all information to train a vgg16 network with this project.
A sample model.json is shown below:
{
"learning_rate":0.01,
"momentum":0.9,
"batchsize":8,
"batches":70,
"channel":3,
"classes":2,
"classnamelist":"labels.txt",
"trainlist":"train.txt",
"labellist":"labellist.txt"
}
The classnamelist is a file that contains all names of classes. The number of classes should be equal to the number of classes that provides in "model.json".
The trainlist file contains the absolute path to all pictures in the training set.
The labellist file contains the index of class of a picture, corresponding to the index of pictures in the trainlist.
To train a vgg16 using this project:
#!coding=utf-8
from src.vgg16 import Vgg16
def main():
#init a new Vgg16 instance
net = Vgg16()
#load json file
net.loadWithUntrainedJson(srcDir="./example")
#train CNN, the model will be deployed in tgtDir
net.train(tgtDir='./example')
if __name__ == '__main__':
main()
We provide a predict API to predict from a image list. The API is shown below:
def predict(self, tgtDir, imageListFile)
Before you use the model to predict, the model should be loaded with the following API:
def loadWithTrainedJson(self, modelDir)
The modelDir contains the model trained with this project or created by the users themselves.
The imageListFile is the list of images to be predicted. Each line of the file contains an address of an image(we encourage absolute address).
At last, the result will dump to result.txt to tgtDir.
Further question, please contact wyc8094@gmail.com