philipperemy/keras-tcn

Model Fit deadlocks when training in SageMaker with PipeModeDataset

murphycrosby opened this issue · 2 comments

Model Fit deadlocks when training on SageMaker with PipeModeDataset. CPUUtilization, MemoryUtilization, DiskUtilization all drop to 0 on the training instance. The model works fine when you swap out PipeModeDataset with tf.data.TFRecordDataset. The for loop proves that the dataset batch has been downloaded.

Example TCN Model

import tensorflow as tf
from tensorflow.keras.layers import Dense
from tensorflow.keras import Input, Model
from tcn import TCN, tcn_full_summary
from sagemaker_tensorflow import PipeModeDataset


def parse(record):
    features = {
        'data': tf.io.FixedLenFeature([], tf.string),
        'labels': tf.io.FixedLenFeature([], tf.float32),
    }
    parsed = tf.io.parse_single_example(record, features)
    return ({
        'data': tf.io.parse_tensor(parsed['data'], out_type = tf.float64)
    }, parsed['labels'])


if __name__ =='__main__':
    input_dim = 1
    
    i = Input(batch_shape=(None, 10, input_dim))
    o = TCN(return_sequences=False)(i)  # The TCN layers are here.
    o = Dense(1)(o)
    m = Model(inputs=[i], outputs=[o])
    m.compile(optimizer=args.optimizer, loss='mse')
    tcn_full_summary(m, expand_residual_blocks=False)
    
    ds = PipeModeDataset(channel='train', record_format='TFRecord')
    
    ds = ds.map(parse, num_parallel_calls=10)
    ds = ds.prefetch(10)
    ds = ds.repeat(10)
    ds = ds.batch(25, drop_remainder=True)
    
    for row in ds:
        print(f'x_shape: {row[0]["data"].numpy().shape}')
        print(f'y_shape: {row[1].numpy().shape}')
        break
    m.fit(ds, epochs=10)

tensorflow==2.3.1

Output

Model: "functional_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
input_1 (InputLayer)         [(None, 10, 1)]           0
_________________________________________________________________
residual_block_0 (ResidualBl [(None, 10, 64), (None, 1 8576
_________________________________________________________________
residual_block_1 (ResidualBl [(None, 10, 64), (None, 1 16512
_________________________________________________________________
residual_block_2 (ResidualBl [(None, 10, 64), (None, 1 16512
_________________________________________________________________
residual_block_3 (ResidualBl [(None, 10, 64), (None, 1 16512
_________________________________________________________________
residual_block_4 (ResidualBl [(None, 10, 64), (None, 1 16512
_________________________________________________________________
residual_block_5 (ResidualBl [(None, 10, 64), (None, 1 16512
_________________________________________________________________
lambda (Lambda)              (None, 64)                0
_________________________________________________________________
dense (Dense)                (None, 1)                 65
=================================================================
Total params: 91,201
Trainable params: 91,201
Non-trainable params: 0
_________________________________________________________________
x_shape: (25, 10, 1)
y_shape: (25,)

@murphycrosby have you ever had the chance to try on a more recent version of TF?

krzim commented

@murphycrosby Have you tried swapping out the TCN for another model to see if the issue persists? I'd be curious to see if it were actually related to the TCN training in particular or if it's an oddity related to PipeModeDataset. I would also like to see how you are setting up model training in SageMaker and what instance, options, etc. you are using.

It sounds sort of similar to this issue here which is external to the TCN portion of the code. I don't know that it matters much for a normal TFRecordsDataset but order of operations might matter on the PipeModeDataset. You could try rearranging the parse/prefetch/batch ops to match the AWS example here.

Lastly, you could swap to the new FastFile Mode with something like this example here.