NielsRogge/Transformers-Tutorials

Layoutlmv3 issue with inferencing bounding box is not plotting correctly

rajasekarkrish opened this issue · 0 comments

Layoutlmv3 issue with inferencing bounding box is not plotting correctly
from transformers import AutoModelForTokenClassification
from datasets import load_dataset
import torch
from transformers import AutoProcessor
import matplotlib.pyplot as plt
from updatetrain import id2label
import matplotlib.patches as patches

model = AutoModelForTokenClassification.from_pretrained("/new_dataset/new_layoutlmv3/checkpoint-3000")

dataset = load_dataset(r"/new_layoutlmv3_dataset/new_dataset/updateddataset.py")

processor = AutoProcessor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False)

example = dataset["test"][1]
print(example["image"])
print(example.keys())

image = example["image"]
words = example["words"]
boxes = example["bboxes"]
word_labels = example["ner_tags"]

encoding = processor(image, words, boxes=boxes, word_labels=word_labels, truncation=True , stride =128, return_tensors="pt")

for k,v in encoding.items():
print(k,v.shape)

with torch.no_grad():
outputs = model(**encoding)

logits = outputs.logits
print(logits.shape)

predictions = logits.argmax(-1).squeeze().tolist()
print(predictions)

labels = encoding.labels.squeeze().tolist()
print(labels)

print("printing five labels",labels[:5]) # Print the first 5 labels
def unnormalize_box(bbox, width, height):
return [
width * (bbox[0]/1000),
height *(bbox[1]/1000),
width * (bbox[2]/1000),
height *(bbox[3]/1000),
]

token_boxes = encoding.bbox.squeeze().tolist()
width, height = image.size
print("Image size:", width, "x", height)

print("printing boxes",boxes[:5] )

print("printing token boxes",token_boxes[:5]) # Print the first 5 bounding boxes

true_predictions = [model.config.id2label[pred] for pred, label in zip(predictions, labels) if label != - 100]
true_labels = [model.config.id2label[label] for prediction, label in zip(predictions, labels) if label != -100]
true_boxes = [unnormalize_box(box, width, height) for box, label in zip(token_boxes, labels) if label != -100]

print("printing true Predictions:", true_predictions[:5]) # Print the first 5 predictions
print("printing true Labels:", true_labels[:5]) # Print the first 5 labels

from PIL import ImageDraw, ImageFont

draw = ImageDraw.Draw(image)

font = ImageFont.load_default()

def iob_to_label(label):
label = label[2:]
if not label:
return 'other'
return label

label2color = {
'relevant': 'red',
'se_tax_header': 'green',
'se_tax_due_header': 'green',
'ar_header': 'green',
'se_tax_total': 'blue',
'se_tax_due_total': 'blue',
'se_tax': 'yellow',
'se_tax_due': 'orange',
'ar': 'orange',
}

print(label2color)
print(model.config.id2label)

for i, (prediction, label) in enumerate(zip(predictions, labels)):
if label == -100:
continue # Skip the padding tokens or any token that should be ignored
predicted_label = model.config.id2label.get(prediction, "Label not found")
actual_label = id2label.get(label, "Label not found")
print(f"Token {i}: Predicted - {predicted_label}, Actual - {actual_label}")

for prediction, box in zip(true_predictions, true_boxes):
predicted_label = iob_to_label(prediction).lower()
if predicted_label in label2color:
draw.rectangle(box, outline=label2color[predicted_label])
draw.text((box[0] + 10, box[1] - 10), text=predicted_label, fill=label2color[predicted_label], font=font)
else:
print(f"Label {predicted_label} not in label2color dictionary.")

plt.imshow(image)
plt.show()

bounding_box_not_ploted_accordingly_layoutlmv3

Dataset preparation code
import json
import os
import numpy as np
from PIL import Image
import datasets

import torch

logger = datasets.logging.get_logger(name)

def normalize_bbox(bbox, size):
return [
int(1000 * bbox[0] / size[0]),
int(1000 * bbox[1] / size[1]),
int(1000 * bbox[2] / size[0]),
int(1000 * bbox[3] / size[1]),
]

def load_image(image_path):
image = Image.open(image_path).convert("RGB")
w, h = image.size
return image, (w, h)

class CustomDatasetConfig(datasets.BuilderConfig):
"""BuilderConfig for CustomDataset"""
def init(self, **kwargs):
"""BuilderConfig for CustomDataset.
Args:
**kwargs: keyword arguments forwarded to super.
"""
super(CustomDatasetConfig, self).init(**kwargs)

class CustomDataset(datasets.GeneratorBasedBuilder):
"""Custom dataset for document understanding."""

BUILDER_CONFIGS = [
    CustomDatasetConfig(name="custom_dataset", version=datasets.Version("1.0.0"), description="Custom dataset"),
]

def _info(self):
    return datasets.DatasetInfo(
        features=datasets.Features(
            {
                "id": datasets.Value("string"),
                "words": datasets.Sequence(datasets.Value("string")),
                "bboxes": datasets.Sequence(datasets.Sequence(datasets.Value("int64"))),
                "ner_tags": datasets.Sequence(
                    datasets.features.ClassLabel(
                        names=[
                            'irrelevant',
                            'base_tax_header',
                            'base_tax_due_header',
                            'year_header',
                            'base_tax_total',
                            'base_tax_due_total',
                            'base_tax',
                            'base_tax_due',
                            'year',
                            # Add more labels as per your requirement
                        ]
                    )
                ),
                "image": datasets.features.Image(),
                "image_path": datasets.Value("string"),
            }
        ),
        supervised_keys=None,
    )

def _split_generators(self, dl_manager):
    """Returns SplitGenerators."""
    # Assuming the data is already downloaded/extracted and available in a specific directory
    data_dir = '/new_layoutlmv3_dataset/new_dataset/data/'
    return [
        datasets.SplitGenerator(
            name=datasets.Split.TRAIN, gen_kwargs={"filepath": os.path.join(data_dir, "train.json")},
        ),
        datasets.SplitGenerator(
            name=datasets.Split.TEST, gen_kwargs={"filepath": os.path.join(data_dir, "test.json")},
        ),
    ]

def _generate_examples(self, filepath):
    logger.info("⏳ Generating examples from = %s", filepath)
    with open(filepath, "r", encoding="utf8") as f:
        data = json.load(f)

    # Define the base directory for your images
    base_dir = r"D:\new_layoutlmv3_dataset\new_dataset"  # Update this path to your base directory

    for guid, item in enumerate(data):
        # Check if 'file_name' exists and correct the path
        if 'file_name' not in item:
            logger.warning(f"Skipping entry {guid} due to missing 'file_name'")
            continue
        # Correct the slash direction and prepend the base directory to the file name
        image_relative_path = item['file_name'].replace('../', '').replace('/', '\\')
        image_path = os.path.join(base_dir, image_relative_path)

        image, size = load_image(image_path)
        words, bboxes, ner_tags = [], [], []

        for annotation in item["annotations"]:
            words.append(annotation["text"])
            normalized_bbox = normalize_bbox(annotation["box"], size)
            bboxes.append(normalized_bbox)
            ner_tags.append(annotation["label"])

        yield guid, {
            "id": str(guid),
            "words": words,
            "bboxes": bboxes,
            "ner_tags": ner_tags,
            "image": image,  # Adjusting according to the expected format
            "image_path": image_path  # Keep this as 'image_path' for consistency in your dataset features
        }