alibaba/EasyParallelLibrary

训练时,除chief worker外,其余worker在每次save checkpoint 后 step归0,且在第二次save checkpoint 后 整个进程卡死

walkingwindy opened this issue · 1 comments

代码:

"""Run downstream classification"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import tensorflow as tf
import utils.optimizer as optimizer
import epl

FLAGS = tf.flags.FLAGS

tf.flags.DEFINE_integer("task_index", None, "Worker or server index")
tf.flags.DEFINE_string("worker_hosts", "", "worker hosts")

tf.flags.DEFINE_string("buckets", "", "tables info")
tf.flags.DEFINE_string("train_table", "", "tables info")
tf.flags.DEFINE_string("val_table", "", "tables info")
tf.flags.DEFINE_string("checkpoint_dir", '',
                       """Path to checkpoint folder""")

tf.flags.DEFINE_integer("num_epochs", 100,
                        """Number of training epochs (default: 20)""")
tf.flags.DEFINE_integer("max_steps", 10000, "")
tf.flags.DEFINE_integer("batch_size", 256, """Batch size (default: 64)""")
tf.flags.DEFINE_integer("display_step", 200,
                        """Number of steps to display log into TensorBoard (default: 20)""")
tf.flags.DEFINE_integer("save_checkpoints_steps", 1000,
                        "How often to save the model checkpoint.")
tf.flags.DEFINE_float("learning_rate", 0.001,
                      """Learning rate (default: 0.0005)""")
tf.flags.DEFINE_float("max_grad_norm", 5.0,
                      """Maximum value of the global norm of the gradients for clipping (default: 5.0)""")

tf.flags.DEFINE_integer("num_pipe_stages", 1, "number of pipeline stages")
tf.flags.DEFINE_integer("num_micro_batch", 1, "number of pipeline micro batches")

def str2list(str_in, shape, separator=' ', dtype=tf.int32):
    data = tf.string_split([str_in], separator)
    data = tf.string_to_number(data.values, dtype)
    return tf.reshape(data, shape)

def file_based_input_fn_builder(input_file, slice_id, slice_count, is_training, drop_remainder):
    """Creates an `input_fn` closure to be passed to TPUEstimator."""
    def _decode_record(*record):
        """Decodes a record to a TensorFlow example."""
        (cert_no, coll_case_no, embedding, dt, label) = record

        embedding = str2list(embedding, shape=[512], separator='\002', dtype=tf.float32)

        example = {'input_embed': embedding,
                   'label': label,
                   'dt': dt,
                   'cert_no': cert_no,
                   'coll_case_no': coll_case_no}
        return example

    def input_fn(params):
        """The actual input function."""
        d = tf.data.TableRecordDataset([input_file], record_defaults=['', '', '', '', 0])
        if is_training:
            d = d.repeat(FLAGS.num_epochs)
            d = d.shuffle(buffer_size=1000)

        d = d.apply(tf.contrib.data.map_and_batch(
                        lambda v1, v2, v3, v4, v5: _decode_record(v1, v2, v3, v4, v5),
                        batch_size=FLAGS.batch_size,
                        drop_remainder=drop_remainder))
        return d

    return input_fn

def create_model(input_embed, label):

    with tf.variable_scope("loss", reuse=tf.AUTO_REUSE):
        with tf.variable_scope("cls"):            
            logits = tf.layers.dense(
                input_embed,
                2,
                activation=None,
                kernel_initializer=tf.truncated_normal_initializer())

        one_hot_label = tf.one_hot(label, depth=2, dtype=tf.float32)
        loss = tf.losses.softmax_cross_entropy(one_hot_label, logits)
        probs = tf.nn.softmax(logits, axis=-1)
        predict = tf.argmax(probs, axis=-1, output_type=tf.int32)

        acc = tf.metrics.accuracy(label, predict)
        auc = tf.metrics.auc(label, probs[:,-1])
        return (loss, acc, auc)


def model_fn_builder(checkpoint_dir, learning_rate):
    """Returns `model_fn` closure for TPUEstimator."""
    def model_fn(features, mode):
        """The `model_fn` for Estimator."""

        input_embed = features['input_embed']
        label = features["label"]

        # create loss
        (loss, acc, auc) = create_model(input_embed, label)

        output_spec = None
        if mode == tf.estimator.ModeKeys.TRAIN:
            #rms optimizer
            tvars = tf.trainable_variables()
            grads = tf.gradients(loss, tvars)
            clipped_grads, global_norm = tf.clip_by_global_norm(grads, FLAGS.max_grad_norm)
            tf.summary.scalar('global_grad_norm', global_norm)

            global_step = tf.train.get_or_create_global_step()
            optimizer = tf.train.RMSPropOptimizer(learning_rate)
            train_op = optimizer.apply_gradients(zip(clipped_grads, tvars),
                                            name='train_op',
                                            global_step=global_step)

            output_spec = tf.estimator.EstimatorSpec(
                mode=mode,
                loss=loss,
                train_op=train_op)

        elif mode == tf.estimator.ModeKeys.EVAL:
            output_spec = tf.estimator.EstimatorSpec(
                mode=mode,
                loss=loss,
                eval_metric_ops={'Accuracy':acc, "AUC":auc})
        else:
            raise ValueError("Only TRAIN and EVAL modes are supported: %s" % (mode))

        return output_spec

    return model_fn
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)
    tf.logging.info("############## Start #####################")
    checkpoint_dir = os.path.join(FLAGS.buckets, FLAGS.checkpoint_dir)
    train_file = FLAGS.train_table
    val_file = FLAGS.val_table

    worker_spec = FLAGS.worker_hosts.split(",")
    worker_count = len(worker_spec)
    task_index = FLAGS.task_index

    epl_env = epl.Env.get()
    total_device = len(epl_env.cluster.available_devices)
    num_replica = total_device // FLAGS.num_pipe_stages
    micro_batch = FLAGS.batch_size // epl_env.config.pipeline.num_micro_batch
    micro_batch = micro_batch // num_replica
  
    print("total_batch: {}, num_micro_batch: {}, num_replica: {}, micro_batch: {}".format(
            FLAGS.batch_size,
            epl_env.config.pipeline.num_micro_batch,
            num_replica,
            micro_batch))
    print("task_index:", task_index)
    print("total_device:", total_device)

    model_fn = model_fn_builder(checkpoint_dir, FLAGS.learning_rate)

    train_input_fn = file_based_input_fn_builder(
        input_file=train_file,
        slice_id=task_index,
        slice_count=worker_count,
        is_training=True,
        drop_remainder=True
    )

    val_input_fn = file_based_input_fn_builder(
        input_file=val_file,
        slice_id=task_index,
        slice_count=worker_count,
        is_training=False,
        drop_remainder=False
    )

    sess_config = tf.ConfigProto(allow_soft_placement=True)
    config = tf.estimator.RunConfig(session_config=sess_config,
                                    save_checkpoints_steps=FLAGS.save_checkpoints_steps)

    estimator = tf.estimator.Estimator(
                model_fn=model_fn,
                config=config,
                model_dir=checkpoint_dir)

    train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=FLAGS.max_steps)
    eval_spec = tf.estimator.EvalSpec(input_fn=val_input_fn, start_delay_secs=6, throttle_secs=1)
    tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
    tf.logging.info("#################  All process done.  ########################")

if __name__ == '__main__':
    env_dist = os.environ
    print(env_dist.get('TF_CONFIG'))
    config_json = {}
    config_json["pipeline.num_micro_batch"] = FLAGS.num_micro_batch
    epl.init(epl.Config(config_json))
    if FLAGS.num_pipe_stages == 1:
        epl.set_default_strategy(epl.replicate(device_count=1))
    tf.app.run()

训练提交worker sql:


pai -name tensorflow1120_py3
-Dscript="***/resources/***.tar.gz"
-DentryFile="train_downstream_cls.py"
-Dbuckets="***"
-DuserDefinedParameters="--num_epochs=10 --max_steps=100000 --buckets=*** --checkpoint_dir=*** --train_table=*** --val_table=*** “
-Dtables="***, ***"
-Dcluster="{\"worker\":{\"count\":8,\"cpu\":400,\"gpu\":100}}"

You can set different checkpoint dir for different workers, e.g.

worker_index = Env.get().cluster.worker_index
checkpoint_dir = os.path.join(FLAGS.buckets, FLAGS.checkpoint_dir, worker_index)