xiaohu306/MLASM

A "model.py" file for cross-modal retrieval preformance improvement.

Opened this issue · 0 comments

If you use the following "model.py" file, cross-modal retrieval performance will be greatly improved.
model.py

`import argparse
import pickle
import torchtext
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.nn.init
import torchvision.models as models
from torch.autograd import Variable
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import torch.backends.cudnn as cudnn
from torch.nn.utils.clip_grad import clip_grad_norm_
import numpy as np
import json
from vocab import deserialize_vocab
import os
from einops import rearrange
import math
from clipres import ModifiedResNet
from collections import OrderedDict

def l2norm(X, dim=-1, eps=1e-8):
"""L2-normalize columns of X"""
norm = torch.pow(X, 2).sum(dim=dim, keepdim=True).sqrt() + eps
X = torch.div(X, norm)
return X

class GateF(nn.Module):
def init(self,dim = 1024):
super().init()
self.dim = dim
self.linear1 = nn.Linear(dim,dim//16)
self.linear2 = nn.Linear(dim//16,dim)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
def forward(self,x):
t = self.linear1(x)
t = self.relu(t)
t = self.linear2(t)
t = self.sigmoid(t)
return t * x

class EncoderImage(nn.Module):

def __init__(self, embed_size, finetune=False, cnn_type='CLIP-RN101',
             no_imgnorm=False):
    """Load pretrained VGG19 and replace top fc layer."""
    super(EncoderImage, self).__init__()
    self.embed_size = embed_size
    self.no_imgnorm = no_imgnorm

    self.cnn = self.get_cnn(cnn_type, True)
    print('finetune:', finetune)

    for param in self.cnn.parameters():
        param.requires_grad = finetune

    if cnn_type.startswith('vgg'):
        self.fc = nn.Linear(self.cnn.classifier._modules['6'].in_features,
                            embed_size)
        self.cnn.classifier = nn.Sequential(
            *list(self.cnn.classifier.children())[:-1])
    elif cnn_type.startswith('resnet'):
        # print(self.cnn.module.fc.in_features)
        self.fc = nn.Linear(self.cnn.module.fc.in_features, embed_size)
        self.cnn.module.fc = nn.Sequential()
    elif cnn_type.startswith('CLIP'):
        self.fc = nn.Linear(512,embed_size)
    self.fc_attn_i = nn.Linear(1024,1024)
    self.fusion = Fusion()
    self.gatef1 = GateF()
    self.init_weights()

def get_cnn(self, arch, pretrained):
    """Load a pretrained CNN and parallelize over GPUs
    """
    if pretrained:
        print("=> using pre-trained model '{}'".format(arch))
        statefile = torch.jit.load(open('RN101.pt', 'rb'))
        state_dict = statefile.state_dict()
        counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in
                        [1, 2, 3, 4]]
        vision_layers = tuple(counts)
        vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
        output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
        vision_patch_size = None
        assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
        image_resolution = output_width * 32

        embed_dim = state_dict["text_projection"].shape[1]
        context_length = state_dict["positional_embedding"].shape[0]
        vocab_size = state_dict["token_embedding.weight"].shape[0]
        vision_heads = vision_width * 32 // 64

        dictin = OrderedDict()
        for key, value in state_dict.items():
            if key.startswith('visual'):
                dictin[key[7:]] = value

        model = ModifiedResNet(
            layers=vision_layers,
            output_dim=embed_dim,
            heads=vision_heads,
            input_resolution=image_resolution,
            width=vision_width
        )

        model.load_state_dict(dictin)

    else:
        print("=> creating model '{}'".format(arch))
        model = models.__dict__[arch]()

    if arch.startswith('alexnet') or arch.startswith('vgg'):
        model.features = nn.DataParallel(model.features)
    else:
        model = nn.DataParallel(model)

    if torch.cuda.is_available():
        model.cuda()

    return model

def load_state_dict(self, state_dict):
    """
    Handle the models saved before commit pytorch/vision@989d52a
    """
    if 'cnn.classifier.1.weight' in state_dict:
        state_dict['cnn.classifier.0.weight'] = state_dict[
            'cnn.classifier.1.weight']
        del state_dict['cnn.classifier.1.weight']
        state_dict['cnn.classifier.0.bias'] = state_dict[
            'cnn.classifier.1.bias']
        del state_dict['cnn.classifier.1.bias']
        state_dict['cnn.classifier.3.weight'] = state_dict[
            'cnn.classifier.4.weight']
        del state_dict['cnn.classifier.4.weight']
        state_dict['cnn.classifier.3.bias'] = state_dict[
            'cnn.classifier.4.bias']
        del state_dict['cnn.classifier.4.bias']

    super(EncoderImage, self).load_state_dict(state_dict)

def init_weights(self):
    """Xavier initialization for the fully connected layer
    """
    r = np.sqrt(6.) / np.sqrt(self.fc.in_features +
                              self.fc.out_features)
    self.fc.weight.data.uniform_(-r, r)
    self.fc.bias.data.fill_(0)
    r = np.sqrt(6.) / np.sqrt(self.fc_attn_i.in_features +
                              self.fc_attn_i.out_features)
    self.fc_attn_i.weight.data.uniform_(-r, r)
    self.fc_attn_i.bias.data.fill_(0)

def forward(self, images, local_image):
    """Extract image feature vectors."""
    features = self.cnn(images)
    features = l2norm(features)
    # linear projection to the joint embedding space
    features = self.fc(features)
    features = self.gatef1(features)
    # normalization in the joint embedding space
    features_1 = self.fc_attn_i(local_image)
    
    # features = l2norm(local_imgae)
    features_f = l2norm(self.fusion(features,features_1))
    return features, features_f

class Fusion(nn.Module):
def init(self):
super(Fusion, self).init()
self.f_size = 1024
self.gate0 = nn.Linear(self.f_size*2, self.f_size)

self.gate1 = nn.Linear(self.f_size, self.f_size)

#    self.fusion0 = nn.Linear(self.f_size, self.f_size)
#    self.fusion1 = nn.Linear(self.f_size, self.f_size)

def forward(self, vec1, vec2):
    vec = torch.cat((vec1,vec2),dim=1)
    features_1 = self.gate0(vec)

features_2 = self.gate1(vec2)

    t = torch.sigmoid(features_1)
    f = t * vec1 + (1 - t) * vec2
    return f

class EncoderRegion(nn.Module):
def init(self, opt):
super(EncoderRegion, self).init()
self.fc_region = nn.Linear(2048, opt.embed_size)
self.init_weights()

def init_weights(self):
    """Xavier initialization for the fully connected layer
    """
    r = np.sqrt(6.) / np.sqrt(self.fc_region.in_features +
                              self.fc_region.out_features)
    self.fc_region.weight.data.uniform_(-r, r)
    self.fc_region.bias.data.fill_(0)

def forward(self, region_feat):
    region_feat = self.fc_region(region_feat)
    region_feat = l2norm(region_feat, dim=-1)
    return region_feat

class EncoderWord(nn.Module):

def __init__(self, opt):
    super(EncoderWord, self).__init__()
    self.embed_size = opt.embed_size
    # word embedding
    self.embed = nn.Embedding(opt.vocab_size, opt.word_dim)
    # caption embedding
    self.rnn = nn.GRU(opt.word_dim, opt.embed_size, opt.num_layers, batch_first=True)
    self.gatef2 = GateF()
    vocab = deserialize_vocab(os.path.join('./vocab/'+opt.data_name+'_precomp_vocab.json'))
    word2idx = vocab.word2idx
    # self.init_weights()
    self.init_weights('glove', word2idx, opt.word_dim)
    self.dropout = nn.Dropout(0.1)


def init_weights(self, wemb_type, word2idx, word_dim):
    if wemb_type.lower() == 'random_init':
        nn.init.xavier_uniform_(self.embed.weight)
    else:
        # Load pretrained word embedding
        if 'fasttext' == wemb_type.lower():
            wemb = torchtext.vocab.FastText()
        elif 'glove' == wemb_type.lower():
            wemb = torchtext.vocab.GloVe()
        else:
            raise Exception('Unknown word embedding type: {}'.format(wemb_type))
        assert wemb.vectors.shape[1] == word_dim

        # quick-and-dirty trick to improve word-hit rate
        missing_words = []
        for word, idx in word2idx.items():
            if word not in wemb.stoi:
                word = word.replace('-', '').replace('.', '').replace("'", '')
                if '/' in word:
                    word = word.split('/')[0]
            if word in wemb.stoi:
                self.embed.weight.data[idx] = wemb.vectors[wemb.stoi[word]]
            else:
                missing_words.append(word)
        print('Words: {}/{} found in vocabulary; {} words missing'.format(
            len(word2idx) - len(missing_words), len(word2idx), len(missing_words)))

def forward(self, x, lengths):
    # return out
    x = self.embed(x)
    x = self.dropout(x)

    packed = pack_padded_sequence(x, lengths, batch_first=True,enforce_sorted=False)

    # Forward propagate RNN
    out, _ = self.rnn(packed)

    # Reshape *final* output to (batch_size, hidden_size)
    padded = pad_packed_sequence(out, batch_first=True)
    cap_emb, cap_len = padded

    cap_emb = l2norm(cap_emb, dim=-1)
    cap_emb_mean = torch.mean(cap_emb, 1)
    cap_emb_mean = self.gatef2(cap_emb_mean)
    cap_emb_mean = l2norm(cap_emb_mean)

    return cap_emb, cap_emb_mean

class EncoderText(nn.Module):
def init(self, opt):
super(EncoderText, self).init()
# self.sa = TextSA(opt.embed_size, 0.4)
self.fc_text = nn.Linear(1024,1024)
self.fusion = Fusion()
self.init_weights()
def init_weights(self):
"""Xavier initialization for the fully connected layer
"""
r = np.sqrt(6.) / np.sqrt(self.fc_text.in_features +
self.fc_text.out_features)
self.fc_text.weight.data.uniform_(-r, r)
self.fc_text.bias.data.fill_(0)
def forward(self, word_emb, text_emb):
# word_emb_mean = torch.mean(word_emb, 1)
# cap_emb = self.sa(word_emb, word_emb_mean)
word_emb = self.fc_text(word_emb)
emb = self.fusion(text_emb,word_emb)
emb = l2norm(emb)
return emb

def cosine_similarity(x1, x2, dim=1, eps=1e-8):
"""Returns cosine similarity between x1 and x2, computed along dim."""
w12 = torch.sum(x1 * x2, dim)
w1 = torch.norm(x1, 2, dim)
w2 = torch.norm(x2, 2, dim)
return (w12 / (w1 * w2).clamp(min=eps)).squeeze()

def func_attention(query, context, opt, smooth, eps=1e-8):
"""
query: (n_context, queryL, d)
context: (n_context, sourceL, d)
"""

# Get attention
# --> (batch, d, queryL)
queryT = torch.transpose(query, 1, 2)

# (batch, sourceL, d)(batch, d, queryL)
# --> (batch, sourceL, queryL)
attn = torch.bmm(context, queryT)

attn = nn.LeakyReLU(0.1)(attn)
attn = l2norm(attn, 2)

# --> (batch, queryL, sourceL)
attn = torch.transpose(attn, 1, 2).contiguous()
# --> (batch*queryL, sourceL)
attn = F.softmax(attn * smooth, dim=2)

# --> (batch, sourceL, queryL)
attnT = torch.transpose(attn, 1, 2).contiguous()
# --> (batch, d, sourceL)
contextT = torch.transpose(context, 1, 2)
# (batch x d x sourceL)(batch x sourceL x queryL)
# --> (batch, d, queryL)
weightedContext = torch.bmm(contextT, attnT)
# --> (batch, queryL, d)
weightedContext = torch.transpose(weightedContext, 1, 2)

return weightedContext

def xattn_score_t2i(images, captions, cap_lens, opt):
"""
Images: (n_image, n_regions, d) matrix of images
Captions: (n_caption, max_n_word, d) matrix of captions
CapLens: (n_caption) array of caption lengths
"""
similarities = []
weiContext_i = []
n_image = images.size(0)
n_caption = captions.size(0)
for i in range(n_caption):
# Get the i-th text description
n_word = cap_lens[i]
cap_i = captions[i, :n_word, :].unsqueeze(0).contiguous()
# --> (n_image, n_word, d)
cap_i_expand = cap_i.repeat(n_image, 1, 1)
"""
word(query): (n_image, n_word, d)
image(context): (n_image, n_regions, d)
weiContext: (n_image, n_word, d)
attn: (n_image, n_region, n_word)
"""
weiContext = func_attention(cap_i_expand, images, opt, smooth=9.)
cap_i_expand = cap_i_expand.contiguous()
weiContext = weiContext.contiguous()
# (n_image, n_word)
row_sim = cosine_similarity(cap_i_expand, weiContext, dim=2)

    row_sim = row_sim.mean(dim=1, keepdim=True)

    similarities.append(row_sim)

    weiContext = weiContext.mean(dim=1, keepdim=True)

    weiContext_i.append(weiContext)

# (n_image, n_caption)
similarities = torch.cat(similarities, 1)
weiContext_i = torch.cat(weiContext_i, 1)

weiContext_i = [weiContext_i[i, i, :].view(1, 1024) for i in range(n_image)]

weiContext_i = torch.cat(weiContext_i, 0)

weiContext_i = torch.mean(weiContext_i, dim=1)

return similarities,weiContext_i

def xattn_score_i2t(images, captions, cap_lens, opt):
"""
Images: (batch_size, n_regions, d) matrix of images
Captions: (batch_size, max_n_words, d) matrix of captions
CapLens: (batch_size) array of caption lengths
"""
similarities = []
weiContext_t = []
n_image = images.size(0)
n_caption = captions.size(0)

for i in range(n_caption):
    # Get the i-th text description
    n_word = cap_lens[i]
    cap_i = captions[i, :n_word, :].unsqueeze(0).contiguous()
    # (n_image, n_word, d)
    cap_i_expand = cap_i.repeat(n_image, 1, 1)
    """
        word(query): (n_image, n_word, d)
        image(context): (n_image, n_region, d)
        weiContext: (n_image, n_region, d)
        attn: (n_image, n_word, n_region)
    """
    weiContext = func_attention(images, cap_i_expand, opt, smooth=4.)
    # (n_image, n_region)
    row_sim = cosine_similarity(images, weiContext, dim=2)
    row_sim = row_sim.mean(dim=1, keepdim=True)
    similarities.append(row_sim)
    weiContext = weiContext.mean(dim=1, keepdim=True)
    weiContext_t.append(weiContext)

# (n_image, n_caption)
similarities = torch.cat(similarities, 1)
weiContext_t = torch.cat(weiContext_t, 1)

weiContext_t = [weiContext_t[i, i, :].view(1, 1024) for i in range(n_image)]

weiContext_t = torch.cat(weiContext_t, 0)

weiContext_t = torch.mean(weiContext, dim=1)

return similarities, weiContext_t

class LA(nn.Module):
def init(self,nchannels=1024):
super().init()
self.kernel_sizes = (1,3,5,7)
self.nchannels = nchannels
self.groups = len(self.kernel_sizes)
self.split_channels = [nchannels // self.groups for _ in range(self.groups)]
self.split_channels[0] += nchannels - sum(self.split_channels)
self.layers = nn.Sequential(nn.Conv1d(in_channels=self.split_channels[0],out_channels=self.split_channels[0],
kernel_size=self.kernel_sizes[0],padding=self.kernel_sizes[0]//2,groups=self.split_channels[0]),
nn.Conv1d(in_channels=self.split_channels[1], out_channels=self.split_channels[1],
kernel_size=self.kernel_sizes[1], padding=self.kernel_sizes[1] // 2,
groups=self.split_channels[1]),
nn.Conv1d(in_channels=self.split_channels[2], out_channels=self.split_channels[2],
kernel_size=self.kernel_sizes[2], padding=self.kernel_sizes[2] // 2,
groups=self.split_channels[2]),
nn.Conv1d(in_channels=self.split_channels[3], out_channels=self.split_channels[3],
kernel_size=self.kernel_sizes[3], padding=self.kernel_sizes[3] // 2,
groups=self.split_channels[3]),
)
def forward(self,x):
xi = x.transpose(1,2)
split_x = torch.split(xi,self.split_channels,dim=1)

    outputs = [layer(sp_x) for layer,sp_x in zip(self.layers,split_x)]
    outputs =  torch.cat(outputs,dim=1)
    B,C,N = outputs.size()
    channels_per_groups = C // 4
#    outputs = outputs.view(B,4,channels_per_groups,N)
 #   outputs = outputs.transpose(1,2).contiguous()
 #   outputs = outputs.view(B,C,N)
    return outputs.transpose(1,2) + x

class SLA(nn.Module):
def init(self,dim,heads=16,dim_head=64,dropout=0.1):
super(SLA, self).init()
inner_dim = dim_head * heads
self.dim_head = dim_head
project_out = not(heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.attn = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qk = nn.Linear(dim,inner_dim * 2)
self.to_out = nn.Sequential(
nn.Linear(inner_dim,dim),
nn.Dropout(dropout),
)
self.ln = nn.LayerNorm(dim)
self.la = LA()
def forward(self,x):
b,n,c = x.shape
xo = self.ln(x)
xl = self.la(xo)
v = xo.view(b,n,self.heads,self.dim_head).transpose(1,2)
qk = self.to_qk(xo).chunk(2,dim=-1)
q,k = map(lambda t:rearrange(t,'b n (h d) -> b h n d',h = self.heads),qk)
dots = torch.matmul(q,k.transpose(-1,-2)) * self.scale
attn = self.attn(dots)
attn = self.dropout(attn)
out = torch.matmul(attn,v)
out = rearrange(out,'b h n d -> b n (h d)')
out = self.to_out(out)
return out + x + xl

def cosine_sim(im, s):
"""Cosine similarity between all the image and sentence pairs
"""
return im.mm(s.t())

def Adaptive_Margin(scores):
bs = scores.size(0)
r = (torch.sum(scores >= 0) - bs) / (scores.numel() - bs)
return round(r.item(),4)

class ContrastiveLoss(nn.Module):

def __init__(self, opt, margin=0, measure=False, max_violation=False):
    super(ContrastiveLoss, self).__init__()
    #         self.net_type = opt.type
    self.margin = margin
    #         self.margin = 0.2
    self.opt = opt
    self.sim = cosine_sim
    self.max_violation = max_violation
    self.interval = self.opt.interval
    self.t = self.opt.threshold
def forward(self, im, s, region_feats, word_feats, length, gimg, gtxt, sims_local, iters):
    scores_g2l = self.sim(im,s)

    scores_global = self.sim(gimg, gtxt)
    scores_local = sims_local

    scores = scores_global + scores_local + scores_g2l

    diagonal = scores.diag().view(im.size(0), 1)
    d1 = diagonal.expand_as(scores)
    d2 = diagonal.t().expand_as(scores)

    if iters % self.interval == 0:
        k1 = Adaptive_Margin(scores - d1 + self.margin)
        k2 = Adaptive_Margin(scores - d2 + self.margin)
        k = max(k1, k2)
        print(k, self.t)
        if k >= self.t:
            self.margin += 0.01 * (1 - math.exp(-k))
            self.t *= math.exp(-iters / self.opt.iterations)
            self.margin = round(self.margin,4)
            self.t = round(self.t,4)

    # compare every diagonal score to scores in its column
    # caption retrieval
    cost_s = (self.margin + scores - d1).clamp(min=0)
    # compare every diagonal score to scores in its row
    # image retrieval
    cost_im = (self.margin + scores - d2).clamp(min=0)

    # clear diagonals
    mask = torch.eye(scores.size(0)) > .5
    I = Variable(mask)
    if torch.cuda.is_available():
        I = I.cuda()
    cost_s = cost_s.masked_fill_(I, 0)
    cost_im = cost_im.masked_fill_(I, 0)

    # keep the maximum violating negative for each query

    cost_s = cost_s.max(1)[0]
    cost_im = cost_im.max(0)[0]

    return cost_s.sum() + cost_im.sum()

class JZK(nn.Module):

def __init__(self, opt, pre_scan=False):
    super().__init__()
    #         self.net_type = opt.type
    self.opt = opt
    self.grad_clip = opt.grad_clip
    self.img_enc = EncoderImage(opt.embed_size,
                                opt.finetune, opt.cnn_type,
                                no_imgnorm=opt.no_imgnorm)
    self.region_enc = EncoderRegion(opt)
    self.cap_enc = EncoderText(opt)
    self.word_enc = EncoderWord(opt)
    self.sai = SLA(opt.embed_size)
    self.sat = SLA(opt.embed_size)
#    self.sai = LA()
#    self.sat = LA()
    # self.label_enc = EncoderLabel(opt)
    if torch.cuda.is_available():
        self.img_enc.cuda()
        self.cap_enc.cuda()
        self.region_enc.cuda()
        self.word_enc.cuda()
        self.sai.cuda()
        self.sat.cuda()
        # self.label_enc.cuda()
        cudnn.benchmark = True

    # Loss and Optimizer
    self.criterion = ContrastiveLoss(opt, margin=opt.margin,
                                     measure=opt.measure,
                                     max_violation=opt.max_violation)


    params = list(self.img_enc.fc.parameters())
    if opt.finetune:
        params += list(self.img_enc.cnn.parameters())
    params += list(self.word_enc.parameters())
    params += list(self.region_enc.parameters())
    params += list(self.cap_enc.parameters())
    params += list(self.sai.parameters())
    params += list(self.sat.parameters())


    self.params = params
    

    self.optimizer = torch.optim.Adam(params, lr=opt.learning_rate)


    self.Eiters = 0

def state_dict(self):
    # state_dict = [self.img_enc.state_dict(), self.cap_enc.state_dict(), self.label_enc.state_dict(),
    #               self.region_enc.state_dict(), self.word_enc.state_dict()]
    state_dict = [self.img_enc.state_dict(), self.cap_enc.state_dict(),
                  self.region_enc.state_dict(), self.word_enc.state_dict(),
                  self.sai.state_dict(),self.sat.state_dict()]
    return state_dict

def load_state_dict(self, state_dict):
    self.img_enc.load_state_dict(state_dict[0])
    self.cap_enc.load_state_dict(state_dict[1])
    # self.label_enc.load_state_dict(state_dict[2])
    self.region_enc.load_state_dict(state_dict[2])
    self.word_enc.load_state_dict(state_dict[3])
    self.sai.load_state_dict(state_dict[4])
    self.sat.load_state_dict(state_dict[5])

def train_start(self):
    """switch to train mode
    """
    self.img_enc.train()
    self.cap_enc.train()
    # self.label_enc.train()
    self.region_enc.train()
    self.word_enc.train()
    self.sai.train()
    self.sat.train()

def val_start(self):
    """switch to evaluate mode
    """
    self.img_enc.eval()
    self.cap_enc.eval()
    # self.label_enc.eval()
    self.region_enc.eval()
    self.word_enc.eval()
    self.sai.eval()
    self.sat.eval()

def forward_emb(self, images, region_feat,  captions, lengths):
    """Compute the image and caption embeddings
    """
    # Set mini-batch dataset
    images = Variable(images)
    captions = Variable(captions)

    region_feat = Variable(region_feat)
    if torch.cuda.is_available():
        images = images.cuda()
        captions = captions.cuda()
        region_feat = region_feat.cuda()

    # Forward

    region_emb = self.region_enc(region_feat)
    word_emb, text_emb = self.word_enc(captions, lengths)
    sims_local, attn_txt, attn_img = self.local_sim(region_emb,word_emb,lengths)
    attn_img = self.sai(attn_img)
    attn_txt = self.sat(attn_txt)
    attn_img = [attn_img[i, i, :].view(1, 1024) for i in range(attn_img.size(0))]
    attn_img = torch.cat(attn_img, 0)
    attn_txt = [attn_txt[i, i, :].view(1, 1024) for i in range(attn_txt.size(0))]
    attn_txt = torch.cat(attn_txt, 0)
    gloi_emb, img_emb = self.img_enc(images, attn_img)
    cap_emb = self.cap_enc(attn_txt, text_emb)
    # img_label, cap_label, label = self.label_enc(img_emb, cap_emb, region_emb, region_cls, word_emb, lengths)
    return img_emb, cap_emb, region_emb, word_emb, lengths, gloi_emb, text_emb, sims_local


def forward_loss(self, img_emb, cap_emb, region_emb, word_emb, lengths, gimg, gtxt,sims_local, iters,**kwargs):
    """Compute the loss given pairs of image and caption embeddings
    """
    loss = self.criterion(img_emb, cap_emb, region_emb, word_emb, lengths, gimg,gtxt,sims_local,iters)

    self.logger.update('Loss', loss.item(), img_emb.size(0))
    return loss

def train_emb(self, images, region_feat, captions, lengths, ids=None, *args):
    """One training step given images and captions.
    """
    self.Eiters += 1
    self.logger.update('Eit', self.Eiters)
    self.logger.update('lr', self.optimizer.param_groups[0]['lr'])

    # compute the embeddings
    img_emb, cap_emb, region_emb, word_emb, lengths, gimg,gtxt,sims_local = \
        self.forward_emb(images, region_feat, captions,lengths)
    # measure accuracy and record loss
    self.optimizer.zero_grad()


    loss = self.forward_loss(img_emb, cap_emb, region_emb, word_emb, lengths, gimg, gtxt, sims_local, self.Eiters)

    # compute gradient and do SGD step
    loss.backward()
    if self.grad_clip > 0:
        clip_grad_norm_(self.params, self.grad_clip)

    self.optimizer.step()


def local_sim(self, region_emb, word_emb, length):
    attn_i = None
    attn_t = None
    scores = None
    if self.opt.cross_attn == 't2i':
        scores, attn_i = xattn_score_t2i(region_emb, word_emb, length, self.opt)
    elif self.opt.cross_attn == 'i2t':
        scores, attn_t = xattn_score_i2t(region_emb, word_emb, length, self.opt)
    elif self.opt.cross_attn == 'all':
        score1, attn_t = xattn_score_i2t(region_emb, word_emb, length, self.opt)
        score2, attn_i = xattn_score_t2i(region_emb, word_emb, length, self.opt)
        scores = 0.5 * (score1 + score2)
    return scores, attn_t, attn_i

def xattn_score_t2i1(images, captions, cap_lens, opt):

similarities = []
n_image = images.size(0)
n_caption = captions.size(0)
for i in range(n_caption):
    # Get the i-th text description
    n_word = cap_lens[i]
    cap_i = captions[i, :n_word, :].unsqueeze(0).contiguous()
    # --> (n_image, n_word, d)
    cap_i_expand = cap_i.repeat(n_image, 1, 1)
    """
        word(query): (n_image, n_word, d)
        image(context): (n_image, n_regions, d)
        weiContext: (n_image, n_word, d)
        attn: (n_image, n_region, n_word)
    """
    weiContext = func_attention(cap_i_expand, images, opt, smooth=9.)
    cap_i_expand = cap_i_expand.contiguous()
    weiContext = weiContext.contiguous()
    # (n_image, n_word)
    row_sim = cosine_similarity(cap_i_expand, weiContext, dim=2)
    row_sim = row_sim.mean(dim=1, keepdim=True)
    similarities.append(row_sim)

# (n_image, n_caption)
similarities = torch.cat(similarities, 1)

return similarities

def xattn_score_i2t1(images, captions, cap_lens, opt):

similarities = []
n_image = images.size(0)
n_caption = captions.size(0)
n_region = images.size(1)
for i in range(n_caption):
    # Get the i-th text description
    n_word = cap_lens[i]
    cap_i = captions[i, :n_word, :].unsqueeze(0).contiguous()
    # (n_image, n_word, d)
    cap_i_expand = cap_i.repeat(n_image, 1, 1)
    """
        word(query): (n_image, n_word, d)
        image(context): (n_image, n_region, d)
        weiContext: (n_image, n_region, d)
        attn: (n_image, n_word, n_region)
    """
    weiContext = func_attention(images, cap_i_expand, opt, smooth=4.)
    # (n_image, n_region)
    row_sim = cosine_similarity(images, weiContext, dim=2)
    row_sim = row_sim.mean(dim=1, keepdim=True)
    similarities.append(row_sim)

# (n_image, n_caption)
similarities = torch.cat(similarities, 1)
return similarities`