keras-team/keras

StringLookup does not return expected dtype for multi_hot

Toku11 opened this issue · 1 comments

Documentation states we should expected float32 when using 'multi_hot', however int64 tensor is being returned

vocab = ["a", "b", "c", "d"]
data = [["a", "c", "d", "d"], ["d", "z", "b", "z"]]
layer = tf.keras.layers.StringLookup(vocabulary=vocab, out
put_mode='multi_hot')
layer(data)```

tf version: 2.16.1
<tf.Tensor: shape=(2, 5), dtype=int64, numpy=
array([[0, 1, 0, 1, 1],
[1, 0, 1, 0, 1]])>

Hi @Toku11 ,

I have tested the code snippets from APIs and the dtype should be int64 for the cases output_mode with one_hot multi_hot and count . The documentation needs to be changed to int64 instead of float32. Thanks!