NielsRogge/Transformers-Tutorials

layoutXLM/layoutLMv2 output length doesn't match the input length even after postprocessing

borisloktev opened this issue · 0 comments

Hi there!

I just wanted to ask about the way to best match the outputs to the original input tokens when using the model with pre-ocr'ed documents.

I've fine-tuned a version of LayoutXLM for token classification and am using the function below to infer. I've noticed however that the output length doesn't match the original number of input tokens even after removing padding, subword and special tokens. Maybe you could advise me on the way to best align these two? Maybe I'm missing something in the post-processing step?

Thank you in advance and really love your work!

def get_model_prediction(preprocessed_tokens, rotated_img, width_up, height_up):
    encoding = TOKENIZER(
        [i["token_text"] for i in preprocessed_tokens],
        boxes=[i["normalized_bbox"] for i in preprocessed_tokens],
        return_offsets_mapping=True,
        padding="max_length",
        max_length=512,
        truncation=True,
        return_tensors="pt",
    )

    image_input = FEATURE_EXTRACTOR(
        rotated_img,
        return_tensors="pt",
    )  # type: ignore

    outputs = MODEL(
        input_ids=encoding.input_ids,
        attention_mask=encoding.attention_mask,
        bbox=encoding.bbox,
        image=image_input.pixel_values,
        return_dict=True,
    )  # type: ignore
    predictions = outputs.logits.argmax(-1).squeeze().tolist()
    token_boxes = encoding.bbox.squeeze().tolist()
    print(encoding.input_ids.shape, outputs.logits.argmax(-1).squeeze().shape, encoding.bbox.squeeze().shape)

    # get all tokens that are not subwords
    is_subword = np.array(encoding.offset_mapping.squeeze().tolist())[:, 0] != 0

    # remove padding tokens
    predictions = predictions[: sum(encoding.attention_mask.squeeze().tolist())]

    # get aligned predictions and boxes
    true_predictions = [
        ID_TO_LABEL[pred] for idx, pred in enumerate(predictions) if not is_subword[idx]
    ]
    # remove subword and padding boxes
    true_boxes = [
        image_utils.unnormalize_box(box, width_up, height_up)
        for idx, box in enumerate(token_boxes)
        if not is_subword[idx] and encoding.attention_mask.squeeze().tolist()[idx] == 1
    ]