Layoutlmv3 issue with inferencing bounding box is not plotting correctly
rajasekarkrish opened this issue · 0 comments
Layoutlmv3 issue with inferencing bounding box is not plotting correctly
from transformers import AutoModelForTokenClassification
from datasets import load_dataset
import torch
from transformers import AutoProcessor
import matplotlib.pyplot as plt
from updatetrain import id2label
import matplotlib.patches as patches
model = AutoModelForTokenClassification.from_pretrained("/new_dataset/new_layoutlmv3/checkpoint-3000")
dataset = load_dataset(r"/new_layoutlmv3_dataset/new_dataset/updateddataset.py")
processor = AutoProcessor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False)
example = dataset["test"][1]
print(example["image"])
print(example.keys())
image = example["image"]
words = example["words"]
boxes = example["bboxes"]
word_labels = example["ner_tags"]
encoding = processor(image, words, boxes=boxes, word_labels=word_labels, truncation=True , stride =128, return_tensors="pt")
for k,v in encoding.items():
print(k,v.shape)
with torch.no_grad():
outputs = model(**encoding)
logits = outputs.logits
print(logits.shape)
predictions = logits.argmax(-1).squeeze().tolist()
print(predictions)
labels = encoding.labels.squeeze().tolist()
print(labels)
print("printing five labels",labels[:5]) # Print the first 5 labels
def unnormalize_box(bbox, width, height):
return [
width * (bbox[0]/1000),
height *(bbox[1]/1000),
width * (bbox[2]/1000),
height *(bbox[3]/1000),
]
token_boxes = encoding.bbox.squeeze().tolist()
width, height = image.size
print("Image size:", width, "x", height)
print("printing boxes",boxes[:5] )
print("printing token boxes",token_boxes[:5]) # Print the first 5 bounding boxes
true_predictions = [model.config.id2label[pred] for pred, label in zip(predictions, labels) if label != - 100]
true_labels = [model.config.id2label[label] for prediction, label in zip(predictions, labels) if label != -100]
true_boxes = [unnormalize_box(box, width, height) for box, label in zip(token_boxes, labels) if label != -100]
print("printing true Predictions:", true_predictions[:5]) # Print the first 5 predictions
print("printing true Labels:", true_labels[:5]) # Print the first 5 labels
from PIL import ImageDraw, ImageFont
draw = ImageDraw.Draw(image)
font = ImageFont.load_default()
def iob_to_label(label):
label = label[2:]
if not label:
return 'other'
return label
label2color = {
'relevant': 'red',
'se_tax_header': 'green',
'se_tax_due_header': 'green',
'ar_header': 'green',
'se_tax_total': 'blue',
'se_tax_due_total': 'blue',
'se_tax': 'yellow',
'se_tax_due': 'orange',
'ar': 'orange',
}
print(label2color)
print(model.config.id2label)
for i, (prediction, label) in enumerate(zip(predictions, labels)):
if label == -100:
continue # Skip the padding tokens or any token that should be ignored
predicted_label = model.config.id2label.get(prediction, "Label not found")
actual_label = id2label.get(label, "Label not found")
print(f"Token {i}: Predicted - {predicted_label}, Actual - {actual_label}")
for prediction, box in zip(true_predictions, true_boxes):
predicted_label = iob_to_label(prediction).lower()
if predicted_label in label2color:
draw.rectangle(box, outline=label2color[predicted_label])
draw.text((box[0] + 10, box[1] - 10), text=predicted_label, fill=label2color[predicted_label], font=font)
else:
print(f"Label {predicted_label} not in label2color dictionary.")
plt.imshow(image)
plt.show()
Dataset preparation code
import json
import os
import numpy as np
from PIL import Image
import datasets
import torch
logger = datasets.logging.get_logger(name)
def normalize_bbox(bbox, size):
return [
int(1000 * bbox[0] / size[0]),
int(1000 * bbox[1] / size[1]),
int(1000 * bbox[2] / size[0]),
int(1000 * bbox[3] / size[1]),
]
def load_image(image_path):
image = Image.open(image_path).convert("RGB")
w, h = image.size
return image, (w, h)
class CustomDatasetConfig(datasets.BuilderConfig):
"""BuilderConfig for CustomDataset"""
def init(self, **kwargs):
"""BuilderConfig for CustomDataset.
Args:
**kwargs: keyword arguments forwarded to super.
"""
super(CustomDatasetConfig, self).init(**kwargs)
class CustomDataset(datasets.GeneratorBasedBuilder):
"""Custom dataset for document understanding."""
BUILDER_CONFIGS = [
CustomDatasetConfig(name="custom_dataset", version=datasets.Version("1.0.0"), description="Custom dataset"),
]
def _info(self):
return datasets.DatasetInfo(
features=datasets.Features(
{
"id": datasets.Value("string"),
"words": datasets.Sequence(datasets.Value("string")),
"bboxes": datasets.Sequence(datasets.Sequence(datasets.Value("int64"))),
"ner_tags": datasets.Sequence(
datasets.features.ClassLabel(
names=[
'irrelevant',
'base_tax_header',
'base_tax_due_header',
'year_header',
'base_tax_total',
'base_tax_due_total',
'base_tax',
'base_tax_due',
'year',
# Add more labels as per your requirement
]
)
),
"image": datasets.features.Image(),
"image_path": datasets.Value("string"),
}
),
supervised_keys=None,
)
def _split_generators(self, dl_manager):
"""Returns SplitGenerators."""
# Assuming the data is already downloaded/extracted and available in a specific directory
data_dir = '/new_layoutlmv3_dataset/new_dataset/data/'
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN, gen_kwargs={"filepath": os.path.join(data_dir, "train.json")},
),
datasets.SplitGenerator(
name=datasets.Split.TEST, gen_kwargs={"filepath": os.path.join(data_dir, "test.json")},
),
]
def _generate_examples(self, filepath):
logger.info("⏳ Generating examples from = %s", filepath)
with open(filepath, "r", encoding="utf8") as f:
data = json.load(f)
# Define the base directory for your images
base_dir = r"D:\new_layoutlmv3_dataset\new_dataset" # Update this path to your base directory
for guid, item in enumerate(data):
# Check if 'file_name' exists and correct the path
if 'file_name' not in item:
logger.warning(f"Skipping entry {guid} due to missing 'file_name'")
continue
# Correct the slash direction and prepend the base directory to the file name
image_relative_path = item['file_name'].replace('../', '').replace('/', '\\')
image_path = os.path.join(base_dir, image_relative_path)
image, size = load_image(image_path)
words, bboxes, ner_tags = [], [], []
for annotation in item["annotations"]:
words.append(annotation["text"])
normalized_bbox = normalize_bbox(annotation["box"], size)
bboxes.append(normalized_bbox)
ner_tags.append(annotation["label"])
yield guid, {
"id": str(guid),
"words": words,
"bboxes": bboxes,
"ner_tags": ner_tags,
"image": image, # Adjusting according to the expected format
"image_path": image_path # Keep this as 'image_path' for consistency in your dataset features
}