fabio-sim/LightGlue-ONNX

the possiblity of supporting batch input

noahzn opened this issue · 12 comments

Hi @fabio-sim Now the repo only supports for batchsize =1, do you think it's possible that if not enough keypoints are extracted, we can use a random array to make them have the same number of keypoints. For example, if the input is 2XNX2, for image 1 N1=128, for image 2 N2=125, can we stack three random array as the fake points so that we can run it in a batch mode?

Hello @noahzn, thanks for your interest again. I'll see what I can do.

Thank you! I will be waiting for your thoughts.

I've added batch input support in 9ebf215. Rather than padding with a random array, I've decided to go with another design choice instead; details here: https://fabio-sim.github.io/blog/accelerating-lightglue-inference-onnx-runtime-tensorrt/

That's really amazing! I will take a careful look and give you feedback. Thanks a lot!

Hi @fabio-sim , I noticed that you also modified this file, but you didn't use it in exporting models. Can I use it if I want to export non-end2end models using batch input? My two image batch have different numbers of keypoints. For example, keypoints of image1 are always (B X 100 X 2), and image2's are always (B X 200 X2)

Hi, that file is from the original impl, so it's unrelated to export.

For your use case, I recommend passing the left and right batches separately then, like this: (note: untested):

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn

from ..config import Extractor
from ..ops import multi_head_attention_dispatch

torch.backends.cudnn.deterministic = True


class LearnableFourierPositionalEncoding(nn.Module):
    def __init__(self, M: int, descriptor_dim: int, num_heads: int, gamma: float = 1.0) -> None:
        super().__init__()
        self.num_heads = num_heads
        head_dim = descriptor_dim // num_heads
        self.Wr = nn.Linear(M, head_dim // 2, bias=False)
        self.gamma = gamma
        nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma**-2)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """encode position vector"""
        projected = self.Wr(x)
        cosines, sines = torch.cos(projected), torch.sin(projected)
        emb = torch.stack([cosines, sines])
        return emb.repeat_interleave(2, dim=3).repeat(1, 1, 1, self.num_heads).unsqueeze(4)


class TokenConfidence(nn.Module):
    def __init__(self, dim: int) -> None:
        super().__init__()
        self.token = nn.Sequential(nn.Linear(dim, 1), nn.Sigmoid())

    def forward(self, desc0: torch.Tensor, desc1: torch.Tensor):
        """get confidence tokens"""
        return (
            self.token(desc0.detach()).squeeze(-1),
            self.token(desc1.detach()).squeeze(-1),
        )


class SelfBlock(nn.Module):
    def __init__(self, embed_dim: int, num_heads: int, bias: bool = True) -> None:
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.ffn = nn.Sequential(
            nn.Linear(2 * embed_dim, 2 * embed_dim),
            nn.LayerNorm(2 * embed_dim, elementwise_affine=True),
            nn.GELU(),
            nn.Linear(2 * embed_dim, embed_dim),
        )

    def forward(self, x: torch.Tensor, encoding: torch.Tensor) -> torch.Tensor:
        b, n, _ = x.shape
        qkv: torch.Tensor = self.Wqkv(x)
        qkv = qkv.reshape((b, n, self.embed_dim, 3))
        qk, v = qkv[..., :2], qkv[..., 2]
        qk = self.apply_cached_rotary_emb(encoding, qk)
        q, k = qk[..., 0], qk[..., 1]
        context = multi_head_attention_dispatch(q, k, v, self.num_heads)
        message = self.out_proj(context)
        return x + self.ffn(torch.concat([x, message], 2))

    def rotate_half(self, qk: torch.Tensor) -> torch.Tensor:
        b, n, _, _ = qk.shape
        qk = qk.reshape((b, n, self.num_heads, self.head_dim // 2, 2, 2))
        qk = torch.stack((-qk[..., 1, :], qk[..., 0, :]), dim=4)
        qk = qk.reshape((b, n, self.embed_dim, 2))
        return qk

    def apply_cached_rotary_emb(self, encoding: torch.Tensor, qk: torch.Tensor) -> torch.Tensor:
        return qk * encoding[0] + self.rotate_half(qk) * encoding[1]


class CrossBlock(nn.Module):
    def __init__(self, embed_dim: int, num_heads: int, bias: bool = True) -> None:
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.to_qk = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.to_v = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.to_out = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.ffn = nn.Sequential(
            nn.Linear(2 * embed_dim, 2 * embed_dim),
            nn.LayerNorm(2 * embed_dim, elementwise_affine=True),
            nn.GELU(),
            nn.Linear(2 * embed_dim, embed_dim),
        )

    def forward(self, descriptors0: torch.Tensor, descriptors1: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        b, _, _ = descriptors0.shape
        qk0, v0 = self.to_qk(descriptors0), self.to_v(descriptors0)
        qk1, v1 = self.to_qk(descriptors1), self.to_v(descriptors1)

        m0 = multi_head_attention_dispatch(qk0, qk1, v1, self.num_heads)
        m0 = self.to_out(m0)
        descriptors0 = descriptors0 + self.ffn(torch.concat([descriptors0, m0], 2))

        m1 = multi_head_attention_dispatch(qk1, qk0, v0, self.num_heads)
        m1 = self.to_out(m1)
        descriptors1 = descriptors1 + self.ffn(torch.concat([descriptors1, m1], 2))
        return descriptors0, descriptors1


class TransformerLayer(nn.Module):
    def __init__(self, embed_dim: int, num_heads: int):
        super().__init__()
        self.self_attn = SelfBlock(embed_dim, num_heads)
        self.cross_attn = CrossBlock(embed_dim, num_heads)

    def forward(
        self, descriptors0: torch.Tensor, descriptors1: torch.Tensor, encodings0: torch.Tensor, encodings1: torch.Tensor
    ) -> torch.Tensor:
        descriptors0 = self.self_attn(descriptors0, encodings0)
        descriptors1 = self.self_attn(descriptors1, encodings1)
        return self.cross_attn(descriptors0, descriptors1)


def sigmoid_log_double_softmax(similarities: torch.Tensor, z0: torch.Tensor, z1: torch.Tensor) -> torch.Tensor:
    """create the log assignment matrix from logits and similarity"""
    certainties = F.logsigmoid(z0) + F.logsigmoid(z1).transpose(1, 2)
    scores0 = F.log_softmax(similarities, 2)
    scores1 = F.log_softmax(similarities, 1)
    scores = scores0 + scores1 + certainties
    return scores


class MatchAssignment(nn.Module):
    def __init__(self, dim: int) -> None:
        super().__init__()
        self.scale = dim**0.25
        self.final_proj = nn.Linear(dim, dim, bias=True)
        self.matchability = nn.Linear(dim, 1, bias=True)

    def forward(self, descriptors0: torch.Tensor, descriptors1: torch.Tensor) -> torch.Tensor:
        """build assignment matrix from descriptors"""
        mdescriptors0 = self.final_proj(descriptors0) / self.scale
        mdescriptors1 = self.final_proj(descriptors1) / self.scale
        similarities = mdescriptors0 @ mdescriptors1.transpose(1, 2)
        z0 = self.matchability(descriptors0)
        z1 = self.matchability(descriptors1)
        scores = sigmoid_log_double_softmax(similarities, z0, z1)
        return scores

    def get_matchability(self, desc: torch.Tensor):
        return torch.sigmoid(self.matchability(desc)).squeeze(-1)


def filter_matches(scores: torch.Tensor, threshold: float):
    """obtain matches from a log assignment matrix [BxNxN]"""
    max0 = torch.topk(scores, k=1, dim=2, sorted=False)  # scores.max(2)
    max1 = torch.topk(scores, k=1, dim=1, sorted=False)  # scores.max(1)
    m0, m1 = max0.indices[:, :, 0], max1.indices[:, 0, :]

    indices = torch.arange(m0.shape[1], device=m0.device).expand_as(m0)
    mutual = indices == m1.gather(1, m0)
    mscores = max0.values[:, :, 0].exp()
    valid = mscores > threshold

    b_idx, m0_idx = torch.where(valid & mutual)
    m1_idx = m0[b_idx, m0_idx]
    matches = torch.concat([b_idx[:, None], m0_idx[:, None], m1_idx[:, None]], 1)
    mscores = mscores[b_idx, m0_idx]
    return matches, mscores


class LightGlue(nn.Module):
    version = "v0.1_arxiv"
    url = "https://github.com/cvg/LightGlue/releases/download/{}/{}_lightglue.pth"

    def __init__(
        self,
        extractor: Extractor,
        descriptor_dim: int = 256,
        num_heads: int = 4,
        n_layers: int = 9,
        filter_threshold: float = 0.1,  # match threshold
        depth_confidence: float = -1,  # -1 is no early stopping, recommend: 0.95
        width_confidence: float = -1,  # -1 is no point pruning, recommend: 0.99
    ) -> None:
        super().__init__()

        self.descriptor_dim = descriptor_dim
        self.num_heads = num_heads
        self.n_layers = n_layers
        self.filter_threshold = filter_threshold
        self.depth_confidence = depth_confidence
        self.width_confidence = width_confidence

        if extractor.dim != self.descriptor_dim:
            self.input_proj = nn.Linear(extractor.dim, self.descriptor_dim, bias=True)
        else:
            self.input_proj = nn.Identity()

        self.posenc = LearnableFourierPositionalEncoding(2, self.descriptor_dim, self.num_heads)

        d, h, n = self.descriptor_dim, self.num_heads, self.n_layers

        self.transformers = nn.ModuleList([TransformerLayer(d, h) for _ in range(n)])

        self.log_assignment = nn.ModuleList([MatchAssignment(d) for _ in range(n)])

        self.token_confidence = nn.ModuleList([TokenConfidence(d) for _ in range(n - 1)])
        self.register_buffer(
            "confidence_thresholds",
            torch.Tensor([self.confidence_threshold(i) for i in range(n)]),
        )

        state_dict = torch.hub.load_state_dict_from_url(self.url.format(self.version, extractor.value))

        # rename old state dict entries
        for i in range(n):
            pattern = f"self_attn.{i}", f"transformers.{i}.self_attn"
            state_dict = {k.replace(*pattern): v for k, v in state_dict.items()}
            pattern = f"cross_attn.{i}", f"transformers.{i}.cross_attn"
            state_dict = {k.replace(*pattern): v for k, v in state_dict.items()}
        self.load_state_dict(state_dict, strict=False)

    def forward(
        self,
        keypoints0: torch.Tensor,  # (2B, N, 2), normalized
        keypoints1: torch.Tensor,
        descriptors0: torch.Tensor,  # (2B, N, D)
        descriptors1: torch.Tensor,
    ):
        descriptors0 = self.input_proj(descriptors0)
        descriptors1 = self.input_proj(descriptors1)

        # positional embeddings
        encodings0 = self.posenc(keypoints0)  # (2, 2B, *, 64, 1)
        encodings1 = self.posenc(keypoints1)

        # GNN + final_proj + assignment
        for i in range(self.n_layers):
            # self+cross attention
            descriptors0, descriptors1 = self.transformers[i](descriptors0, descriptors1, encodings0, encodings1)

        scores = self.log_assignment[i](descriptors0, descriptors1)  # (B, N, N)
        matches, mscores = filter_matches(scores, self.filter_threshold)
        return matches, mscores  # (M, 3), (M,)

    def confidence_threshold(self, layer_index: int) -> float:
        """scaled confidence threshold"""
        threshold = 0.8 + 0.1 * np.exp(-4.0 * layer_index / self.n_layers)
        return np.clip(threshold, 0, 1)

    def get_pruning_mask(
        self,
        confidences: torch.Tensor | None,
        scores: torch.Tensor,
        layer_index: int,
    ) -> torch.Tensor:
        """mask points which should be removed"""
        keep = scores > (1 - self.width_confidence)
        if confidences is not None:  # Low-confidence points are never pruned.
            keep |= confidences <= self.confidence_thresholds[layer_index]
        return keep

    def check_if_stop(
        self,
        confidences0: torch.Tensor,
        confidences1: torch.Tensor,
        layer_index: int,
        num_points: int,
    ) -> torch.Tensor:
        """evaluate stopping condition"""
        confidences = torch.cat([confidences0, confidences1], -1)
        threshold = self.confidence_thresholds[layer_index]
        ratio_confident = 1.0 - (confidences < threshold).float().sum() / num_points
        return ratio_confident > self.depth_confidence

and then adjusting the Pipeline class to orchestrate SuperPoint(100) and SuperPoint(200) accordingly.

Hi @fabio-sim thank you very much! I'm now working on the code. But I met an error

fused_multi_head_attention = torch.library.custom_op(CUSTOM_OP_NAME, mutates_args=())(multi_head_attention)
AttributeError: module 'torch.library' has no attribute 'custom_op'

My torch is >=2.1

Oh apologies, my mistake. torch.library.custom_op needs torch >= 2.4. I should've put a check.
I think it's fine if you comment it out

@fabio-sim Thank you for your comments.

orig_image0 = cv2.imread(img0_path, cv2.IMREAD_COLOR)
orig_image1 = cv2.imread(img1_path, cv2.IMREAD_COLOR)
viz2d.plot_images(
    [orig_image0, orig_image1]
)

assert np.all(kpts0[2][matches[..., 1]] == kpts0[0][matches[..., 1]])
assert np.all(kpts1[2][matches[..., 2]] == kpts1[0][matches[..., 2]])
viz2d.plot_matches(kpts0[0][matches[..., 1]], kpts1[0][matches[...,2]], color="lime", lw=0.2)

viz2d.save_plot('aaa1.jpg', dpi=300)
viz2d.plt.show()
viz2d.plot_matches(kpts0[2][matches[..., 1]], kpts1[2][matches[..., 2]], color="lime", lw=0.2)
viz2d.save_plot('aaa2.jpg', dpi=300)
viz2d.plt.show()

I used the above code to visualize. I used batchsize=4, and for the first and the third image pairs, they are the same, and for the other two image pairs I used random arrays. Here I assert that the output for the first and the third pairs are the same. However, when visualizing the results, there are always several matches are changing and incorrect. Do you know the reason?

myplot1

myplot2

update: The problem has been solved. I didn't parse the returned matches correctly. Now it works. Thanks a million for your help!! Now I close this ticket

Hi, sorry, I still have a problem.

def multi_head_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int) -> torch.Tensor:
    b, n, d = q.shape
    head_dim = d // num_heads
    q, k, v = (t.reshape((b, n, num_heads, head_dim)).transpose(1, 2) for t in (q, k, v))
    return F.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape((b, n, d))

I found that when image pairs have different numbers of keypoints, the multi_head_attention will throw an error.
For example, for the left images the dimension is (2, 99, 64), and for the right images the dimension is (2, 256, 64). Here 256 is the max_number of keypoints I set. but it extracts 99 keypoints on the left images. Then in the multi_head_attention function it throws the error
q, k, v = (t.reshape((b, n, num_heads, head_dim)).transpose(1, 2) for t in (q, k, v)) RuntimeError: shape '[2, 99, 4, 16]' is invalid for input of size 32768

because for the right images it's [2, 256, 4, 16]: 2x256x4x16=32768 (Please notice here that my keypoint descriptor is 64D, instead of 256D. It's a customized network).

I tried to modify the function as follows, the error was gone but the matching result is bad.

def multi_head_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int) -> torch.Tensor:
    b, n, d = q.shape
    nk = k.shape[1]
    head_dim = d // num_heads
    q = q.reshape(b, n, num_heads, head_dim).transpose(1, 2)
    k = k.reshape(b, nk, num_heads, head_dim).transpose(1, 2)
    v = v.reshape(b, nk, num_heads, head_dim).transpose(1, 2)
    return F.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape((b, n, d))

Could you help me with that? Thank you in advance!

Update: I have used the old implementation for CrossBlock and it works with different numbers of keypoints.

Ah yes, if you have different number of keypoints, that means the sequence length of Q is different from that of K & V.

Use something like this instead:

def multi_head_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int) -> torch.Tensor:
    b, n, d = q.shape
    _, n1, _ = k.shape
    head_dim = d // num_heads
    q = q.reshape((b, n, num_heads, head_dim)).transpose(1, 2)
    k, v = (t.reshape((b, n1, num_heads, head_dim)).transpose(1, 2) for t in (k, v))
    return F.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape((b, n, d))

Thank you again for your help!