the possiblity of supporting batch input
noahzn opened this issue · 12 comments
Hi @fabio-sim Now the repo only supports for batchsize =1, do you think it's possible that if not enough keypoints are extracted, we can use a random array to make them have the same number of keypoints. For example, if the input is 2XNX2
, for image 1 N1=128, for image 2 N2=125, can we stack three random array as the fake points so that we can run it in a batch mode?
Thank you! I will be waiting for your thoughts.
I've added batch input support in 9ebf215. Rather than padding with a random array, I've decided to go with another design choice instead; details here: https://fabio-sim.github.io/blog/accelerating-lightglue-inference-onnx-runtime-tensorrt/
That's really amazing! I will take a careful look and give you feedback. Thanks a lot!
Hi @fabio-sim , I noticed that you also modified this file, but you didn't use it in exporting models. Can I use it if I want to export non-end2end models using batch input? My two image batch have different numbers of keypoints. For example, keypoints of image1 are always (B X 100 X 2), and image2's are always (B X 200 X2)
Hi, that file is from the original impl, so it's unrelated to export.
For your use case, I recommend passing the left and right batches separately then, like this: (note: untested):
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from ..config import Extractor
from ..ops import multi_head_attention_dispatch
torch.backends.cudnn.deterministic = True
class LearnableFourierPositionalEncoding(nn.Module):
def __init__(self, M: int, descriptor_dim: int, num_heads: int, gamma: float = 1.0) -> None:
super().__init__()
self.num_heads = num_heads
head_dim = descriptor_dim // num_heads
self.Wr = nn.Linear(M, head_dim // 2, bias=False)
self.gamma = gamma
nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma**-2)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""encode position vector"""
projected = self.Wr(x)
cosines, sines = torch.cos(projected), torch.sin(projected)
emb = torch.stack([cosines, sines])
return emb.repeat_interleave(2, dim=3).repeat(1, 1, 1, self.num_heads).unsqueeze(4)
class TokenConfidence(nn.Module):
def __init__(self, dim: int) -> None:
super().__init__()
self.token = nn.Sequential(nn.Linear(dim, 1), nn.Sigmoid())
def forward(self, desc0: torch.Tensor, desc1: torch.Tensor):
"""get confidence tokens"""
return (
self.token(desc0.detach()).squeeze(-1),
self.token(desc1.detach()).squeeze(-1),
)
class SelfBlock(nn.Module):
def __init__(self, embed_dim: int, num_heads: int, bias: bool = True) -> None:
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.ffn = nn.Sequential(
nn.Linear(2 * embed_dim, 2 * embed_dim),
nn.LayerNorm(2 * embed_dim, elementwise_affine=True),
nn.GELU(),
nn.Linear(2 * embed_dim, embed_dim),
)
def forward(self, x: torch.Tensor, encoding: torch.Tensor) -> torch.Tensor:
b, n, _ = x.shape
qkv: torch.Tensor = self.Wqkv(x)
qkv = qkv.reshape((b, n, self.embed_dim, 3))
qk, v = qkv[..., :2], qkv[..., 2]
qk = self.apply_cached_rotary_emb(encoding, qk)
q, k = qk[..., 0], qk[..., 1]
context = multi_head_attention_dispatch(q, k, v, self.num_heads)
message = self.out_proj(context)
return x + self.ffn(torch.concat([x, message], 2))
def rotate_half(self, qk: torch.Tensor) -> torch.Tensor:
b, n, _, _ = qk.shape
qk = qk.reshape((b, n, self.num_heads, self.head_dim // 2, 2, 2))
qk = torch.stack((-qk[..., 1, :], qk[..., 0, :]), dim=4)
qk = qk.reshape((b, n, self.embed_dim, 2))
return qk
def apply_cached_rotary_emb(self, encoding: torch.Tensor, qk: torch.Tensor) -> torch.Tensor:
return qk * encoding[0] + self.rotate_half(qk) * encoding[1]
class CrossBlock(nn.Module):
def __init__(self, embed_dim: int, num_heads: int, bias: bool = True) -> None:
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.to_qk = nn.Linear(embed_dim, embed_dim, bias=bias)
self.to_v = nn.Linear(embed_dim, embed_dim, bias=bias)
self.to_out = nn.Linear(embed_dim, embed_dim, bias=bias)
self.ffn = nn.Sequential(
nn.Linear(2 * embed_dim, 2 * embed_dim),
nn.LayerNorm(2 * embed_dim, elementwise_affine=True),
nn.GELU(),
nn.Linear(2 * embed_dim, embed_dim),
)
def forward(self, descriptors0: torch.Tensor, descriptors1: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
b, _, _ = descriptors0.shape
qk0, v0 = self.to_qk(descriptors0), self.to_v(descriptors0)
qk1, v1 = self.to_qk(descriptors1), self.to_v(descriptors1)
m0 = multi_head_attention_dispatch(qk0, qk1, v1, self.num_heads)
m0 = self.to_out(m0)
descriptors0 = descriptors0 + self.ffn(torch.concat([descriptors0, m0], 2))
m1 = multi_head_attention_dispatch(qk1, qk0, v0, self.num_heads)
m1 = self.to_out(m1)
descriptors1 = descriptors1 + self.ffn(torch.concat([descriptors1, m1], 2))
return descriptors0, descriptors1
class TransformerLayer(nn.Module):
def __init__(self, embed_dim: int, num_heads: int):
super().__init__()
self.self_attn = SelfBlock(embed_dim, num_heads)
self.cross_attn = CrossBlock(embed_dim, num_heads)
def forward(
self, descriptors0: torch.Tensor, descriptors1: torch.Tensor, encodings0: torch.Tensor, encodings1: torch.Tensor
) -> torch.Tensor:
descriptors0 = self.self_attn(descriptors0, encodings0)
descriptors1 = self.self_attn(descriptors1, encodings1)
return self.cross_attn(descriptors0, descriptors1)
def sigmoid_log_double_softmax(similarities: torch.Tensor, z0: torch.Tensor, z1: torch.Tensor) -> torch.Tensor:
"""create the log assignment matrix from logits and similarity"""
certainties = F.logsigmoid(z0) + F.logsigmoid(z1).transpose(1, 2)
scores0 = F.log_softmax(similarities, 2)
scores1 = F.log_softmax(similarities, 1)
scores = scores0 + scores1 + certainties
return scores
class MatchAssignment(nn.Module):
def __init__(self, dim: int) -> None:
super().__init__()
self.scale = dim**0.25
self.final_proj = nn.Linear(dim, dim, bias=True)
self.matchability = nn.Linear(dim, 1, bias=True)
def forward(self, descriptors0: torch.Tensor, descriptors1: torch.Tensor) -> torch.Tensor:
"""build assignment matrix from descriptors"""
mdescriptors0 = self.final_proj(descriptors0) / self.scale
mdescriptors1 = self.final_proj(descriptors1) / self.scale
similarities = mdescriptors0 @ mdescriptors1.transpose(1, 2)
z0 = self.matchability(descriptors0)
z1 = self.matchability(descriptors1)
scores = sigmoid_log_double_softmax(similarities, z0, z1)
return scores
def get_matchability(self, desc: torch.Tensor):
return torch.sigmoid(self.matchability(desc)).squeeze(-1)
def filter_matches(scores: torch.Tensor, threshold: float):
"""obtain matches from a log assignment matrix [BxNxN]"""
max0 = torch.topk(scores, k=1, dim=2, sorted=False) # scores.max(2)
max1 = torch.topk(scores, k=1, dim=1, sorted=False) # scores.max(1)
m0, m1 = max0.indices[:, :, 0], max1.indices[:, 0, :]
indices = torch.arange(m0.shape[1], device=m0.device).expand_as(m0)
mutual = indices == m1.gather(1, m0)
mscores = max0.values[:, :, 0].exp()
valid = mscores > threshold
b_idx, m0_idx = torch.where(valid & mutual)
m1_idx = m0[b_idx, m0_idx]
matches = torch.concat([b_idx[:, None], m0_idx[:, None], m1_idx[:, None]], 1)
mscores = mscores[b_idx, m0_idx]
return matches, mscores
class LightGlue(nn.Module):
version = "v0.1_arxiv"
url = "https://github.com/cvg/LightGlue/releases/download/{}/{}_lightglue.pth"
def __init__(
self,
extractor: Extractor,
descriptor_dim: int = 256,
num_heads: int = 4,
n_layers: int = 9,
filter_threshold: float = 0.1, # match threshold
depth_confidence: float = -1, # -1 is no early stopping, recommend: 0.95
width_confidence: float = -1, # -1 is no point pruning, recommend: 0.99
) -> None:
super().__init__()
self.descriptor_dim = descriptor_dim
self.num_heads = num_heads
self.n_layers = n_layers
self.filter_threshold = filter_threshold
self.depth_confidence = depth_confidence
self.width_confidence = width_confidence
if extractor.dim != self.descriptor_dim:
self.input_proj = nn.Linear(extractor.dim, self.descriptor_dim, bias=True)
else:
self.input_proj = nn.Identity()
self.posenc = LearnableFourierPositionalEncoding(2, self.descriptor_dim, self.num_heads)
d, h, n = self.descriptor_dim, self.num_heads, self.n_layers
self.transformers = nn.ModuleList([TransformerLayer(d, h) for _ in range(n)])
self.log_assignment = nn.ModuleList([MatchAssignment(d) for _ in range(n)])
self.token_confidence = nn.ModuleList([TokenConfidence(d) for _ in range(n - 1)])
self.register_buffer(
"confidence_thresholds",
torch.Tensor([self.confidence_threshold(i) for i in range(n)]),
)
state_dict = torch.hub.load_state_dict_from_url(self.url.format(self.version, extractor.value))
# rename old state dict entries
for i in range(n):
pattern = f"self_attn.{i}", f"transformers.{i}.self_attn"
state_dict = {k.replace(*pattern): v for k, v in state_dict.items()}
pattern = f"cross_attn.{i}", f"transformers.{i}.cross_attn"
state_dict = {k.replace(*pattern): v for k, v in state_dict.items()}
self.load_state_dict(state_dict, strict=False)
def forward(
self,
keypoints0: torch.Tensor, # (2B, N, 2), normalized
keypoints1: torch.Tensor,
descriptors0: torch.Tensor, # (2B, N, D)
descriptors1: torch.Tensor,
):
descriptors0 = self.input_proj(descriptors0)
descriptors1 = self.input_proj(descriptors1)
# positional embeddings
encodings0 = self.posenc(keypoints0) # (2, 2B, *, 64, 1)
encodings1 = self.posenc(keypoints1)
# GNN + final_proj + assignment
for i in range(self.n_layers):
# self+cross attention
descriptors0, descriptors1 = self.transformers[i](descriptors0, descriptors1, encodings0, encodings1)
scores = self.log_assignment[i](descriptors0, descriptors1) # (B, N, N)
matches, mscores = filter_matches(scores, self.filter_threshold)
return matches, mscores # (M, 3), (M,)
def confidence_threshold(self, layer_index: int) -> float:
"""scaled confidence threshold"""
threshold = 0.8 + 0.1 * np.exp(-4.0 * layer_index / self.n_layers)
return np.clip(threshold, 0, 1)
def get_pruning_mask(
self,
confidences: torch.Tensor | None,
scores: torch.Tensor,
layer_index: int,
) -> torch.Tensor:
"""mask points which should be removed"""
keep = scores > (1 - self.width_confidence)
if confidences is not None: # Low-confidence points are never pruned.
keep |= confidences <= self.confidence_thresholds[layer_index]
return keep
def check_if_stop(
self,
confidences0: torch.Tensor,
confidences1: torch.Tensor,
layer_index: int,
num_points: int,
) -> torch.Tensor:
"""evaluate stopping condition"""
confidences = torch.cat([confidences0, confidences1], -1)
threshold = self.confidence_thresholds[layer_index]
ratio_confident = 1.0 - (confidences < threshold).float().sum() / num_points
return ratio_confident > self.depth_confidence
and then adjusting the Pipeline
class to orchestrate SuperPoint(100)
and SuperPoint(200)
accordingly.
Hi @fabio-sim thank you very much! I'm now working on the code. But I met an error
fused_multi_head_attention = torch.library.custom_op(CUSTOM_OP_NAME, mutates_args=())(multi_head_attention)
AttributeError: module 'torch.library' has no attribute 'custom_op'
My torch is >=2.1
Oh apologies, my mistake. torch.library.custom_op
needs torch >= 2.4
. I should've put a check.
I think it's fine if you comment it out
@fabio-sim Thank you for your comments.
orig_image0 = cv2.imread(img0_path, cv2.IMREAD_COLOR)
orig_image1 = cv2.imread(img1_path, cv2.IMREAD_COLOR)
viz2d.plot_images(
[orig_image0, orig_image1]
)
assert np.all(kpts0[2][matches[..., 1]] == kpts0[0][matches[..., 1]])
assert np.all(kpts1[2][matches[..., 2]] == kpts1[0][matches[..., 2]])
viz2d.plot_matches(kpts0[0][matches[..., 1]], kpts1[0][matches[...,2]], color="lime", lw=0.2)
viz2d.save_plot('aaa1.jpg', dpi=300)
viz2d.plt.show()
viz2d.plot_matches(kpts0[2][matches[..., 1]], kpts1[2][matches[..., 2]], color="lime", lw=0.2)
viz2d.save_plot('aaa2.jpg', dpi=300)
viz2d.plt.show()
I used the above code to visualize. I used batchsize=4, and for the first and the third image pairs, they are the same, and for the other two image pairs I used random arrays. Here I assert that the output for the first and the third pairs are the same. However, when visualizing the results, there are always several matches are changing and incorrect. Do you know the reason?
update: The problem has been solved. I didn't parse the returned matches
correctly. Now it works. Thanks a million for your help!! Now I close this ticket
Hi, sorry, I still have a problem.
def multi_head_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int) -> torch.Tensor:
b, n, d = q.shape
head_dim = d // num_heads
q, k, v = (t.reshape((b, n, num_heads, head_dim)).transpose(1, 2) for t in (q, k, v))
return F.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape((b, n, d))
I found that when image pairs have different numbers of keypoints, the multi_head_attention
will throw an error.
For example, for the left images the dimension is (2, 99, 64), and for the right images the dimension is (2, 256, 64). Here 256 is the max_number of keypoints I set. but it extracts 99 keypoints on the left images. Then in the multi_head_attention
function it throws the error
q, k, v = (t.reshape((b, n, num_heads, head_dim)).transpose(1, 2) for t in (q, k, v)) RuntimeError: shape '[2, 99, 4, 16]' is invalid for input of size 32768
because for the right images it's [2, 256, 4, 16]: 2x256x4x16=32768 (Please notice here that my keypoint descriptor is 64D, instead of 256D. It's a customized network).
I tried to modify the function as follows, the error was gone but the matching result is bad.
def multi_head_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int) -> torch.Tensor:
b, n, d = q.shape
nk = k.shape[1]
head_dim = d // num_heads
q = q.reshape(b, n, num_heads, head_dim).transpose(1, 2)
k = k.reshape(b, nk, num_heads, head_dim).transpose(1, 2)
v = v.reshape(b, nk, num_heads, head_dim).transpose(1, 2)
return F.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape((b, n, d))
Could you help me with that? Thank you in advance!
Update: I have used the old implementation for CrossBlock and it works with different numbers of keypoints.
Ah yes, if you have different number of keypoints, that means the sequence length of Q is different from that of K & V.
Use something like this instead:
def multi_head_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int) -> torch.Tensor:
b, n, d = q.shape
_, n1, _ = k.shape
head_dim = d // num_heads
q = q.reshape((b, n, num_heads, head_dim)).transpose(1, 2)
k, v = (t.reshape((b, n1, num_heads, head_dim)).transpose(1, 2) for t in (k, v))
return F.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape((b, n, d))
Thank you again for your help!