tensorflow/ecosystem

Load SequenceExample tfrecords in tensorflow error.

niumeng07 opened this issue · 1 comments

Generate tfrecods in pyspark success, code as follows:

fields = []
field = StructField('seqlen', IntegerType())
fields.append(field)
field = StructField('label', ArrayType(LongType(), True))
fields.append(field)
schema = StructType(fields)
ret = []
ret.append(len(seq_list))
labels = [1 for item in seq_list]
ret.append(labels)

df = spark.createDataFrame(rdd, schema)
df.write.format("tfrecords").mode('overwrite').option("recordType", "SequenceExample").save(...)

Python tensorflow decode error,
tensorflow.python.framework.errors_impl.InvalidArgumentError: Name: , Feature list 'label' is required but could not be found. Did you mean to include it in feature_list_dense_missing_assumed_empty or feature_list_dense_defaults?

Python code:

def reader(args, filenames, num_workers=1, worker_index=0, decode_func = None):
    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=args.num_parallel_reads,
                          buffer_size = args.reader_buffer_size)
    dataset = dataset.map(decode_func)
    dataset = dataset.repeat(args.num_epochs)
    dataset = dataset.shuffle(1000 + 3 * args.batch_size)
    dataset = dataset.batch(args.batch_size)
    iterator = tf.data.make_one_shot_iterator(dataset)
    dataset_init_op = iterator.make_initializer(dataset, name='dataset_init')
    return iterator.get_next(), dataset_init_op
 def decode_squence_list(serialized_example):
     feature = {'label': tf.FixedLenSequenceFeature([], tf.int64),
                'seqlen': tf.FixedLenSequenceFeature([], tf.int64),
     }
    context_example, sequence_example = tf.parse_single_sequence_example(
       serialized_example,
       context_features=None,
       sequence_features=feature)
    labels = sequence_example['label']
    seqlen = sequence_example['seqlen']
    return labels, seqlen


iterator_next_op, dataset_init_op  = reader(args, filenames, 1, 0, decode_func=decode_squence_list)
labels, seqlen = iterator_next_op
sess.run([labels, seqlen])

solved, closed.