A lightweight utility for training multiple Keras models in parallel and comparing their final loss and last-epoch time.
-
ParallelFinder
- Accepts a list of zero-argument functions, each returning a compiled
keras.Model. - Spawns one
multiprocessing.Processper model to train simultaneously. - Uses a shared, process-safe dictionary (
logs) to record per-model and global “best” metrics.
- Accepts a list of zero-argument functions, each returning a compiled
-
FinderCallback
- Measures the duration of each epoch.
- On the final epoch, writes the model’s final loss and epoch time into
ParallelFinder.logs. - Updates global best-loss and best-time entries if surpassed.
-
Define Model Constructors
Each function should return a compiled Keras model:
from tensorflow import keras def build_dense(): model = keras.Sequential([ keras.Input(shape=(784,)), keras.layers.Dense(128, activation="relu"), keras.layers.Dense(10, activation="softmax"), ]) model.compile(optimizer="adam", loss="categorical_crossentropy") return model def build_cnn(): model = keras.Sequential([ keras.Input(shape=(784,)), keras.layers.Reshape((28, 28, 1)), keras.layers.Conv2D(16, (3, 3), activation="relu"), keras.layers.Flatten(), keras.layers.Dense(10, activation="softmax"), ]) model.compile(optimizer="adam", loss="categorical_crossentropy") return model
-
Instantiate and Run
from parallel_finder import ParallelFinder from tensorflow.keras.datasets import mnist from tensorflow.keras.utils import to_categorical # Load data (x_train, y_train), _ = mnist.load_data() x_train = x_train.reshape(-1, 784).astype("float32") / 255 y_train = to_categorical(y_train, 10) # List of model-building functions model_constructors = [build_dense, build_cnn] # Create finder finder = ParallelFinder(model_constructors) # Train each model for 5 epochs, batch size 64 finder.find( train_data=x_train, train_labels=y_train, epochs=5, batch_size=64 ) # Inspect results for key, val in finder.logs.items(): print(f"{key}: {val}")
-
Interpret
finder.logs-
Per-model entries (for N models):
model_0_loss,model_0_time, ...,model_{N-1}_loss,model_{N-1}_time
-
Global best entries:
best_loss→ smallest final lossbest_loss_model_idx→ index of model withbest_losstime_for_best_loss→ epoch time for that modelbest_time→ shortest final-epoch timebest_time_model_idx→ index of model withbest_timeloss_for_best_time→ loss corresponding tobest_time
-
-
GPU Usage: If multiple processes share a GPU, configure each model to allow memory growth (e.g.,
tf.config.experimental.set_memory_growth). -
Data Overhead: Large datasets are pickled for each process—consider lightweight datasets or shared memory solutions if this becomes a bottleneck.
-
Customization:
- Modify
FinderCallback.on_epoch_endto track validation metrics (e.g.,'val_loss'). - Pass additional
fit(...)arguments via**kw_fit(e.g.,validation_data=(x_val, y_val)).
- Modify