Compare runtimes of `atx` vs `t2`
tallamjr opened this issue · 4 comments
With the new custom callback in commit: 4c9b1be, this should allow for logging of wall-time for each epoch, as well as a mean runtime for training
@tallamjr Can you link to the commit or, even better, the lines of code that do this? I wouldn't mind taking a look.
Hi @jasonmcewen , apologies, I've updated the link above.
Anyway, the main thing is to add a custom callback to the keras
model like the one below (housed in astronet/custom_callbacks.py
) which will log the time for each epoch:
class TimeHistoryCallback(tf.keras.callbacks.Callback):
def on_train_begin(self, logs={}):
self.times = []
def on_epoch_begin(self, epoch, logs={}):
self.epoch_time_start = time.time()
def on_epoch_end(self, epoch, logs={}):
self.times.append(time.time() - self.epoch_time_start)
Then, just before I do model.fit()
inside astronet/atx/train.py
, I create an instance of it with (see comments in code):
...
time_callback = TimeHistoryCallback() # <--- Create instance here
history = model.fit(
train_input,
y_train,
batch_size=BATCH_SIZE,
epochs=self.epochs,
shuffle=True,
validation_data=(test_input, y_test),
validation_batch_size=VALIDATION_BATCH_SIZE,
verbose=False,
callbacks=[
time_callback, # <--- Call timing callback here for use later..
SGEBreakoutCallback(
threshold=44 # Stop training if running for more than threshold number of hours
),
CSVLogger(
csv_logger_file,
separator=',',
append=False,
),
...
Finally, after training, and saving the model, I print the logs and calculate the timing like so:
log.info(f"PER EPOCH TIMING: {time_callback.times}")
log.info(f"AVERAGE EPOCH TIMING: {np.array(time_callback.times).mean()}")
I was able to already get results for t2
which is reporting ~7.8minutes per epoch, atx
is running now.
Closed with #80
t2
: ---> 468.15 seconds per epoch (7.8mins)
Model: "t2_model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv_embedding (ConvEmbeddin multiple 224
_________________________________________________________________
positional_encoding (Positio multiple 0
_________________________________________________________________
transformer_block (Transform multiple 12704
_________________________________________________________________
global_average_pooling1d (Gl multiple 0
_________________________________________________________________
dropout_2 (Dropout) multiple 0
_________________________________________________________________
dense_6 (Dense) multiple 462
=================================================================
Total params: 13,390
Trainable params: 13,390
Non-trainable params: 0
atx
: ---> 1791.1 seconds per epoch (29.8mins)
Model: "atx_model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
entry_flow (EntryFlow) multiple 71790
_________________________________________________________________
middle_flow (MiddleFlow) multiple 103194
_________________________________________________________________
middle_flow_1 (MiddleFlow) multiple 103194
_________________________________________________________________
middle_flow_2 (MiddleFlow) multiple 103194
_________________________________________________________________
middle_flow_3 (MiddleFlow) multiple 103194
_________________________________________________________________
middle_flow_4 (MiddleFlow) multiple 103194
_________________________________________________________________
middle_flow_5 (MiddleFlow) multiple 103194
_________________________________________________________________
middle_flow_6 (MiddleFlow) multiple 103194
_________________________________________________________________
middle_flow_7 (MiddleFlow) multiple 103194
_________________________________________________________________
exit_flow (ExitFlow) multiple 437802
=================================================================
Total params: 1,335,144
Trainable params: 1,321,512
Non-trainable params: 13,632
Note, with the update scaled down version of atx
, average epoch is reduced to 1550 seconds per epoch (~25mins) with reduced number of parameters:
Model: "atx_model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
entry_flow (EntryFlow) multiple 71790
middle_flow (MiddleFlow) multiple 103194
exit_flow (ExitFlow) multiple 437774
=================================================================
Total params: 612,758
Trainable params: 606,770
Non-trainable params: 5,988
_________________________________________________________________