DeepRec-AI/HybridBackend

hb.keras.Model's fit() func support dataset multiple labels

Closed this issue · 1 comments

User Story

I want to train model using multiple labels data, but fit function throw Exception like
Error when checking model target:expected no data....

When the dataset only contain one label, It's OK.

Detailed requirements

  • The dataset like this,the labels maybe tuple or dict:
ds = hb.data.ParquetDataset(XXX)
def map_fn(batch):
    labels = tuple([batch[l] for l in labels])
    features = {}
    #pass
    return features, labels
ds = ds.map(map_fn)
  • The fit() like this:
m.fit(
    x=train_ds,
    validation_data=valid_ds,
    #XXX
    verbose=0)

I wish fit() can support dataset like above.

API Compatibility

hb.keras.Model's fit()

Willing to contribute

Yes

The tf.keras supports such usage by default, taking the following code snippet as an example:

def input_dataset(args, filenames, batch_size):
  r'''Get input dataset.
  '''
  with tf.device('/cpu:0'):
    ds = hb.data.ParquetDataset(
      filenames,
      batch_size=batch_size,
      num_parallel_reads=len(filenames),
      num_parallel_parser_calls=args.num_parsers,
      drop_remainder=True)
    ds = ds.apply(hb.data.parse())
    ds = ds.map(
      lambda batch: (
        {f: batch[f] for f in batch if f not in ('ts', 'label')},
        {'output_1': tf.reshape(batch['label'], shape=[-1, 1]),
         'output_2': tf.reshape(batch['label'], shape=[-1, 1])}))
    ds = ds.prefetch(args.num_prefetches)
    return ds

It is worth noting that we must use a dict with keys following the pattern of {'output_1': label1, 'output_2': label2 ... }. Here the keys must be written as output_${i} (requested by tf.keras), where ${i} is the index of the ith output in the functional API. Accordingly, the creation of model via the functional API is

your_model = tf.keras.Model(inputs=xxx, outputs=[output_a, output_b])

Therefore, tf.keras implicitly would produce a dict of {'output_1': output_a, 'output_2': output_b} and associate it with the multiple labels {'output_1': label1, 'output_2': label2}.