[CV_Pose Estimation] Efficient Object Localization Using Convolutional Networks
jeonggg119 opened this issue · 0 comments
Efficient Object Localization Using Convolutional Networks
Abstract
- Efficient 'Position Refinement' model
- trained to estimate joint offset location within a small region of img
- trained in cascade within SOTA ConvNet model to acheive improved acc
- on FLIC dataset, MPII dataset
1. Introduction
-
Human-body part localization task ↑ BY ConvNet arch + larger datasets
-
(sota) ConvNet : internal strided-pooling layers
- reduce spatial resolution
- output : invariant to spatial location within pooling region
- promote spatial invariance to local input transformation
- pooling : prevent over-training + reducing computational complexity for classification
- Trade-off : generalization performance ↑ <-> spatial localization accuracy ↓
-
(this paper) LCN : ConvNet for efficient localization of human joints in RGB imgs
- high spatial accuracy + computational efficiency
- begin by coarse body part localization -> output : low resolution, per-pixel heat-map
- show likelihood of a joint occurring in each spatial location
- Max-pooling for dimensionality reduction + improving invariance to noise and local img transformations
- reuse hidden layer conv features from coarse heat-map regression model to improve localization accuracy
2. Related Work
-
Models using Hand-crafted features (edges, contours, HoG, color histograms) : poor generalization performance
- Deformable Part Models (DPM)
- Mixture of templates modeled using SVMs
- Poselet + DPM mosel : spatial relationship of body parts
- Atmlets : semi-global classifier, good for real-world data, but only arms
- Multi-modal model : holistic + local
-
ConvNets
- formulate problem as a direct (continuous) regression
- poorly in high-precision region
- unnecessary learning complexity by mapping from input RGB img to XY location (over-training)
- +) low-dimensional representation of input img, multi-resolution ConvNet arch, ...
3. Coarse Heat-Map Regression Model
- Using Extension of Multi-resolution ConvNet model
- For Sliding window detector with Overlapping contexts to produce Coarse heat-map output
3.1. Model Architecture
- Input : RGB Gaussian pyramid of 3 levels (320 x 240 for FLIC, 256 x 256 for MPII)
Figure 2 : only 2 levels for brevity
- Output : Heat-map for each joint describing per-pixel likelihood for joint occurring in each output spatial location
- 1st layer : LCN (Local Contrast Normalization) with same filter kernel in each 3 resolution banks -> out : LCN imgs
- Next 7 stage (11 for MPII) multi-resolution ConvNet : Pooling -> heat-map output is at a lower resolution than input img
- Last 4 stage (3 for MPII) : effectively simulated FC network for taget input patch size
3.2. Spatial Dropout
- Dropout : zeroing activation -> improving generalization by preventing activations from becoming strongly correlated
- Additional Dropout layer before 1st 1x1 conv layer
- Standard Dropout
- Network is fully conv (1d conv) & natural imgs (so, feature map activations) are strongly correlated
- Result : over-training (Fail)
- Spatial Dropout
- Feature-map = n_features x Height x Width
- How : perform only n_features dropout trials + extend value across entire feature map
- Result : adjacent pixels are either all 0 OR all active (good performance on FLIC)
3.3. Training and Data Augmentation
- Loss : MSE
- H', H : Predicted and GT heat-map for joint
- Target GT heat-map : 2D gaussian of constant variance (sigma = 1.5 pixels) centered at GT joint (x,y)
- Data Augmentation : Random rotation, scaling, flipping -> Generalization
- Multiple people contained but Single person annotated case
- How : Sliding-window + tree-structured MRF spatial model (approximate Torso position)
- MRF Input : GT torso position + 14 predicted joints from ConvNet output = 15 joints locations
- Result : selecting correct person for labeling
4. Fine Heat-Map Regression Model
- Purpose : Recovering spatial accuracy lost due to pooling
- How : Using additional ConvNet to refine localization result of coarse heat-map
- Difference : Reusing existing conv features -> reducing # of params + acting as regularizer
4.1. Model Architecture
-
- Heat-map-based model for coarse localization
- Module to sample and crop conv features at joint location (x, y)
- Additional conv model for fine tuning
-
Joint Inference Steps
- FPROP (forward-propagate) through Coarse heat-map model
- Infer all joint locations (x, y) from max value in each joint's heat-map
- Sample and Crop first 2 conv layers (for all resolution) at each coarse location (x, y)
- FPROP through Fine heat-map model -> (△x, △y)
- Fine heat-map model : Siamese network of 7 instances (14 for MPII)
- Add Position Refinement to coarse location -> Final location (x, y) for each joint
- FPROP (forward-propagate) through Coarse heat-map model
-
- Siamese network : Weights and biases of each module are shared
- Sample location for each joint is different : Conv features don't share same Spatial context
- So, conv sub-nets must be applied to each joint independently
- But, parameter sharing to reduce # of shared params and prevent over-training
-
Last 1x1 Conv
- No weight sharing
- Input : each output of 7 sub-nets
- Output : detailed-resolution heat-map
- Purpose : Final detection for each joint
4.2. Joint Training
- Before Joint Training : Pre-training Coarse heat-map model first
- Holding params Coarse heat-map model Fixed + Training Fine heat-map model
- Jointly Training both models by minimizing E3 = E1 + λE2 ..... (λ = 0.1)
- Regression to set of target heat-maps for minimizing final (x, y) prediction
5. Results
-
Framwork : Torch7
-
Dataset : FLIC(easy), MPII-Human-Pose(hard)
-
Pooling impact for coarse heat-map model : Pooling ↑ -> Detection performance(spatial precision) ↓
-
Ambiguous GT labels : can be worse than expected variance in User-generated labels
-
Greedily-trained cascade (Shared features)
- Coarse and Fine models are trained independently by adding additional conv layer
- How : Training Fine model by using cropped input imgs as input
- Result : regularizing effect of joint training : preventing over-training [F14(a)]
-
SpatialDropout : regularizing effect of dropout + reduction in strong heat-map outliers [F14(b)]
6. Conclusion
- Localization tasks demand high degree of spatial precision
- Cascaded architecture that combined Fine and Coarse conv networks -> SOTA on FLIC, MPII-human-pose
- Spatial Precision + Computational benefits of Pooling
Code
Train
import os
import sys
import time
import argparse
import torch
import numpy as np
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchcontrib
from torchvision import transforms
from dataset.cub200 import CUB200Data
from dataset.mit67 import MIT67Data
from dataset.stanford_dog import SDog120Data
from dataset.caltech256 import Caltech257Data
from dataset.stanford_40 import Stanford40Data
from dataset.flower102 import Flower102Data
from model.fe_resnet import resnet18_dropout, resnet50_dropout, resnet101_dropout
from model.fe_mobilenet import mbnetv2_dropout
class MovingAverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=':f', momentum=0.9):
self.name = name
self.fmt = fmt
self.momentum = momentum
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
def update(self, val, n=1):
self.val = val
self.avg = self.momentum*self.avg + (1-self.momentum)*val
def __str__(self):
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
return fmtstr.format(**self.__dict__)
class ProgressMeter(object):
def __init__(self, num_batches, meters, prefix=""):
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
self.meters = meters
self.prefix = prefix
def display(self, batch):
entries = [self.prefix + self.batch_fmtstr.format(batch)]
entries += [str(meter) for meter in self.meters]
print('\t'.join(entries))
def _get_batch_fmtstr(self, num_batches):
num_digits = len(str(num_batches // 1))
fmt = '{:' + str(num_digits) + 'd}'
return '[' + fmt + '/' + fmt.format(num_batches) + ']'
class CrossEntropyLabelSmooth(nn.Module):
def __init__(self, num_classes, epsilon = 0.1):
super(CrossEntropyLabelSmooth, self).__init__()
self.num_classes = num_classes
self.epsilon = epsilon
self.logsoftmax = nn.LogSoftmax(dim=1)
def forward(self, inputs, targets):
log_probs = self.logsoftmax(inputs)
targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1)
targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
loss = (-targets * log_probs).sum(1)
return loss.mean()
def linear_l2(model):
beta_loss = 0
for m in model.modules():
if isinstance(m, nn.Linear):
beta_loss += (m.weight).pow(2).sum()
beta_loss += (m.bias).pow(2).sum()
return 0.5*beta_loss*args.beta, beta_loss
def l2sp(model, reg):
reg_loss = 0
dist = 0
for m in model.modules():
if hasattr(m, 'weight') and hasattr(m, 'old_weight'):
diff = (m.weight - m.old_weight).pow(2).sum()
dist += diff
reg_loss += diff
if hasattr(m, 'bias') and hasattr(m, 'old_bias'):
diff = (m.bias - m.old_bias).pow(2).sum()
dist += diff
reg_loss += diff
if dist > 0:
dist = dist.sqrt()
loss = (reg * reg_loss)
return loss, dist
def test(model, teacher, loader, loss=False):
with torch.no_grad():
model.eval()
if loss:
teacher.eval()
ce = CrossEntropyLabelSmooth(loader.dataset.num_classes, args.label_smoothing).to('cuda')
featloss = torch.nn.MSELoss(reduction='none')
total_ce = 0
total_feat_reg = np.zeros(len(reg_layers))
total_l2sp_reg = 0
total = 0
top1 = 0
total = 0
top1 = 0
for i, (batch, label) in enumerate(loader):
batch, label = batch.to('cuda'), label.to('cuda')
total += batch.size(0)
out = model(batch)
_, pred = out.max(dim=1)
top1 += int(pred.eq(label).sum().item())
if loss:
total_ce += ce(out, label).item()
if teacher is not None:
with torch.no_grad():
tout = teacher(batch)
for key in reg_layers:
src_x = reg_layers[key][0].out
tgt_x = reg_layers[key][1].out
tgt_channels = tgt_x.shape[1]
regloss = featloss(src_x[:,:tgt_channels,:,:], tgt_x.detach()).mean()
total_feat_reg[key] += regloss.item()
_, unweighted = l2sp(model, 0)
total_l2sp_reg += unweighted.item()
return float(top1)/total*100, total_ce/(i+1), np.sum(total_feat_reg)/(i+1), total_l2sp_reg/(i+1), total_feat_reg/(i+1)
def train(model, train_loader, val_loader, iterations=9000, lr=1e-2, name='', l2sp_lmda=1e-2, teacher=None, reg_layers={}):
model = model.to('cuda')
if l2sp_lmda == 0:
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=args.momentum, weight_decay=args.weight_decay)
else:
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=args.momentum, weight_decay=0)
end_iter = iterations
if args.swa:
optimizer = torchcontrib.optim.SWA(optimizer, swa_start=args.swa_start, swa_freq=args.swa_freq)
end_iter = args.swa_start
if args.const_lr:
scheduler = None
else:
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, end_iter)
teacher.eval()
ce = CrossEntropyLabelSmooth(train_loader.dataset.num_classes, args.label_smoothing).to('cuda')
featloss = torch.nn.MSELoss()
batch_time = MovingAverageMeter('Time', ':6.3f')
data_time = MovingAverageMeter('Data', ':6.3f')
ce_loss_meter = MovingAverageMeter('CE Loss', ':6.3f')
feat_loss_meter = MovingAverageMeter('Feat. Loss', ':6.3f')
l2sp_loss_meter = MovingAverageMeter('L2SP Loss', ':6.3f')
linear_loss_meter = MovingAverageMeter('LinearL2 Loss', ':6.3f')
total_loss_meter = MovingAverageMeter('Total Loss', ':6.3f')
top1_meter = MovingAverageMeter('Acc@1', ':6.2f')
dataloader_iterator = iter(train_loader)
for i in range(iterations):
if args.swa:
if i >= int(args.swa_start) and (i-int(args.swa_start))%args.swa_freq == 0:
scheduler = None
model.train()
optimizer.zero_grad()
end = time.time()
try:
batch, label = next(dataloader_iterator)
except:
dataloader_iterator = iter(train_loader)
batch, label = next(dataloader_iterator)
batch, label = batch.to('cuda'), label.to('cuda')
data_time.update(time.time() - end)
out = model(batch)
_, pred = out.max(dim=1)
top1_meter.update(float(pred.eq(label).sum().item()) / label.shape[0] * 100.)
loss = 0.
loss += ce(out, label)
ce_loss_meter.update(loss.item())
with torch.no_grad():
tout = teacher(batch)
# Compute the feature distillation loss only when needed
if args.feat_lmda > 0:
regloss = 0
for layer in args.feat_layers:
key = int(layer)-1
src_x = reg_layers[key][0].out
tgt_x = reg_layers[key][1].out
tgt_channels = tgt_x.shape[1]
regloss += featloss(src_x[:,:tgt_channels,:,:], tgt_x.detach())
regloss = args.feat_lmda * regloss
loss += regloss
feat_loss_meter.update(regloss.item())
beta_loss, linear_norm = linear_l2(model)
loss = loss + beta_loss
linear_loss_meter.update(beta_loss.item())
if l2sp_lmda > 0:
reg, _ = l2sp(model, l2sp_lmda)
l2sp_loss_meter.update(reg.item())
loss = loss + reg
total_loss_meter.update(loss.item())
loss.backward()
optimizer.step()
for param_group in optimizer.param_groups:
current_lr = param_group['lr']
if scheduler is not None:
scheduler.step()
batch_time.update(time.time() - end)
if (i % args.print_freq == 0) or (i == iterations-1):
progress = ProgressMeter(
iterations,
[batch_time, data_time, top1_meter, total_loss_meter, ce_loss_meter, feat_loss_meter, l2sp_loss_meter, linear_loss_meter],
prefix="LR: {:6.5f}".format(current_lr))
progress.display(i)
if (i % args.test_interval == 0) or (i == iterations-1):
test_top1, test_ce_loss, test_feat_loss, test_weight_loss, test_feat_layer_loss = test(model, teacher, val_loader, loss=True)
train_top1, train_ce_loss, train_feat_loss, train_weight_loss, train_feat_layer_loss = test(model, teacher, train_loader, loss=True)
print('Eval Train | Iteration {}/{} | Top-1: {:.2f} | CE Loss: {:.3f} | Feat Reg Loss: {:.6f} | L2SP Reg Loss: {:.3f}'.format(i+1, iterations, train_top1, train_ce_loss, train_feat_loss, train_weight_loss))
print('Eval Test | Iteration {}/{} | Top-1: {:.2f} | CE Loss: {:.3f} | Feat Reg Loss: {:.6f} | L2SP Reg Loss: {:.3f}'.format(i+1, iterations, test_top1, test_ce_loss, test_feat_loss, test_weight_loss))
if not args.no_save:
if not os.path.exists('ckpt'):
os.makedirs('ckpt')
torch.save({'state_dict': model.state_dict()}, 'ckpt/{}.pth'.format(name))
if args.swa:
optimizer.swap_swa_sgd()
for m in model.modules():
if hasattr(m, 'running_mean'):
m.reset_running_stats()
m.momentum = None
with torch.no_grad():
model.train()
for x, y in train_loader:
x = x.to('cuda')
out = model(x)
test_top1, test_ce_loss, test_feat_loss, test_weight_loss, test_feat_layer_loss = test(model, teacher, val_loader, loss=True)
train_top1, train_ce_loss, train_feat_loss, train_weight_loss, train_feat_layer_loss = test(model, teacher, train_loader, loss=True)
print('Eval Train | Iteration {}/{} | Top-1: {:.2f} | CE Loss: {:.3f} | Feat Reg Loss: {:.6f} | L2SP Reg Loss: {:.3f}'.format(i+1, iterations, train_top1, train_ce_loss, train_feat_loss, train_weight_loss))
print('Eval Test | Iteration {}/{} | Top-1: {:.2f} | CE Loss: {:.3f} | Feat Reg Loss: {:.6f} | L2SP Reg Loss: {:.3f}'.format(i+1, iterations, test_top1, test_ce_loss, test_feat_loss, test_weight_loss))
if not args.no_save:
if not os.path.exists('ckpt'):
os.makedirs('ckpt')
torch.save({'state_dict': model.state_dict()}, 'ckpt/{}.pth'.format(name))
return model
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--datapath", type=str, default='/data', help='path to the dataset')
parser.add_argument("--dataset", type=str, default='CUB200Data', help='Target dataset. Currently support: \{SDog120Data, CUB200Data, Stanford40Data, MIT67Data, Flower102Data\}')
parser.add_argument("--iterations", type=int, default=30000, help='Iterations to train')
parser.add_argument("--print_freq", type=int, default=100, help='Frequency of printing training logs')
parser.add_argument("--test_interval", type=int, default=1000, help='Frequency of testing')
parser.add_argument("--name", type=str, default='test', help='Name for the checkpoint')
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--lr", type=float, default=1e-2)
parser.add_argument("--const_lr", action='store_true', default=False, help='Use constant learning rate')
parser.add_argument("--weight_decay", type=float, default=0)
parser.add_argument("--momentum", type=float, default=0.9)
parser.add_argument("--beta", type=float, default=1e-2, help='The strength of the L2 regularization on the last linear layer')
parser.add_argument("--dropout", type=float, default=0, help='Dropout rate for spatial dropout')
parser.add_argument("--l2sp_lmda", type=float, default=0)
parser.add_argument("--feat_lmda", type=float, default=0)
parser.add_argument("--feat_layers", type=str, default='1234', help='Used for DELTA (which layers or stages to match), ResNets should be 1234 and MobileNetV2 should be 12345')
parser.add_argument("--reinit", action='store_true', default=False, help='Reinitialize before training')
parser.add_argument("--no_save", action='store_true', default=False, help='Do not save checkpoints')
parser.add_argument("--swa", action='store_true', default=False, help='Use SWA')
parser.add_argument("--swa_freq", type=int, default=500, help='Frequency of averaging models in SWA')
parser.add_argument("--swa_start", type=int, default=0, help='Start SWA since which iterations')
parser.add_argument("--label_smoothing", type=float, default=0)
parser.add_argument("--checkpoint", type=str, default='', help='Load a previously trained checkpoint')
parser.add_argument("--network", type=str, default='resnet18', help='Network architecture. Currently support: \{resnet18, resnet50, resnet101, mbnetv2\}')
parser.add_argument("--tnetwork", type=str, default='resnet18', help='Network architecture. Currently support: \{resnet18, resnet50, resnet101, mbnetv2\}')
parser.add_argument("--width_mult", type=float, default=1)
parser.add_argument("--shot", type=int, default=-1, help='Number of training samples per class for the training dataset. -1 indicates using the full dataset.')
parser.add_argument("--log", action='store_true', default=False, help='Redirect the output to log/args.name.log')
args = parser.parse_args()
return args
# Used to matching features
def record_act(self, input, output):
self.out = output
def record_act_with_1x1(self, input, output):
self.out = self[-1].dim_matching(output)
if __name__ == '__main__':
args = get_args()
if args.log:
if not os.path.exists('log'):
os.makedirs('log')
sys.stdout = open('log/{}.log'.format(args.name), 'w')
print(args)
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
# Used to make sure we sample the same image for few-shot scenarios
seed = 98
train_set = eval(args.dataset)(args.datapath, True, transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]), args.shot, seed, preload=False)
test_set = eval(args.dataset)(args.datapath, False, transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
]), args.shot, seed, preload=False)
train_loader = torch.utils.data.DataLoader(train_set,
batch_size=args.batch_size, shuffle=True,
num_workers=8, pin_memory=True)
val_loader = train_loader
test_loader = torch.utils.data.DataLoader(test_set,
batch_size=args.batch_size, shuffle=False,
num_workers=8, pin_memory=False)
model = eval('{}_dropout'.format(args.network))(pretrained=True, dropout=args.dropout, width_mult=args.width_mult, num_classes=train_loader.dataset.num_classes).cuda()
if args.checkpoint != '':
checkpoint = torch.load(args.checkpoint)
model.load_state_dict(checkpoint['state_dict'])
# Pre-trained model
teacher = eval('{}_dropout'.format(args.tnetwork))(pretrained=True, dropout=0, num_classes=train_loader.dataset.num_classes).cuda()
if 'mbnetv2' in args.network:
reg_layers = {0: [model.layer1], 1: [model.layer2], 2: [model.layer3], 3: [model.layer4], 4: [model.layer5]}
model.layer1.register_forward_hook(record_act)
model.layer2.register_forward_hook(record_act)
model.layer3.register_forward_hook(record_act)
model.layer4.register_forward_hook(record_act)
model.layer5.register_forward_hook(record_act)
else:
reg_layers = {0: [model.layer1], 1: [model.layer2], 2: [model.layer3], 3: [model.layer4]}
# if args.width_mult > 1:
# model.layer1.register_forward_hook(record_act_with_1x1)
# model.layer2.register_forward_hook(record_act_with_1x1)
# model.layer3.register_forward_hook(record_act_with_1x1)
# model.layer4.register_forward_hook(record_act_with_1x1)
# model.layer1[-1].dim_matching = torch.nn.Conv2d(model.layer1[-1].out_dim, int(model.layer1[-1].out_dim/args.width_mult), kernel_size=1, bias=False).cuda()
# model.layer2[-1].dim_matching = torch.nn.Conv2d(model.layer2[-1].out_dim, int(model.layer2[-1].out_dim/args.width_mult), kernel_size=1, bias=False).cuda()
# model.layer3[-1].dim_matching = torch.nn.Conv2d(model.layer3[-1].out_dim, int(model.layer3[-1].out_dim/args.width_mult), kernel_size=1, bias=False).cuda()
# model.layer4[-1].dim_matching = torch.nn.Conv2d(model.layer4[-1].out_dim, int(model.layer4[-1].out_dim/args.width_mult), kernel_size=1, bias=False).cuda()
# else:
# model.layer1.register_forward_hook(record_act)
# model.layer2.register_forward_hook(record_act)
# model.layer3.register_forward_hook(record_act)
# model.layer4.register_forward_hook(record_act)
model.layer1.register_forward_hook(record_act_with_1x1)
model.layer2.register_forward_hook(record_act_with_1x1)
model.layer3.register_forward_hook(record_act_with_1x1)
model.layer4.register_forward_hook(record_act_with_1x1)
model.layer1[-1].dim_matching = torch.nn.Conv2d(model.layer1[-1].out_dim, int(teacher.layer1[-1].out_dim/args.width_mult), kernel_size=1, bias=False).cuda()
model.layer2[-1].dim_matching = torch.nn.Conv2d(model.layer2[-1].out_dim, int(teacher.layer2[-1].out_dim/args.width_mult), kernel_size=1, bias=False).cuda()
model.layer3[-1].dim_matching = torch.nn.Conv2d(model.layer3[-1].out_dim, int(teacher.layer3[-1].out_dim/args.width_mult), kernel_size=1, bias=False).cuda()
model.layer4[-1].dim_matching = torch.nn.Conv2d(model.layer4[-1].out_dim, int(teacher.layer4[-1].out_dim/args.width_mult), kernel_size=1, bias=False).cuda()
# Stored pre-trained weights for computing L2SP
for m in model.modules():
if hasattr(m, 'weight') and not hasattr(m, 'old_weight'):
m.old_weight = m.weight.data.clone().detach()
# all_weights = torch.cat([all_weights.reshape(-1), m.weight.data.abs().reshape(-1)], dim=0)
if hasattr(m, 'bias') and not hasattr(m, 'old_bias') and m.bias is not None:
m.old_bias = m.bias.data.clone().detach()
if args.reinit:
for m in model.modules():
if type(m) in [nn.Linear, nn.BatchNorm2d, nn.Conv2d]:
m.reset_parameters()
reg_layers[0].append(teacher.layer1)
teacher.layer1.register_forward_hook(record_act)
reg_layers[1].append(teacher.layer2)
teacher.layer2.register_forward_hook(record_act)
reg_layers[2].append(teacher.layer3)
teacher.layer3.register_forward_hook(record_act)
reg_layers[3].append(teacher.layer4)
teacher.layer4.register_forward_hook(record_act)
if '5' in args.feat_layers:
reg_layers[4].append(teacher.layer5)
teacher.layer5.register_forward_hook(record_act)
train(model, train_loader, test_loader, l2sp_lmda=args.l2sp_lmda, iterations=args.iterations, lr=args.lr, name='{}'.format(args.name), teacher=teacher, reg_layers=reg_layers)
Eval
import argparse
import torch
import time
import sys
import numpy as np
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchcontrib
from PIL import Image
from torchvision import transforms
from dataset.cub200 import CUB200Data
from dataset.mit67 import MIT67Data
from dataset.stanford_dog import SDog120Data
from dataset.caltech256 import Caltech257Data
from dataset.stanford_40 import Stanford40Data
from dataset.flower102 import Flower102Data
from advertorch.attacks import LinfPGDAttack
from model.fe_resnet import resnet18_dropout, resnet50_dropout, resnet101_dropout
from model.fe_mobilenet import mbnetv2_dropout
from model.fe_resnet import feresnet18, feresnet50, feresnet101
from model.fe_mobilenet import fembnetv2
def test(model, loader, adversary):
model.eval()
total_ce = 0
total = 0
top1 = 0
total = 0
top1_clean = 0
top1_adv = 0
adv_success = 0
adv_trial = 0
for i, (batch, label) in enumerate(loader):
batch, label = batch.to('cuda'), label.to('cuda')
total += batch.size(0)
out_clean = model(batch)
if 'mbnetv2' in args.network:
y = torch.zeros(batch.shape[0], model.classifier[1].in_features).cuda()
else:
y = torch.zeros(batch.shape[0], model.fc.in_features).cuda()
y[:,0] = args.m
advbatch = adversary.perturb(batch, y)
out_adv = model(advbatch)
_, pred_clean = out_clean.max(dim=1)
_, pred_adv = out_adv.max(dim=1)
clean_correct = pred_clean.eq(label)
adv_trial += int(clean_correct.sum().item())
adv_success += int(pred_adv[clean_correct].eq(label[clean_correct]).sum().item())
top1_clean += int(pred_clean.eq(label).sum().item())
top1_adv += int(pred_adv.eq(label).sum().item())
print('{}/{}...'.format(i+1, len(loader)))
return float(top1_clean)/total*100, float(top1_adv)/total*100, float(adv_trial-adv_success) / adv_trial *100
def record_act(self, input, output):
pass
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--datapath", type=str, default='/data', help='path to the dataset')
parser.add_argument("--dataset", type=str, default='CUB200Data', help='Target dataset. Currently support: \{SDog120Data, CUB200Data, Stanford40Data, MIT67Data, Flower102Data\}')
parser.add_argument("--name", type=str, default='test')
parser.add_argument("--B", type=float, default=0.1, help='Attack budget')
parser.add_argument("--m", type=float, default=1000, help='Hyper-parameter for task-agnostic attack')
parser.add_argument("--pgd_iter", type=int, default=40)
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--dropout", type=float, default=0)
parser.add_argument("--checkpoint", type=str, default='')
parser.add_argument("--network", type=str, default='resnet18', help='Network architecture. Currently support: \{resnet18, resnet50, resnet101, mbnetv2\}')
args = parser.parse_args()
return args
def myloss(yhat, y):
return -((yhat[:,0]-y[:,0])**2 + 0.1*((yhat[:,1:]-y[:,1:])**2).mean(1)).mean()
if __name__ == '__main__':
args = get_args()
print(args)
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
seed = int(time.time())
test_set = eval(args.dataset)(args.datapath, False, transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
]), -1, seed, preload=False)
test_loader = torch.utils.data.DataLoader(test_set,
batch_size=args.batch_size, shuffle=False,
num_workers=8, pin_memory=False)
transferred_model = eval('{}_dropout'.format(args.network))(pretrained=False, dropout=args.dropout, num_classes=test_loader.dataset.num_classes).cuda()
checkpoint = torch.load(args.checkpoint)
transferred_model.load_state_dict(checkpoint['state_dict'])
pretrained_model = eval('fe{}'.format(args.network))(pretrained=True).cuda().eval()
adversary = LinfPGDAttack(
pretrained_model, loss_fn=myloss, eps=args.B,
nb_iter=args.pgd_iter, eps_iter=0.01, rand_init=True, clip_min=-2.2, clip_max=2.2,
targeted=False)
clean_top1, adv_top1, adv_sr = test(transferred_model, test_loader, adversary)
print('Clean Top-1: {:.2f} | Adv Top-1: {:.2f} | Attack Success Rate: {:.2f}'.format(clean_top1, adv_top1, adv_sr))