tallamjr/astronet

Compare runtimes of `atx` vs `t2`

tallamjr opened this issue · 4 comments

With the new custom callback in commit: 4c9b1be, this should allow for logging of wall-time for each epoch, as well as a mean runtime for training

@tallamjr Can you link to the commit or, even better, the lines of code that do this? I wouldn't mind taking a look.

Hi @jasonmcewen , apologies, I've updated the link above.

Anyway, the main thing is to add a custom callback to the keras model like the one below (housed in astronet/custom_callbacks.py) which will log the time for each epoch:

class TimeHistoryCallback(tf.keras.callbacks.Callback):
    def on_train_begin(self, logs={}):
        self.times = []

    def on_epoch_begin(self, epoch, logs={}):
        self.epoch_time_start = time.time()

    def on_epoch_end(self, epoch, logs={}):
        self.times.append(time.time() - self.epoch_time_start)

Then, just before I do model.fit() inside astronet/atx/train.py, I create an instance of it with (see comments in code):

        ...
        time_callback = TimeHistoryCallback() # <--- Create instance here

        history = model.fit(
            train_input,
            y_train,
            batch_size=BATCH_SIZE,
            epochs=self.epochs,
            shuffle=True,
            validation_data=(test_input, y_test),
            validation_batch_size=VALIDATION_BATCH_SIZE,
            verbose=False,
            callbacks=[
                time_callback,  # <--- Call timing callback here for use later..
                SGEBreakoutCallback(
                    threshold=44     # Stop training if running for more than threshold number of hours
                ),
                CSVLogger(
                    csv_logger_file,
                    separator=',',
                    append=False,
                ),
            ...

Finally, after training, and saving the model, I print the logs and calculate the timing like so:

        log.info(f"PER EPOCH TIMING: {time_callback.times}")
        log.info(f"AVERAGE EPOCH TIMING: {np.array(time_callback.times).mean()}")

I was able to already get results for t2 which is reporting ~7.8minutes per epoch, atx is running now.

Closed with #80

t2: ---> 468.15 seconds per epoch (7.8mins)

Model: "t2_model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
conv_embedding (ConvEmbeddin multiple                  224
_________________________________________________________________
positional_encoding (Positio multiple                  0
_________________________________________________________________
transformer_block (Transform multiple                  12704
_________________________________________________________________
global_average_pooling1d (Gl multiple                  0
_________________________________________________________________
dropout_2 (Dropout)          multiple                  0
_________________________________________________________________
dense_6 (Dense)              multiple                  462
=================================================================
Total params: 13,390
Trainable params: 13,390
Non-trainable params: 0

atx: ---> 1791.1 seconds per epoch (29.8mins)

Model: "atx_model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
entry_flow (EntryFlow)       multiple                  71790
_________________________________________________________________
middle_flow (MiddleFlow)     multiple                  103194
_________________________________________________________________
middle_flow_1 (MiddleFlow)   multiple                  103194
_________________________________________________________________
middle_flow_2 (MiddleFlow)   multiple                  103194
_________________________________________________________________
middle_flow_3 (MiddleFlow)   multiple                  103194
_________________________________________________________________
middle_flow_4 (MiddleFlow)   multiple                  103194
_________________________________________________________________
middle_flow_5 (MiddleFlow)   multiple                  103194
_________________________________________________________________
middle_flow_6 (MiddleFlow)   multiple                  103194
_________________________________________________________________
middle_flow_7 (MiddleFlow)   multiple                  103194
_________________________________________________________________
exit_flow (ExitFlow)         multiple                  437802
=================================================================
Total params: 1,335,144
Trainable params: 1,321,512
Non-trainable params: 13,632

Note, with the update scaled down version of atx, average epoch is reduced to 1550 seconds per epoch (~25mins) with reduced number of parameters:

Model: "atx_model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #
=================================================================
 entry_flow (EntryFlow)      multiple                  71790

 middle_flow (MiddleFlow)    multiple                  103194

 exit_flow (ExitFlow)        multiple                  437774

=================================================================
Total params: 612,758
Trainable params: 606,770
Non-trainable params: 5,988
_________________________________________________________________