layoutXLM/layoutLMv2 output length doesn't match the input length even after postprocessing
borisloktev opened this issue · 0 comments
borisloktev commented
Hi there!
I just wanted to ask about the way to best match the outputs to the original input tokens when using the model with pre-ocr'ed documents.
I've fine-tuned a version of LayoutXLM for token classification and am using the function below to infer. I've noticed however that the output length doesn't match the original number of input tokens even after removing padding, subword and special tokens. Maybe you could advise me on the way to best align these two? Maybe I'm missing something in the post-processing step?
Thank you in advance and really love your work!
def get_model_prediction(preprocessed_tokens, rotated_img, width_up, height_up):
encoding = TOKENIZER(
[i["token_text"] for i in preprocessed_tokens],
boxes=[i["normalized_bbox"] for i in preprocessed_tokens],
return_offsets_mapping=True,
padding="max_length",
max_length=512,
truncation=True,
return_tensors="pt",
)
image_input = FEATURE_EXTRACTOR(
rotated_img,
return_tensors="pt",
) # type: ignore
outputs = MODEL(
input_ids=encoding.input_ids,
attention_mask=encoding.attention_mask,
bbox=encoding.bbox,
image=image_input.pixel_values,
return_dict=True,
) # type: ignore
predictions = outputs.logits.argmax(-1).squeeze().tolist()
token_boxes = encoding.bbox.squeeze().tolist()
print(encoding.input_ids.shape, outputs.logits.argmax(-1).squeeze().shape, encoding.bbox.squeeze().shape)
# get all tokens that are not subwords
is_subword = np.array(encoding.offset_mapping.squeeze().tolist())[:, 0] != 0
# remove padding tokens
predictions = predictions[: sum(encoding.attention_mask.squeeze().tolist())]
# get aligned predictions and boxes
true_predictions = [
ID_TO_LABEL[pred] for idx, pred in enumerate(predictions) if not is_subword[idx]
]
# remove subword and padding boxes
true_boxes = [
image_utils.unnormalize_box(box, width_up, height_up)
for idx, box in enumerate(token_boxes)
if not is_subword[idx] and encoding.attention_mask.squeeze().tolist()[idx] == 1
]