A "model.py" file for cross-modal retrieval preformance improvement.
Opened this issue · 0 comments
If you use the following "model.py" file, cross-modal retrieval performance will be greatly improved.
model.py
`import argparse
import pickle
import torchtext
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.nn.init
import torchvision.models as models
from torch.autograd import Variable
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import torch.backends.cudnn as cudnn
from torch.nn.utils.clip_grad import clip_grad_norm_
import numpy as np
import json
from vocab import deserialize_vocab
import os
from einops import rearrange
import math
from clipres import ModifiedResNet
from collections import OrderedDict
def l2norm(X, dim=-1, eps=1e-8):
"""L2-normalize columns of X"""
norm = torch.pow(X, 2).sum(dim=dim, keepdim=True).sqrt() + eps
X = torch.div(X, norm)
return X
class GateF(nn.Module):
def init(self,dim = 1024):
super().init()
self.dim = dim
self.linear1 = nn.Linear(dim,dim//16)
self.linear2 = nn.Linear(dim//16,dim)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
def forward(self,x):
t = self.linear1(x)
t = self.relu(t)
t = self.linear2(t)
t = self.sigmoid(t)
return t * x
class EncoderImage(nn.Module):
def __init__(self, embed_size, finetune=False, cnn_type='CLIP-RN101',
no_imgnorm=False):
"""Load pretrained VGG19 and replace top fc layer."""
super(EncoderImage, self).__init__()
self.embed_size = embed_size
self.no_imgnorm = no_imgnorm
self.cnn = self.get_cnn(cnn_type, True)
print('finetune:', finetune)
for param in self.cnn.parameters():
param.requires_grad = finetune
if cnn_type.startswith('vgg'):
self.fc = nn.Linear(self.cnn.classifier._modules['6'].in_features,
embed_size)
self.cnn.classifier = nn.Sequential(
*list(self.cnn.classifier.children())[:-1])
elif cnn_type.startswith('resnet'):
# print(self.cnn.module.fc.in_features)
self.fc = nn.Linear(self.cnn.module.fc.in_features, embed_size)
self.cnn.module.fc = nn.Sequential()
elif cnn_type.startswith('CLIP'):
self.fc = nn.Linear(512,embed_size)
self.fc_attn_i = nn.Linear(1024,1024)
self.fusion = Fusion()
self.gatef1 = GateF()
self.init_weights()
def get_cnn(self, arch, pretrained):
"""Load a pretrained CNN and parallelize over GPUs
"""
if pretrained:
print("=> using pre-trained model '{}'".format(arch))
statefile = torch.jit.load(open('RN101.pt', 'rb'))
state_dict = statefile.state_dict()
counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in
[1, 2, 3, 4]]
vision_layers = tuple(counts)
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
vision_patch_size = None
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
image_resolution = output_width * 32
embed_dim = state_dict["text_projection"].shape[1]
context_length = state_dict["positional_embedding"].shape[0]
vocab_size = state_dict["token_embedding.weight"].shape[0]
vision_heads = vision_width * 32 // 64
dictin = OrderedDict()
for key, value in state_dict.items():
if key.startswith('visual'):
dictin[key[7:]] = value
model = ModifiedResNet(
layers=vision_layers,
output_dim=embed_dim,
heads=vision_heads,
input_resolution=image_resolution,
width=vision_width
)
model.load_state_dict(dictin)
else:
print("=> creating model '{}'".format(arch))
model = models.__dict__[arch]()
if arch.startswith('alexnet') or arch.startswith('vgg'):
model.features = nn.DataParallel(model.features)
else:
model = nn.DataParallel(model)
if torch.cuda.is_available():
model.cuda()
return model
def load_state_dict(self, state_dict):
"""
Handle the models saved before commit pytorch/vision@989d52a
"""
if 'cnn.classifier.1.weight' in state_dict:
state_dict['cnn.classifier.0.weight'] = state_dict[
'cnn.classifier.1.weight']
del state_dict['cnn.classifier.1.weight']
state_dict['cnn.classifier.0.bias'] = state_dict[
'cnn.classifier.1.bias']
del state_dict['cnn.classifier.1.bias']
state_dict['cnn.classifier.3.weight'] = state_dict[
'cnn.classifier.4.weight']
del state_dict['cnn.classifier.4.weight']
state_dict['cnn.classifier.3.bias'] = state_dict[
'cnn.classifier.4.bias']
del state_dict['cnn.classifier.4.bias']
super(EncoderImage, self).load_state_dict(state_dict)
def init_weights(self):
"""Xavier initialization for the fully connected layer
"""
r = np.sqrt(6.) / np.sqrt(self.fc.in_features +
self.fc.out_features)
self.fc.weight.data.uniform_(-r, r)
self.fc.bias.data.fill_(0)
r = np.sqrt(6.) / np.sqrt(self.fc_attn_i.in_features +
self.fc_attn_i.out_features)
self.fc_attn_i.weight.data.uniform_(-r, r)
self.fc_attn_i.bias.data.fill_(0)
def forward(self, images, local_image):
"""Extract image feature vectors."""
features = self.cnn(images)
features = l2norm(features)
# linear projection to the joint embedding space
features = self.fc(features)
features = self.gatef1(features)
# normalization in the joint embedding space
features_1 = self.fc_attn_i(local_image)
# features = l2norm(local_imgae)
features_f = l2norm(self.fusion(features,features_1))
return features, features_f
class Fusion(nn.Module):
def init(self):
super(Fusion, self).init()
self.f_size = 1024
self.gate0 = nn.Linear(self.f_size*2, self.f_size)
self.gate1 = nn.Linear(self.f_size, self.f_size)
# self.fusion0 = nn.Linear(self.f_size, self.f_size)
# self.fusion1 = nn.Linear(self.f_size, self.f_size)
def forward(self, vec1, vec2):
vec = torch.cat((vec1,vec2),dim=1)
features_1 = self.gate0(vec)
features_2 = self.gate1(vec2)
t = torch.sigmoid(features_1)
f = t * vec1 + (1 - t) * vec2
return f
class EncoderRegion(nn.Module):
def init(self, opt):
super(EncoderRegion, self).init()
self.fc_region = nn.Linear(2048, opt.embed_size)
self.init_weights()
def init_weights(self):
"""Xavier initialization for the fully connected layer
"""
r = np.sqrt(6.) / np.sqrt(self.fc_region.in_features +
self.fc_region.out_features)
self.fc_region.weight.data.uniform_(-r, r)
self.fc_region.bias.data.fill_(0)
def forward(self, region_feat):
region_feat = self.fc_region(region_feat)
region_feat = l2norm(region_feat, dim=-1)
return region_feat
class EncoderWord(nn.Module):
def __init__(self, opt):
super(EncoderWord, self).__init__()
self.embed_size = opt.embed_size
# word embedding
self.embed = nn.Embedding(opt.vocab_size, opt.word_dim)
# caption embedding
self.rnn = nn.GRU(opt.word_dim, opt.embed_size, opt.num_layers, batch_first=True)
self.gatef2 = GateF()
vocab = deserialize_vocab(os.path.join('./vocab/'+opt.data_name+'_precomp_vocab.json'))
word2idx = vocab.word2idx
# self.init_weights()
self.init_weights('glove', word2idx, opt.word_dim)
self.dropout = nn.Dropout(0.1)
def init_weights(self, wemb_type, word2idx, word_dim):
if wemb_type.lower() == 'random_init':
nn.init.xavier_uniform_(self.embed.weight)
else:
# Load pretrained word embedding
if 'fasttext' == wemb_type.lower():
wemb = torchtext.vocab.FastText()
elif 'glove' == wemb_type.lower():
wemb = torchtext.vocab.GloVe()
else:
raise Exception('Unknown word embedding type: {}'.format(wemb_type))
assert wemb.vectors.shape[1] == word_dim
# quick-and-dirty trick to improve word-hit rate
missing_words = []
for word, idx in word2idx.items():
if word not in wemb.stoi:
word = word.replace('-', '').replace('.', '').replace("'", '')
if '/' in word:
word = word.split('/')[0]
if word in wemb.stoi:
self.embed.weight.data[idx] = wemb.vectors[wemb.stoi[word]]
else:
missing_words.append(word)
print('Words: {}/{} found in vocabulary; {} words missing'.format(
len(word2idx) - len(missing_words), len(word2idx), len(missing_words)))
def forward(self, x, lengths):
# return out
x = self.embed(x)
x = self.dropout(x)
packed = pack_padded_sequence(x, lengths, batch_first=True,enforce_sorted=False)
# Forward propagate RNN
out, _ = self.rnn(packed)
# Reshape *final* output to (batch_size, hidden_size)
padded = pad_packed_sequence(out, batch_first=True)
cap_emb, cap_len = padded
cap_emb = l2norm(cap_emb, dim=-1)
cap_emb_mean = torch.mean(cap_emb, 1)
cap_emb_mean = self.gatef2(cap_emb_mean)
cap_emb_mean = l2norm(cap_emb_mean)
return cap_emb, cap_emb_mean
class EncoderText(nn.Module):
def init(self, opt):
super(EncoderText, self).init()
# self.sa = TextSA(opt.embed_size, 0.4)
self.fc_text = nn.Linear(1024,1024)
self.fusion = Fusion()
self.init_weights()
def init_weights(self):
"""Xavier initialization for the fully connected layer
"""
r = np.sqrt(6.) / np.sqrt(self.fc_text.in_features +
self.fc_text.out_features)
self.fc_text.weight.data.uniform_(-r, r)
self.fc_text.bias.data.fill_(0)
def forward(self, word_emb, text_emb):
# word_emb_mean = torch.mean(word_emb, 1)
# cap_emb = self.sa(word_emb, word_emb_mean)
word_emb = self.fc_text(word_emb)
emb = self.fusion(text_emb,word_emb)
emb = l2norm(emb)
return emb
def cosine_similarity(x1, x2, dim=1, eps=1e-8):
"""Returns cosine similarity between x1 and x2, computed along dim."""
w12 = torch.sum(x1 * x2, dim)
w1 = torch.norm(x1, 2, dim)
w2 = torch.norm(x2, 2, dim)
return (w12 / (w1 * w2).clamp(min=eps)).squeeze()
def func_attention(query, context, opt, smooth, eps=1e-8):
"""
query: (n_context, queryL, d)
context: (n_context, sourceL, d)
"""
# Get attention
# --> (batch, d, queryL)
queryT = torch.transpose(query, 1, 2)
# (batch, sourceL, d)(batch, d, queryL)
# --> (batch, sourceL, queryL)
attn = torch.bmm(context, queryT)
attn = nn.LeakyReLU(0.1)(attn)
attn = l2norm(attn, 2)
# --> (batch, queryL, sourceL)
attn = torch.transpose(attn, 1, 2).contiguous()
# --> (batch*queryL, sourceL)
attn = F.softmax(attn * smooth, dim=2)
# --> (batch, sourceL, queryL)
attnT = torch.transpose(attn, 1, 2).contiguous()
# --> (batch, d, sourceL)
contextT = torch.transpose(context, 1, 2)
# (batch x d x sourceL)(batch x sourceL x queryL)
# --> (batch, d, queryL)
weightedContext = torch.bmm(contextT, attnT)
# --> (batch, queryL, d)
weightedContext = torch.transpose(weightedContext, 1, 2)
return weightedContext
def xattn_score_t2i(images, captions, cap_lens, opt):
"""
Images: (n_image, n_regions, d) matrix of images
Captions: (n_caption, max_n_word, d) matrix of captions
CapLens: (n_caption) array of caption lengths
"""
similarities = []
weiContext_i = []
n_image = images.size(0)
n_caption = captions.size(0)
for i in range(n_caption):
# Get the i-th text description
n_word = cap_lens[i]
cap_i = captions[i, :n_word, :].unsqueeze(0).contiguous()
# --> (n_image, n_word, d)
cap_i_expand = cap_i.repeat(n_image, 1, 1)
"""
word(query): (n_image, n_word, d)
image(context): (n_image, n_regions, d)
weiContext: (n_image, n_word, d)
attn: (n_image, n_region, n_word)
"""
weiContext = func_attention(cap_i_expand, images, opt, smooth=9.)
cap_i_expand = cap_i_expand.contiguous()
weiContext = weiContext.contiguous()
# (n_image, n_word)
row_sim = cosine_similarity(cap_i_expand, weiContext, dim=2)
row_sim = row_sim.mean(dim=1, keepdim=True)
similarities.append(row_sim)
weiContext = weiContext.mean(dim=1, keepdim=True)
weiContext_i.append(weiContext)
# (n_image, n_caption)
similarities = torch.cat(similarities, 1)
weiContext_i = torch.cat(weiContext_i, 1)
weiContext_i = [weiContext_i[i, i, :].view(1, 1024) for i in range(n_image)]
weiContext_i = torch.cat(weiContext_i, 0)
weiContext_i = torch.mean(weiContext_i, dim=1)
return similarities,weiContext_i
def xattn_score_i2t(images, captions, cap_lens, opt):
"""
Images: (batch_size, n_regions, d) matrix of images
Captions: (batch_size, max_n_words, d) matrix of captions
CapLens: (batch_size) array of caption lengths
"""
similarities = []
weiContext_t = []
n_image = images.size(0)
n_caption = captions.size(0)
for i in range(n_caption):
# Get the i-th text description
n_word = cap_lens[i]
cap_i = captions[i, :n_word, :].unsqueeze(0).contiguous()
# (n_image, n_word, d)
cap_i_expand = cap_i.repeat(n_image, 1, 1)
"""
word(query): (n_image, n_word, d)
image(context): (n_image, n_region, d)
weiContext: (n_image, n_region, d)
attn: (n_image, n_word, n_region)
"""
weiContext = func_attention(images, cap_i_expand, opt, smooth=4.)
# (n_image, n_region)
row_sim = cosine_similarity(images, weiContext, dim=2)
row_sim = row_sim.mean(dim=1, keepdim=True)
similarities.append(row_sim)
weiContext = weiContext.mean(dim=1, keepdim=True)
weiContext_t.append(weiContext)
# (n_image, n_caption)
similarities = torch.cat(similarities, 1)
weiContext_t = torch.cat(weiContext_t, 1)
weiContext_t = [weiContext_t[i, i, :].view(1, 1024) for i in range(n_image)]
weiContext_t = torch.cat(weiContext_t, 0)
weiContext_t = torch.mean(weiContext, dim=1)
return similarities, weiContext_t
class LA(nn.Module):
def init(self,nchannels=1024):
super().init()
self.kernel_sizes = (1,3,5,7)
self.nchannels = nchannels
self.groups = len(self.kernel_sizes)
self.split_channels = [nchannels // self.groups for _ in range(self.groups)]
self.split_channels[0] += nchannels - sum(self.split_channels)
self.layers = nn.Sequential(nn.Conv1d(in_channels=self.split_channels[0],out_channels=self.split_channels[0],
kernel_size=self.kernel_sizes[0],padding=self.kernel_sizes[0]//2,groups=self.split_channels[0]),
nn.Conv1d(in_channels=self.split_channels[1], out_channels=self.split_channels[1],
kernel_size=self.kernel_sizes[1], padding=self.kernel_sizes[1] // 2,
groups=self.split_channels[1]),
nn.Conv1d(in_channels=self.split_channels[2], out_channels=self.split_channels[2],
kernel_size=self.kernel_sizes[2], padding=self.kernel_sizes[2] // 2,
groups=self.split_channels[2]),
nn.Conv1d(in_channels=self.split_channels[3], out_channels=self.split_channels[3],
kernel_size=self.kernel_sizes[3], padding=self.kernel_sizes[3] // 2,
groups=self.split_channels[3]),
)
def forward(self,x):
xi = x.transpose(1,2)
split_x = torch.split(xi,self.split_channels,dim=1)
outputs = [layer(sp_x) for layer,sp_x in zip(self.layers,split_x)]
outputs = torch.cat(outputs,dim=1)
B,C,N = outputs.size()
channels_per_groups = C // 4
# outputs = outputs.view(B,4,channels_per_groups,N)
# outputs = outputs.transpose(1,2).contiguous()
# outputs = outputs.view(B,C,N)
return outputs.transpose(1,2) + x
class SLA(nn.Module):
def init(self,dim,heads=16,dim_head=64,dropout=0.1):
super(SLA, self).init()
inner_dim = dim_head * heads
self.dim_head = dim_head
project_out = not(heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.attn = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qk = nn.Linear(dim,inner_dim * 2)
self.to_out = nn.Sequential(
nn.Linear(inner_dim,dim),
nn.Dropout(dropout),
)
self.ln = nn.LayerNorm(dim)
self.la = LA()
def forward(self,x):
b,n,c = x.shape
xo = self.ln(x)
xl = self.la(xo)
v = xo.view(b,n,self.heads,self.dim_head).transpose(1,2)
qk = self.to_qk(xo).chunk(2,dim=-1)
q,k = map(lambda t:rearrange(t,'b n (h d) -> b h n d',h = self.heads),qk)
dots = torch.matmul(q,k.transpose(-1,-2)) * self.scale
attn = self.attn(dots)
attn = self.dropout(attn)
out = torch.matmul(attn,v)
out = rearrange(out,'b h n d -> b n (h d)')
out = self.to_out(out)
return out + x + xl
def cosine_sim(im, s):
"""Cosine similarity between all the image and sentence pairs
"""
return im.mm(s.t())
def Adaptive_Margin(scores):
bs = scores.size(0)
r = (torch.sum(scores >= 0) - bs) / (scores.numel() - bs)
return round(r.item(),4)
class ContrastiveLoss(nn.Module):
def __init__(self, opt, margin=0, measure=False, max_violation=False):
super(ContrastiveLoss, self).__init__()
# self.net_type = opt.type
self.margin = margin
# self.margin = 0.2
self.opt = opt
self.sim = cosine_sim
self.max_violation = max_violation
self.interval = self.opt.interval
self.t = self.opt.threshold
def forward(self, im, s, region_feats, word_feats, length, gimg, gtxt, sims_local, iters):
scores_g2l = self.sim(im,s)
scores_global = self.sim(gimg, gtxt)
scores_local = sims_local
scores = scores_global + scores_local + scores_g2l
diagonal = scores.diag().view(im.size(0), 1)
d1 = diagonal.expand_as(scores)
d2 = diagonal.t().expand_as(scores)
if iters % self.interval == 0:
k1 = Adaptive_Margin(scores - d1 + self.margin)
k2 = Adaptive_Margin(scores - d2 + self.margin)
k = max(k1, k2)
print(k, self.t)
if k >= self.t:
self.margin += 0.01 * (1 - math.exp(-k))
self.t *= math.exp(-iters / self.opt.iterations)
self.margin = round(self.margin,4)
self.t = round(self.t,4)
# compare every diagonal score to scores in its column
# caption retrieval
cost_s = (self.margin + scores - d1).clamp(min=0)
# compare every diagonal score to scores in its row
# image retrieval
cost_im = (self.margin + scores - d2).clamp(min=0)
# clear diagonals
mask = torch.eye(scores.size(0)) > .5
I = Variable(mask)
if torch.cuda.is_available():
I = I.cuda()
cost_s = cost_s.masked_fill_(I, 0)
cost_im = cost_im.masked_fill_(I, 0)
# keep the maximum violating negative for each query
cost_s = cost_s.max(1)[0]
cost_im = cost_im.max(0)[0]
return cost_s.sum() + cost_im.sum()
class JZK(nn.Module):
def __init__(self, opt, pre_scan=False):
super().__init__()
# self.net_type = opt.type
self.opt = opt
self.grad_clip = opt.grad_clip
self.img_enc = EncoderImage(opt.embed_size,
opt.finetune, opt.cnn_type,
no_imgnorm=opt.no_imgnorm)
self.region_enc = EncoderRegion(opt)
self.cap_enc = EncoderText(opt)
self.word_enc = EncoderWord(opt)
self.sai = SLA(opt.embed_size)
self.sat = SLA(opt.embed_size)
# self.sai = LA()
# self.sat = LA()
# self.label_enc = EncoderLabel(opt)
if torch.cuda.is_available():
self.img_enc.cuda()
self.cap_enc.cuda()
self.region_enc.cuda()
self.word_enc.cuda()
self.sai.cuda()
self.sat.cuda()
# self.label_enc.cuda()
cudnn.benchmark = True
# Loss and Optimizer
self.criterion = ContrastiveLoss(opt, margin=opt.margin,
measure=opt.measure,
max_violation=opt.max_violation)
params = list(self.img_enc.fc.parameters())
if opt.finetune:
params += list(self.img_enc.cnn.parameters())
params += list(self.word_enc.parameters())
params += list(self.region_enc.parameters())
params += list(self.cap_enc.parameters())
params += list(self.sai.parameters())
params += list(self.sat.parameters())
self.params = params
self.optimizer = torch.optim.Adam(params, lr=opt.learning_rate)
self.Eiters = 0
def state_dict(self):
# state_dict = [self.img_enc.state_dict(), self.cap_enc.state_dict(), self.label_enc.state_dict(),
# self.region_enc.state_dict(), self.word_enc.state_dict()]
state_dict = [self.img_enc.state_dict(), self.cap_enc.state_dict(),
self.region_enc.state_dict(), self.word_enc.state_dict(),
self.sai.state_dict(),self.sat.state_dict()]
return state_dict
def load_state_dict(self, state_dict):
self.img_enc.load_state_dict(state_dict[0])
self.cap_enc.load_state_dict(state_dict[1])
# self.label_enc.load_state_dict(state_dict[2])
self.region_enc.load_state_dict(state_dict[2])
self.word_enc.load_state_dict(state_dict[3])
self.sai.load_state_dict(state_dict[4])
self.sat.load_state_dict(state_dict[5])
def train_start(self):
"""switch to train mode
"""
self.img_enc.train()
self.cap_enc.train()
# self.label_enc.train()
self.region_enc.train()
self.word_enc.train()
self.sai.train()
self.sat.train()
def val_start(self):
"""switch to evaluate mode
"""
self.img_enc.eval()
self.cap_enc.eval()
# self.label_enc.eval()
self.region_enc.eval()
self.word_enc.eval()
self.sai.eval()
self.sat.eval()
def forward_emb(self, images, region_feat, captions, lengths):
"""Compute the image and caption embeddings
"""
# Set mini-batch dataset
images = Variable(images)
captions = Variable(captions)
region_feat = Variable(region_feat)
if torch.cuda.is_available():
images = images.cuda()
captions = captions.cuda()
region_feat = region_feat.cuda()
# Forward
region_emb = self.region_enc(region_feat)
word_emb, text_emb = self.word_enc(captions, lengths)
sims_local, attn_txt, attn_img = self.local_sim(region_emb,word_emb,lengths)
attn_img = self.sai(attn_img)
attn_txt = self.sat(attn_txt)
attn_img = [attn_img[i, i, :].view(1, 1024) for i in range(attn_img.size(0))]
attn_img = torch.cat(attn_img, 0)
attn_txt = [attn_txt[i, i, :].view(1, 1024) for i in range(attn_txt.size(0))]
attn_txt = torch.cat(attn_txt, 0)
gloi_emb, img_emb = self.img_enc(images, attn_img)
cap_emb = self.cap_enc(attn_txt, text_emb)
# img_label, cap_label, label = self.label_enc(img_emb, cap_emb, region_emb, region_cls, word_emb, lengths)
return img_emb, cap_emb, region_emb, word_emb, lengths, gloi_emb, text_emb, sims_local
def forward_loss(self, img_emb, cap_emb, region_emb, word_emb, lengths, gimg, gtxt,sims_local, iters,**kwargs):
"""Compute the loss given pairs of image and caption embeddings
"""
loss = self.criterion(img_emb, cap_emb, region_emb, word_emb, lengths, gimg,gtxt,sims_local,iters)
self.logger.update('Loss', loss.item(), img_emb.size(0))
return loss
def train_emb(self, images, region_feat, captions, lengths, ids=None, *args):
"""One training step given images and captions.
"""
self.Eiters += 1
self.logger.update('Eit', self.Eiters)
self.logger.update('lr', self.optimizer.param_groups[0]['lr'])
# compute the embeddings
img_emb, cap_emb, region_emb, word_emb, lengths, gimg,gtxt,sims_local = \
self.forward_emb(images, region_feat, captions,lengths)
# measure accuracy and record loss
self.optimizer.zero_grad()
loss = self.forward_loss(img_emb, cap_emb, region_emb, word_emb, lengths, gimg, gtxt, sims_local, self.Eiters)
# compute gradient and do SGD step
loss.backward()
if self.grad_clip > 0:
clip_grad_norm_(self.params, self.grad_clip)
self.optimizer.step()
def local_sim(self, region_emb, word_emb, length):
attn_i = None
attn_t = None
scores = None
if self.opt.cross_attn == 't2i':
scores, attn_i = xattn_score_t2i(region_emb, word_emb, length, self.opt)
elif self.opt.cross_attn == 'i2t':
scores, attn_t = xattn_score_i2t(region_emb, word_emb, length, self.opt)
elif self.opt.cross_attn == 'all':
score1, attn_t = xattn_score_i2t(region_emb, word_emb, length, self.opt)
score2, attn_i = xattn_score_t2i(region_emb, word_emb, length, self.opt)
scores = 0.5 * (score1 + score2)
return scores, attn_t, attn_i
def xattn_score_t2i1(images, captions, cap_lens, opt):
similarities = []
n_image = images.size(0)
n_caption = captions.size(0)
for i in range(n_caption):
# Get the i-th text description
n_word = cap_lens[i]
cap_i = captions[i, :n_word, :].unsqueeze(0).contiguous()
# --> (n_image, n_word, d)
cap_i_expand = cap_i.repeat(n_image, 1, 1)
"""
word(query): (n_image, n_word, d)
image(context): (n_image, n_regions, d)
weiContext: (n_image, n_word, d)
attn: (n_image, n_region, n_word)
"""
weiContext = func_attention(cap_i_expand, images, opt, smooth=9.)
cap_i_expand = cap_i_expand.contiguous()
weiContext = weiContext.contiguous()
# (n_image, n_word)
row_sim = cosine_similarity(cap_i_expand, weiContext, dim=2)
row_sim = row_sim.mean(dim=1, keepdim=True)
similarities.append(row_sim)
# (n_image, n_caption)
similarities = torch.cat(similarities, 1)
return similarities
def xattn_score_i2t1(images, captions, cap_lens, opt):
similarities = []
n_image = images.size(0)
n_caption = captions.size(0)
n_region = images.size(1)
for i in range(n_caption):
# Get the i-th text description
n_word = cap_lens[i]
cap_i = captions[i, :n_word, :].unsqueeze(0).contiguous()
# (n_image, n_word, d)
cap_i_expand = cap_i.repeat(n_image, 1, 1)
"""
word(query): (n_image, n_word, d)
image(context): (n_image, n_region, d)
weiContext: (n_image, n_region, d)
attn: (n_image, n_word, n_region)
"""
weiContext = func_attention(images, cap_i_expand, opt, smooth=4.)
# (n_image, n_region)
row_sim = cosine_similarity(images, weiContext, dim=2)
row_sim = row_sim.mean(dim=1, keepdim=True)
similarities.append(row_sim)
# (n_image, n_caption)
similarities = torch.cat(similarities, 1)
return similarities`