hb.keras.Model's fit() func support dataset multiple labels
Closed this issue · 1 comments
User Story
I want to train model using multiple labels data, but fit function throw Exception like
Error when checking model target:expected no data....
When the dataset only contain one label, It's OK.
Detailed requirements
- The dataset like this,the labels maybe tuple or dict:
ds = hb.data.ParquetDataset(XXX)
def map_fn(batch):
labels = tuple([batch[l] for l in labels])
features = {}
#pass
return features, labels
ds = ds.map(map_fn)
- The fit() like this:
m.fit(
x=train_ds,
validation_data=valid_ds,
#XXX
verbose=0)
I wish fit() can support dataset like above.
API Compatibility
hb.keras.Model's fit()
Willing to contribute
Yes
The tf.keras
supports such usage by default, taking the following code snippet as an example:
def input_dataset(args, filenames, batch_size):
r'''Get input dataset.
'''
with tf.device('/cpu:0'):
ds = hb.data.ParquetDataset(
filenames,
batch_size=batch_size,
num_parallel_reads=len(filenames),
num_parallel_parser_calls=args.num_parsers,
drop_remainder=True)
ds = ds.apply(hb.data.parse())
ds = ds.map(
lambda batch: (
{f: batch[f] for f in batch if f not in ('ts', 'label')},
{'output_1': tf.reshape(batch['label'], shape=[-1, 1]),
'output_2': tf.reshape(batch['label'], shape=[-1, 1])}))
ds = ds.prefetch(args.num_prefetches)
return ds
It is worth noting that we must use a dict with keys following the pattern of {'output_1': label1, 'output_2': label2 ... }. Here the keys must be written as output_${i}
(requested by tf.keras
), where ${i} is the index of the ith output in the functional API. Accordingly, the creation of model via the functional API is
your_model = tf.keras.Model(inputs=xxx, outputs=[output_a, output_b])
Therefore, tf.keras
implicitly would produce a dict of {'output_1': output_a, 'output_2': output_b} and associate it with the multiple labels {'output_1': label1, 'output_2': label2}.