TFSequenceClassificationLoss for MultiLabel classification
ds-mike opened this issue · 1 comments
System Info
transformers
version: 4.40.2- Platform: Linux-6.1.84-x86_64-with-glibc2.31
- Python version: 3.11.9
- Huggingface_hub version: 0.23.0
- Safetensors version: 0.4.3
- Accelerate version: not installed
- Accelerate config: not found
- PyTorch version (GPU?): not installed (NA)
- Tensorflow version (GPU?): 2.16.1 (False)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: no
- Using distributed or parallel set-up in script?: no
Who can help?
Tagging (up to) three individuals who may be able to help:
@ArthurZucker (Text Models), @gante (Tensorflow), @Rocketknight1 (Tensorflow examples)
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
Hi,
I have a multi-label classification situation that I am experimenting with using BertForSequenceClassification as a base model for. I am basically declaring the model like the following, where my number of classes is 36. The individual training sample could have several labels.
model = TFBertForSequenceClassification.from_pretrained(
<my path>,
problem_type="multi_label_classification",
label2id=labels["label2id"],
id2label=labels["id2label"],
num_labels=num_classes)
model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True))
My inputs are as follows:
Inputs: {'input_ids': <tf.Tensor: shape=(1, 9), dtype=int32, numpy=
array([[ 101, 2061, 2061, 2172, 3793, 2302, 2151, 3574, 102]],
dtype=int32)>,
'token_type_ids': <tf.Tensor: shape=(1, 9), dtype=int32, numpy=array([[0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=int32)>,
'attention_mask': <tf.Tensor: shape=(1, 9), dtype=int32, numpy=array([[1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=int32)>}
Labels: <tf.Tensor: shape=(1, 36), dtype=int64, numpy=
array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]])>
When I try to call the model with the labels present, only then do I get the following error:
`labels.shape` must equal `logits.shape` except for the last dimension. Received: labels.shape=(36,) and logits.shape=(1, 36)
It seems like it is because SparseCategoricalCrossentropy() is getting set as the loss function here. If I manually override this to what I understand to be a more appropriate loss function, in this case Binary Crossentropy, loss computes as expected.
I tested the torch code here and it seems to work fine. I get something like the following: tensor(0.8118, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
.
I believe the TF loss function should be changed in the line described above. Furthermore, compiling the model as I did in my code doesn't seem to be doing anything. Please let me know where I have gone wrong or if there are other details I can provide.
Thank you!
Mike
Expected behavior
Loss should be computed with binary crossentropy, given that it is a multi-label classification scenario and I am passing a onehot-encoded vector.
Hi @ds-mike, our TF models don't support multi-label classification with the default loss (I'm not sure Torch does either - you got a loss value there but I don't know if it's the correct one!)
Our TFXXXForSequenceClassification
models expect labels to just be integer category indices, so if num_labels=36
, then the label should be an integer between 0
and 35
. This is by design - if we switched to binary crossentropy, then lots of people's existing code wouldn't work anymore!
Multi-label classification with these models is possible, but you'll need to either compile with a new loss function like binary crossentropy as you did there, or for more advanced use cases you can use the base TFBertModel
as a layer in your own model class, and add a custom output head onto it that can have whatever shape and loss you want.
In other words, I don't think you're making any mistake here - you just have a slightly advanced use-case that requires a bit of custom loss computation rather than relying on the built-in losses, and that's okay!