StringLookup does not return expected dtype for multi_hot
Toku11 opened this issue · 1 comments
Toku11 commented
Documentation states we should expected float32 when using 'multi_hot', however int64 tensor is being returned
vocab = ["a", "b", "c", "d"]
data = [["a", "c", "d", "d"], ["d", "z", "b", "z"]]
layer = tf.keras.layers.StringLookup(vocabulary=vocab, out
put_mode='multi_hot')
layer(data)```
tf version: 2.16.1
<tf.Tensor: shape=(2, 5), dtype=int64, numpy=
array([[0, 1, 0, 1, 1],
[1, 0, 1, 0, 1]])>
SuryanarayanaY commented
Hi @Toku11 ,
I have tested the code snippets from APIs and the dtype should be int64 for the cases output_mode
with one_hot
multi_hot
and count
. The documentation needs to be changed to int64
instead of float32
. Thanks!