jeonggg119/DL_paper

[CV_Pose Estimation] Efficient Object Localization Using Convolutional Networks

jeonggg119 opened this issue · 0 comments

Efficient Object Localization Using Convolutional Networks

Abstract

  • Efficient 'Position Refinement' model
    • trained to estimate joint offset location within a small region of img
    • trained in cascade within SOTA ConvNet model to acheive improved acc
    • on FLIC dataset, MPII dataset

1. Introduction

  • Human-body part localization task ↑ BY ConvNet arch + larger datasets

  • (sota) ConvNet : internal strided-pooling layers

    • reduce spatial resolution
    • output : invariant to spatial location within pooling region
    • promote spatial invariance to local input transformation
    • pooling : prevent over-training + reducing computational complexity for classification
    • Trade-off : generalization performance ↑ <-> spatial localization accuracy ↓
  • (this paper) LCN : ConvNet for efficient localization of human joints in RGB imgs

    • high spatial accuracy + computational efficiency
    • begin by coarse body part localization -> output : low resolution, per-pixel heat-map
    • show likelihood of a joint occurring in each spatial location
    • Max-pooling for dimensionality reduction + improving invariance to noise and local img transformations
    • reuse hidden layer conv features from coarse heat-map regression model to improve localization accuracy

2. Related Work

  • Models using Hand-crafted features (edges, contours, HoG, color histograms) : poor generalization performance

    • Deformable Part Models (DPM)
    • Mixture of templates modeled using SVMs
    • Poselet + DPM mosel : spatial relationship of body parts
    • Atmlets : semi-global classifier, good for real-world data, but only arms
    • Multi-modal model : holistic + local
  • ConvNets

    • formulate problem as a direct (continuous) regression
    • poorly in high-precision region
    • unnecessary learning complexity by mapping from input RGB img to XY location (over-training)
    • +) low-dimensional representation of input img, multi-resolution ConvNet arch, ...

3. Coarse Heat-Map Regression Model

  • Using Extension of Multi-resolution ConvNet model
  • For Sliding window detector with Overlapping contexts to produce Coarse heat-map output

3.1. Model Architecture

image

  • Input : RGB Gaussian pyramid of 3 levels (320 x 240 for FLIC, 256 x 256 for MPII)

    Figure 2 : only 2 levels for brevity

  • Output : Heat-map for each joint describing per-pixel likelihood for joint occurring in each output spatial location
  • 1st layer : LCN (Local Contrast Normalization) with same filter kernel in each 3 resolution banks -> out : LCN imgs
  • Next 7 stage (11 for MPII) multi-resolution ConvNet : Pooling -> heat-map output is at a lower resolution than input img
  • Last 4 stage (3 for MPII) : effectively simulated FC network for taget input patch size

3.2. Spatial Dropout

  • Dropout : zeroing activation -> improving generalization by preventing activations from becoming strongly correlated
  • Additional Dropout layer before 1st 1x1 conv layer
  • Standard Dropout
    image
    • Network is fully conv (1d conv) & natural imgs (so, feature map activations) are strongly correlated
    • Result : over-training (Fail)
  • Spatial Dropout
    image
    • Feature-map = n_features x Height x Width
    • How : perform only n_features dropout trials + extend value across entire feature map
    • Result : adjacent pixels are either all 0 OR all active (good performance on FLIC)

3.3. Training and Data Augmentation

  • Loss : MSE
    image
    • H', H : Predicted and GT heat-map for joint
    • Target GT heat-map : 2D gaussian of constant variance (sigma = 1.5 pixels) centered at GT joint (x,y)
  • Data Augmentation : Random rotation, scaling, flipping -> Generalization
  • Multiple people contained but Single person annotated case
    • How : Sliding-window + tree-structured MRF spatial model (approximate Torso position)
    • MRF Input : GT torso position + 14 predicted joints from ConvNet output = 15 joints locations
    • Result : selecting correct person for labeling

4. Fine Heat-Map Regression Model

  • Purpose : Recovering spatial accuracy lost due to pooling
  • How : Using additional ConvNet to refine localization result of coarse heat-map
  • Difference : Reusing existing conv features -> reducing # of params + acting as regularizer

4.1. Model Architecture

  • Full system Architecture
    image

    • Heat-map-based model for coarse localization
    • Module to sample and crop conv features at joint location (x, y)
    • Additional conv model for fine tuning
  • Joint Inference Steps

    1. FPROP (forward-propagate) through Coarse heat-map model
      • Infer all joint locations (x, y) from max value in each joint's heat-map
    2. Sample and Crop first 2 conv layers (for all resolution) at each coarse location (x, y)
      • output gradients from cropped img + output gradients of conv stages in coarse heat-map
        image
    3. FPROP through Fine heat-map model -> (△x, △y)
      • Fine heat-map model : Siamese network of 7 instances (14 for MPII)
    4. Add Position Refinement to coarse location -> Final location (x, y) for each joint
  • Fine heat-map model
    image
    image

    • Siamese network : Weights and biases of each module are shared
    • Sample location for each joint is different : Conv features don't share same Spatial context
    • So, conv sub-nets must be applied to each joint independently
    • But, parameter sharing to reduce # of shared params and prevent over-training
  • Last 1x1 Conv

    • No weight sharing
    • Input : each output of 7 sub-nets
    • Output : detailed-resolution heat-map
    • Purpose : Final detection for each joint

4.2. Joint Training

  • Before Joint Training : Pre-training Coarse heat-map model first
  • Holding params Coarse heat-map model Fixed + Training Fine heat-map model
  • Jointly Training both models by minimizing E3 = E1 + λE2 ..... (λ = 0.1)
    image
    • H', H : Predicted and GT Coarse heat-map for joint
      image
    • G', G : Predicted and GT Fine heat-map for joint
  • Regression to set of target heat-maps for minimizing final (x, y) prediction

5. Results

  • Framwork : Torch7

  • Dataset : FLIC(easy), MPII-Human-Pose(hard)

  • Pooling impact for coarse heat-map model : Pooling ↑ -> Detection performance(spatial precision) ↓
    image

  • Ambiguous GT labels : can be worse than expected variance in User-generated labels
    image

  • Cascaded model impact : better than Coarse model only
    image
    image

  • Greedily-trained cascade (Shared features)
    image

    • Coarse and Fine models are trained independently by adding additional conv layer
    • How : Training Fine model by using cropped input imgs as input
    • Result : regularizing effect of joint training : preventing over-training [F14(a)]
  • SpatialDropout : regularizing effect of dropout + reduction in strong heat-map outliers [F14(b)]
    image
    image
    image
    image

6. Conclusion

  • Localization tasks demand high degree of spatial precision
  • Cascaded architecture that combined Fine and Coarse conv networks -> SOTA on FLIC, MPII-human-pose
  • Spatial Precision + Computational benefits of Pooling

Code

Train

import os
import sys
import time
import argparse

import torch
import numpy as np
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchcontrib

from torchvision import transforms

from dataset.cub200 import CUB200Data
from dataset.mit67 import MIT67Data
from dataset.stanford_dog import SDog120Data
from dataset.caltech256 import Caltech257Data
from dataset.stanford_40 import Stanford40Data
from dataset.flower102 import Flower102Data

from model.fe_resnet import resnet18_dropout, resnet50_dropout, resnet101_dropout
from model.fe_mobilenet import mbnetv2_dropout

class MovingAverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f', momentum=0.9):
        self.name = name
        self.fmt = fmt
        self.momentum = momentum
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0

    def update(self, val, n=1):
        self.val = val
        self.avg = self.momentum*self.avg + (1-self.momentum)*val

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)

class ProgressMeter(object):
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print('\t'.join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'
    
class CrossEntropyLabelSmooth(nn.Module):
    def __init__(self, num_classes, epsilon = 0.1):
        super(CrossEntropyLabelSmooth, self).__init__()
        self.num_classes = num_classes
        self.epsilon = epsilon
        self.logsoftmax = nn.LogSoftmax(dim=1)

    def forward(self, inputs, targets):
        log_probs = self.logsoftmax(inputs)
        targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1)
        targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
        loss = (-targets * log_probs).sum(1)
        return loss.mean()

def linear_l2(model):
    beta_loss = 0
    for m in model.modules():
        if isinstance(m, nn.Linear):
            beta_loss += (m.weight).pow(2).sum()
            beta_loss += (m.bias).pow(2).sum()
    return 0.5*beta_loss*args.beta, beta_loss

def l2sp(model, reg):
    reg_loss = 0
    dist = 0
    for m in model.modules():
        if hasattr(m, 'weight') and hasattr(m, 'old_weight'):
            diff = (m.weight - m.old_weight).pow(2).sum()
            dist += diff
            reg_loss += diff 

        if hasattr(m, 'bias') and hasattr(m, 'old_bias'):
            diff = (m.bias - m.old_bias).pow(2).sum()
            dist += diff
            reg_loss += diff 

    if dist > 0:
        dist = dist.sqrt()
    
    loss = (reg * reg_loss)
    return loss, dist


def test(model, teacher, loader, loss=False):
    with torch.no_grad():
        model.eval()

        if loss:
            teacher.eval()

            ce = CrossEntropyLabelSmooth(loader.dataset.num_classes, args.label_smoothing).to('cuda')
            featloss = torch.nn.MSELoss(reduction='none')

        total_ce = 0
        total_feat_reg = np.zeros(len(reg_layers))
        total_l2sp_reg = 0
        total = 0
        top1 = 0

        total = 0
        top1 = 0
        for i, (batch, label) in enumerate(loader):
            batch, label = batch.to('cuda'), label.to('cuda')

            total += batch.size(0)
            out = model(batch)
            _, pred = out.max(dim=1)
            top1 += int(pred.eq(label).sum().item())

            if loss:
                total_ce += ce(out, label).item()
                if teacher is not None:
                    with torch.no_grad():
                        tout = teacher(batch)

                    for key in reg_layers:
                        src_x = reg_layers[key][0].out
                        tgt_x = reg_layers[key][1].out
                        tgt_channels = tgt_x.shape[1]

                        regloss = featloss(src_x[:,:tgt_channels,:,:], tgt_x.detach()).mean()

                        total_feat_reg[key] += regloss.item()

                _, unweighted = l2sp(model, 0)
                total_l2sp_reg += unweighted.item()

    return float(top1)/total*100, total_ce/(i+1), np.sum(total_feat_reg)/(i+1), total_l2sp_reg/(i+1), total_feat_reg/(i+1)

def train(model, train_loader, val_loader, iterations=9000, lr=1e-2, name='', l2sp_lmda=1e-2, teacher=None, reg_layers={}):
    model = model.to('cuda')

    if l2sp_lmda == 0:
        optimizer = optim.SGD(model.parameters(), lr=lr, momentum=args.momentum, weight_decay=args.weight_decay)
    else:
        optimizer = optim.SGD(model.parameters(), lr=lr, momentum=args.momentum, weight_decay=0)

    end_iter = iterations
    if args.swa:
        optimizer = torchcontrib.optim.SWA(optimizer, swa_start=args.swa_start, swa_freq=args.swa_freq)
        end_iter = args.swa_start
    if args.const_lr:
        scheduler = None
    else:
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, end_iter)

    teacher.eval()
    ce = CrossEntropyLabelSmooth(train_loader.dataset.num_classes, args.label_smoothing).to('cuda')
    featloss = torch.nn.MSELoss()


    batch_time = MovingAverageMeter('Time', ':6.3f')
    data_time = MovingAverageMeter('Data', ':6.3f')
    ce_loss_meter = MovingAverageMeter('CE Loss', ':6.3f')
    feat_loss_meter  = MovingAverageMeter('Feat. Loss', ':6.3f')
    l2sp_loss_meter  = MovingAverageMeter('L2SP Loss', ':6.3f')
    linear_loss_meter  = MovingAverageMeter('LinearL2 Loss', ':6.3f')
    total_loss_meter  = MovingAverageMeter('Total Loss', ':6.3f')
    top1_meter  = MovingAverageMeter('Acc@1', ':6.2f')

    dataloader_iterator = iter(train_loader)
    for i in range(iterations):
        if args.swa:
            if i >= int(args.swa_start) and (i-int(args.swa_start))%args.swa_freq == 0:
                scheduler = None
        model.train()
        optimizer.zero_grad()

        end = time.time()
        try:
            batch, label = next(dataloader_iterator)
        except:
            dataloader_iterator = iter(train_loader)
            batch, label = next(dataloader_iterator)
        batch, label = batch.to('cuda'), label.to('cuda')
        data_time.update(time.time() - end)

        out = model(batch)
        _, pred = out.max(dim=1)

        top1_meter.update(float(pred.eq(label).sum().item()) / label.shape[0] * 100.)

        loss = 0.
        loss += ce(out, label)

        ce_loss_meter.update(loss.item())

        with torch.no_grad():
            tout = teacher(batch)

        # Compute the feature distillation loss only when needed
        if args.feat_lmda > 0:
            regloss = 0
            for layer in args.feat_layers:
                key = int(layer)-1

                src_x = reg_layers[key][0].out
                tgt_x = reg_layers[key][1].out
                tgt_channels = tgt_x.shape[1]
                regloss += featloss(src_x[:,:tgt_channels,:,:], tgt_x.detach())

            regloss = args.feat_lmda * regloss
            loss += regloss
            feat_loss_meter.update(regloss.item())

        beta_loss, linear_norm = linear_l2(model)
        loss = loss + beta_loss 
        linear_loss_meter.update(beta_loss.item())

        if l2sp_lmda > 0:
            reg, _ = l2sp(model, l2sp_lmda)
            l2sp_loss_meter.update(reg.item())
            loss = loss + reg

        total_loss_meter.update(loss.item())

        loss.backward()
        optimizer.step()
        for param_group in optimizer.param_groups:
            current_lr = param_group['lr']
        if scheduler is not None:
            scheduler.step()

        batch_time.update(time.time() - end)

        if (i % args.print_freq == 0) or (i == iterations-1):
            progress = ProgressMeter(
                iterations,
                [batch_time, data_time, top1_meter, total_loss_meter, ce_loss_meter, feat_loss_meter, l2sp_loss_meter, linear_loss_meter],
                prefix="LR: {:6.5f}".format(current_lr))
            progress.display(i)

        if (i % args.test_interval == 0) or (i == iterations-1):
            test_top1, test_ce_loss, test_feat_loss, test_weight_loss, test_feat_layer_loss = test(model, teacher, val_loader, loss=True)
            train_top1, train_ce_loss, train_feat_loss, train_weight_loss, train_feat_layer_loss = test(model, teacher, train_loader, loss=True)
            print('Eval Train | Iteration {}/{} | Top-1: {:.2f} | CE Loss: {:.3f} | Feat Reg Loss: {:.6f} | L2SP Reg Loss: {:.3f}'.format(i+1, iterations, train_top1, train_ce_loss, train_feat_loss, train_weight_loss))
            print('Eval Test | Iteration {}/{} | Top-1: {:.2f} | CE Loss: {:.3f} | Feat Reg Loss: {:.6f} | L2SP Reg Loss: {:.3f}'.format(i+1, iterations, test_top1, test_ce_loss, test_feat_loss, test_weight_loss))
            if not args.no_save:
                if not os.path.exists('ckpt'):
                    os.makedirs('ckpt')
                torch.save({'state_dict': model.state_dict()}, 'ckpt/{}.pth'.format(name))

    if args.swa:
        optimizer.swap_swa_sgd()

        for m in model.modules():
            if hasattr(m, 'running_mean'):
                m.reset_running_stats()
                m.momentum = None
        with torch.no_grad():
            model.train()
            for x, y in train_loader:
                x = x.to('cuda')
                out = model(x)

        test_top1, test_ce_loss, test_feat_loss, test_weight_loss, test_feat_layer_loss = test(model, teacher, val_loader, loss=True)
        train_top1, train_ce_loss, train_feat_loss, train_weight_loss, train_feat_layer_loss = test(model, teacher, train_loader, loss=True)
        print('Eval Train | Iteration {}/{} | Top-1: {:.2f} | CE Loss: {:.3f} | Feat Reg Loss: {:.6f} | L2SP Reg Loss: {:.3f}'.format(i+1, iterations, train_top1, train_ce_loss, train_feat_loss, train_weight_loss))
        print('Eval Test | Iteration {}/{} | Top-1: {:.2f} | CE Loss: {:.3f} | Feat Reg Loss: {:.6f} | L2SP Reg Loss: {:.3f}'.format(i+1, iterations, test_top1, test_ce_loss, test_feat_loss, test_weight_loss))

        if not args.no_save:
            if not os.path.exists('ckpt'):
                os.makedirs('ckpt')
            torch.save({'state_dict': model.state_dict()}, 'ckpt/{}.pth'.format(name))

    return model

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--datapath", type=str, default='/data', help='path to the dataset')
    parser.add_argument("--dataset", type=str, default='CUB200Data', help='Target dataset. Currently support: \{SDog120Data, CUB200Data, Stanford40Data, MIT67Data, Flower102Data\}')
    parser.add_argument("--iterations", type=int, default=30000, help='Iterations to train')
    parser.add_argument("--print_freq", type=int, default=100, help='Frequency of printing training logs')
    parser.add_argument("--test_interval", type=int, default=1000, help='Frequency of testing')
    parser.add_argument("--name", type=str, default='test', help='Name for the checkpoint')
    parser.add_argument("--batch_size", type=int, default=64)
    parser.add_argument("--lr", type=float, default=1e-2)
    parser.add_argument("--const_lr", action='store_true', default=False, help='Use constant learning rate')
    parser.add_argument("--weight_decay", type=float, default=0)
    parser.add_argument("--momentum", type=float, default=0.9)
    parser.add_argument("--beta", type=float, default=1e-2, help='The strength of the L2 regularization on the last linear layer')
    parser.add_argument("--dropout", type=float, default=0, help='Dropout rate for spatial dropout')
    parser.add_argument("--l2sp_lmda", type=float, default=0)
    parser.add_argument("--feat_lmda", type=float, default=0)
    parser.add_argument("--feat_layers", type=str, default='1234', help='Used for DELTA (which layers or stages to match), ResNets should be 1234 and MobileNetV2 should be 12345')
    parser.add_argument("--reinit", action='store_true', default=False, help='Reinitialize before training')
    parser.add_argument("--no_save", action='store_true', default=False, help='Do not save checkpoints')
    parser.add_argument("--swa", action='store_true', default=False, help='Use SWA')
    parser.add_argument("--swa_freq", type=int, default=500, help='Frequency of averaging models in SWA')
    parser.add_argument("--swa_start", type=int, default=0, help='Start SWA since which iterations')
    parser.add_argument("--label_smoothing", type=float, default=0)
    parser.add_argument("--checkpoint", type=str, default='', help='Load a previously trained checkpoint')
    parser.add_argument("--network", type=str, default='resnet18', help='Network architecture. Currently support: \{resnet18, resnet50, resnet101, mbnetv2\}')
    parser.add_argument("--tnetwork", type=str, default='resnet18', help='Network architecture. Currently support: \{resnet18, resnet50, resnet101, mbnetv2\}')
    parser.add_argument("--width_mult", type=float, default=1)
    parser.add_argument("--shot", type=int, default=-1, help='Number of training samples per class for the training dataset. -1 indicates using the full dataset.')
    parser.add_argument("--log", action='store_true', default=False, help='Redirect the output to log/args.name.log')
    args = parser.parse_args()
    return args

# Used to matching features
def record_act(self, input, output):
    self.out = output

def record_act_with_1x1(self, input, output):
    self.out = self[-1].dim_matching(output)

if __name__ == '__main__':
    args = get_args()

    if args.log:
        if not os.path.exists('log'):
            os.makedirs('log')
        sys.stdout = open('log/{}.log'.format(args.name), 'w')


    print(args)

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    # Used to make sure we sample the same image for few-shot scenarios
    seed = 98

    train_set = eval(args.dataset)(args.datapath, True, transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]), args.shot, seed, preload=False)

    test_set = eval(args.dataset)(args.datapath, False, transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ]), args.shot, seed, preload=False)

    train_loader = torch.utils.data.DataLoader(train_set,
        batch_size=args.batch_size, shuffle=True,
        num_workers=8, pin_memory=True)

    val_loader = train_loader

    test_loader = torch.utils.data.DataLoader(test_set,
        batch_size=args.batch_size, shuffle=False,
        num_workers=8, pin_memory=False)

    model = eval('{}_dropout'.format(args.network))(pretrained=True, dropout=args.dropout, width_mult=args.width_mult, num_classes=train_loader.dataset.num_classes).cuda()
    if args.checkpoint != '':
        checkpoint = torch.load(args.checkpoint)
        model.load_state_dict(checkpoint['state_dict'])

    # Pre-trained model
    teacher = eval('{}_dropout'.format(args.tnetwork))(pretrained=True, dropout=0, num_classes=train_loader.dataset.num_classes).cuda()

    if 'mbnetv2' in args.network:
        reg_layers = {0: [model.layer1], 1: [model.layer2], 2: [model.layer3], 3: [model.layer4], 4: [model.layer5]}
        model.layer1.register_forward_hook(record_act)
        model.layer2.register_forward_hook(record_act)
        model.layer3.register_forward_hook(record_act)
        model.layer4.register_forward_hook(record_act)
        model.layer5.register_forward_hook(record_act)
    else:
        reg_layers = {0: [model.layer1], 1: [model.layer2], 2: [model.layer3], 3: [model.layer4]}
        # if args.width_mult > 1:
        #     model.layer1.register_forward_hook(record_act_with_1x1)
        #     model.layer2.register_forward_hook(record_act_with_1x1)
        #     model.layer3.register_forward_hook(record_act_with_1x1)
        #     model.layer4.register_forward_hook(record_act_with_1x1)

        #     model.layer1[-1].dim_matching = torch.nn.Conv2d(model.layer1[-1].out_dim, int(model.layer1[-1].out_dim/args.width_mult), kernel_size=1, bias=False).cuda()
        #     model.layer2[-1].dim_matching = torch.nn.Conv2d(model.layer2[-1].out_dim, int(model.layer2[-1].out_dim/args.width_mult), kernel_size=1, bias=False).cuda()
        #     model.layer3[-1].dim_matching = torch.nn.Conv2d(model.layer3[-1].out_dim, int(model.layer3[-1].out_dim/args.width_mult), kernel_size=1, bias=False).cuda()
        #     model.layer4[-1].dim_matching = torch.nn.Conv2d(model.layer4[-1].out_dim, int(model.layer4[-1].out_dim/args.width_mult), kernel_size=1, bias=False).cuda()
        # else:
        #     model.layer1.register_forward_hook(record_act)
        #     model.layer2.register_forward_hook(record_act)
        #     model.layer3.register_forward_hook(record_act)
        #     model.layer4.register_forward_hook(record_act)

        model.layer1.register_forward_hook(record_act_with_1x1)
        model.layer2.register_forward_hook(record_act_with_1x1)
        model.layer3.register_forward_hook(record_act_with_1x1)
        model.layer4.register_forward_hook(record_act_with_1x1)

        model.layer1[-1].dim_matching = torch.nn.Conv2d(model.layer1[-1].out_dim, int(teacher.layer1[-1].out_dim/args.width_mult), kernel_size=1, bias=False).cuda()
        model.layer2[-1].dim_matching = torch.nn.Conv2d(model.layer2[-1].out_dim, int(teacher.layer2[-1].out_dim/args.width_mult), kernel_size=1, bias=False).cuda()
        model.layer3[-1].dim_matching = torch.nn.Conv2d(model.layer3[-1].out_dim, int(teacher.layer3[-1].out_dim/args.width_mult), kernel_size=1, bias=False).cuda()
        model.layer4[-1].dim_matching = torch.nn.Conv2d(model.layer4[-1].out_dim, int(teacher.layer4[-1].out_dim/args.width_mult), kernel_size=1, bias=False).cuda()


    # Stored pre-trained weights for computing L2SP
    for m in model.modules():
        if hasattr(m, 'weight') and not hasattr(m, 'old_weight'):
            m.old_weight = m.weight.data.clone().detach()
            # all_weights = torch.cat([all_weights.reshape(-1), m.weight.data.abs().reshape(-1)], dim=0)
        if hasattr(m, 'bias') and not hasattr(m, 'old_bias') and m.bias is not None:
            m.old_bias = m.bias.data.clone().detach()

    if args.reinit:
        for m in model.modules():
            if type(m) in [nn.Linear, nn.BatchNorm2d, nn.Conv2d]:
                m.reset_parameters()

    reg_layers[0].append(teacher.layer1)
    teacher.layer1.register_forward_hook(record_act)
    reg_layers[1].append(teacher.layer2)
    teacher.layer2.register_forward_hook(record_act)
    reg_layers[2].append(teacher.layer3)
    teacher.layer3.register_forward_hook(record_act)
    reg_layers[3].append(teacher.layer4)
    teacher.layer4.register_forward_hook(record_act)

    if '5' in args.feat_layers:
        reg_layers[4].append(teacher.layer5)
        teacher.layer5.register_forward_hook(record_act)

    train(model, train_loader, test_loader, l2sp_lmda=args.l2sp_lmda, iterations=args.iterations, lr=args.lr, name='{}'.format(args.name), teacher=teacher, reg_layers=reg_layers)

Eval

import argparse
import torch
import time
import sys
import numpy as np
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchcontrib

from PIL import Image

from torchvision import transforms

from dataset.cub200 import CUB200Data
from dataset.mit67 import MIT67Data
from dataset.stanford_dog import SDog120Data
from dataset.caltech256 import Caltech257Data
from dataset.stanford_40 import Stanford40Data
from dataset.flower102 import Flower102Data

from advertorch.attacks import LinfPGDAttack

from model.fe_resnet import resnet18_dropout, resnet50_dropout, resnet101_dropout
from model.fe_mobilenet import mbnetv2_dropout
from model.fe_resnet import feresnet18, feresnet50, feresnet101
from model.fe_mobilenet import fembnetv2

def test(model, loader, adversary):
    model.eval()

    total_ce = 0
    total = 0
    top1 = 0

    total = 0
    top1_clean = 0
    top1_adv = 0
    adv_success = 0
    adv_trial = 0
    for i, (batch, label) in enumerate(loader):
        batch, label = batch.to('cuda'), label.to('cuda')

        total += batch.size(0)
        out_clean = model(batch)

        if 'mbnetv2' in args.network:
            y = torch.zeros(batch.shape[0], model.classifier[1].in_features).cuda()
        else:
            y = torch.zeros(batch.shape[0], model.fc.in_features).cuda()
        y[:,0] = args.m
        advbatch = adversary.perturb(batch, y)

        out_adv = model(advbatch)

        _, pred_clean = out_clean.max(dim=1)
        _, pred_adv = out_adv.max(dim=1)

        clean_correct = pred_clean.eq(label)
        adv_trial += int(clean_correct.sum().item())
        adv_success += int(pred_adv[clean_correct].eq(label[clean_correct]).sum().item())
        top1_clean += int(pred_clean.eq(label).sum().item())
        top1_adv += int(pred_adv.eq(label).sum().item())

        print('{}/{}...'.format(i+1, len(loader)))


    return float(top1_clean)/total*100, float(top1_adv)/total*100, float(adv_trial-adv_success) / adv_trial *100

def record_act(self, input, output):
    pass

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--datapath", type=str, default='/data', help='path to the dataset')
    parser.add_argument("--dataset", type=str, default='CUB200Data', help='Target dataset. Currently support: \{SDog120Data, CUB200Data, Stanford40Data, MIT67Data, Flower102Data\}')
    parser.add_argument("--name", type=str, default='test')
    parser.add_argument("--B", type=float, default=0.1, help='Attack budget')
    parser.add_argument("--m", type=float, default=1000, help='Hyper-parameter for task-agnostic attack')
    parser.add_argument("--pgd_iter", type=int, default=40)
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--dropout", type=float, default=0)
    parser.add_argument("--checkpoint", type=str, default='')
    parser.add_argument("--network", type=str, default='resnet18', help='Network architecture. Currently support: \{resnet18, resnet50, resnet101, mbnetv2\}')
    args = parser.parse_args()
    return args

def myloss(yhat, y):
    return -((yhat[:,0]-y[:,0])**2 + 0.1*((yhat[:,1:]-y[:,1:])**2).mean(1)).mean()

if __name__ == '__main__':
    args = get_args()
    print(args)

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    seed = int(time.time())

    test_set = eval(args.dataset)(args.datapath, False, transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ]), -1, seed, preload=False)

    test_loader = torch.utils.data.DataLoader(test_set,
        batch_size=args.batch_size, shuffle=False,
        num_workers=8, pin_memory=False)

    transferred_model = eval('{}_dropout'.format(args.network))(pretrained=False, dropout=args.dropout, num_classes=test_loader.dataset.num_classes).cuda()
    checkpoint = torch.load(args.checkpoint)
    transferred_model.load_state_dict(checkpoint['state_dict'])

    pretrained_model = eval('fe{}'.format(args.network))(pretrained=True).cuda().eval()

    adversary = LinfPGDAttack(
            pretrained_model, loss_fn=myloss, eps=args.B,
            nb_iter=args.pgd_iter, eps_iter=0.01, rand_init=True, clip_min=-2.2, clip_max=2.2,
            targeted=False)

    clean_top1, adv_top1, adv_sr = test(transferred_model, test_loader, adversary)

    print('Clean Top-1: {:.2f} | Adv Top-1: {:.2f} | Attack Success Rate: {:.2f}'.format(clean_top1, adv_top1, adv_sr))