Wrappers around the Sequential and Functional API of Keras
by Alexander Braekevelt.
from keras_wrappers import SequentialWrapper, ModelWrapper
class MyModel(SequentialWrapper):
def __init__(self):
model = super().__init__(name='my_model')
model.add(Dense(32, input_dim=784))
model.add(Activation('softmax'))
model.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy'])
@Override
def preprocess_x(self, data):
# Optional preprocessing
return super().preprocess_x(data)
@Override
def preprocess_y(self, data):
# Optional preprocessing
return super().preprocess_y(data)
@Override
def postprocess(self, data):
# Optional postprocessing
return super().postprocess(data)
my_model = MyModel()
class MyModel(ModelWrapper):
def __init__(self):
a = Input(shape=(32,))
b = Dense(32)(a)
model = super().__init__(inputs=a, outputs=b, name='my_model')
model.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy'])
@Override
def preprocess_x(self, data):
# Optional preprocessing
return super().preprocess_x(data)
@Override
def preprocess_y(self, data):
# Optional preprocessing
return super().preprocess_y(data)
@Override
def postprocess(self, data):
# Optional postprocessing
return super().postprocess(data)
my_model = MyModel()
Training saves epochs (if not interrupted) and applies both types of preprocessing.
my_model.train(x, y)
from keras.preprocessing.image import ImageDataGenerator
generator = ImageDataGenerator(
width_shift_range=0.1,
height_shift_range=0.1,
horizontal_flip=True
)
my_model.train_generator(generator, x, y, batch_size=64, epochs=5)
Plots loss and accuracy of all epochs combined.
my_model.plot_history(log_y=False)
Predicting applies preprocessing and postprocessing.
prediction = my_model.predict_one(single_y)
predictions = my_model.predict_all(multiple_y)
Saves both the model weights and the history.
my_model.save_model()
Restores the model weights and the history. (Requires same model architecture.)
my_model.load_model('my_model_35_epochs')
All methods of the original model also exist for the wrappers.