Source code for torchattacks.attacks.apgdt

import time

import numpy as np

import torch

from ..attack import Attack


[docs] class APGDT(Attack): r""" APGD-Targeted in the paper 'Reliable evaluation of adversarial robustness with an ensemble of diverse parameter-free attacks.' Targeted attack for every wrong classes. [https://arxiv.org/abs/2003.01690] [https://github.com/fra31/auto-attack] Distance Measure : Linf, L2 Arguments: model (nn.Module): model to attack. norm (str): Lp-norm of the attack. ['Linf', 'L2'] (Default: 'Linf') eps (float): maximum perturbation. (Default: 8/255) steps (int): number of steps. (Default: 10) n_restarts (int): number of random restarts. (Default: 1) seed (int): random seed for the starting point. (Default: 0) eot_iter (int): number of iteration for EOT. (Default: 1) rho (float): parameter for step-size update (Default: 0.75) verbose (bool): print progress. (Default: False) n_classes (int): number of classes. (Default: 10) Shape: - images: :math:`(N, C, H, W)` where `N = number of batches`, `C = number of channels`, `H = height` and `W = width`. It must have a range [0, 1]. - labels: :math:`(N)` where each value :math:`y_i` is :math:`0 \leq y_i \leq` `number of labels`. - output: :math:`(N, C, H, W)`. Examples:: >>> attack = torchattacks.APGDT(model, norm='Linf', eps=8/255, steps=10, n_restarts=1, seed=0, eot_iter=1, rho=.75, verbose=False, n_classes=10) >>> adv_images = attack(images, labels) """ def __init__( self, model, norm="Linf", eps=8 / 255, steps=10, n_restarts=1, seed=0, eot_iter=1, rho=0.75, verbose=False, n_classes=10, ): super().__init__("APGDT", model) self.eps = eps self.steps = steps self.norm = norm self.n_restarts = n_restarts self.seed = seed self.eot_iter = eot_iter self.thr_decr = rho self.verbose = verbose self.target_class = None self.n_target_classes = n_classes - 1 self.supported_mode = ["default"]
[docs] def forward(self, images, labels): r""" Overridden. """ images = images.clone().detach().to(self.device) labels = labels.clone().detach().to(self.device) _, adv_images = self.perturb(images, labels, cheap=True) return adv_images
def check_oscillation(self, x, j, k, y5, k3=0.5): t = np.zeros(x.shape[1]) for counter5 in range(k): t += x[j - counter5] > x[j - counter5 - 1] return t <= k * k3 * np.ones(t.shape) def check_shape(self, x): return x if len(x.shape) > 0 else np.expand_dims(x, 0) def dlr_loss_targeted(self, x, y, y_target): x_sorted, ind_sorted = x.sort(dim=1) return -(x[np.arange(x.shape[0]), y] - x[np.arange(x.shape[0]), y_target]) / ( x_sorted[:, -1] - 0.5 * x_sorted[:, -3] - 0.5 * x_sorted[:, -4] + 1e-12 ) def attack_single_run(self, x_in, y_in): x = x_in.clone() if len(x_in.shape) == 4 else x_in.clone().unsqueeze(0) y = y_in.clone() if len(y_in.shape) == 1 else y_in.clone().unsqueeze(0) self.steps_2, self.steps_min, self.size_decr = ( max(int(0.22 * self.steps), 1), max(int(0.06 * self.steps), 1), max(int(0.03 * self.steps), 1), ) # nopep8 if self.verbose: print( "parameters: ", self.steps, self.steps_2, self.steps_min, self.size_decr ) # nopep8 if self.norm == "Linf": t = 2 * torch.rand(x.shape).to(self.device).detach() - 1 x_adv = x.detach() + self.eps * torch.ones([x.shape[0], 1, 1, 1]).to( self.device ).detach() * t / ( t.reshape([t.shape[0], -1]) .abs() .max(dim=1, keepdim=True)[0] .reshape([-1, 1, 1, 1]) ) # nopep8 elif self.norm == "L2": t = torch.randn(x.shape).to(self.device).detach() x_adv = x.detach() + self.eps * torch.ones([x.shape[0], 1, 1, 1]).to( self.device ).detach() * t / ( (t ** 2).sum(dim=(1, 2, 3), keepdim=True).sqrt() + 1e-12 ) # nopep8 x_adv = x_adv.clamp(0.0, 1.0) x_best = x_adv.clone() x_best_adv = x_adv.clone() loss_steps = torch.zeros([self.steps, x.shape[0]]) loss_best_steps = torch.zeros([self.steps + 1, x.shape[0]]) acc_steps = torch.zeros_like(loss_best_steps) output = self.get_logits(x) y_target = output.sort(dim=1)[1][:, -self.target_class] x_adv.requires_grad_() grad = torch.zeros_like(x) for _ in range(self.eot_iter): with torch.enable_grad(): # 1 forward pass (eot_iter = 1) logits = self.get_logits(x_adv) loss_indiv = self.dlr_loss_targeted(logits, y, y_target) loss = loss_indiv.sum() # 1 backward pass (eot_iter = 1) grad += torch.autograd.grad(loss, [x_adv])[0].detach() grad /= float(self.eot_iter) grad_best = grad.clone() acc = logits.detach().max(1)[1] == y acc_steps[0] = acc + 0 loss_best = loss_indiv.detach().clone() step_size = ( self.eps * torch.ones([x.shape[0], 1, 1, 1]).to(self.device).detach() * torch.Tensor([2.0]).to(self.device).detach().reshape([1, 1, 1, 1]) ) # nopep8 x_adv_old = x_adv.clone() # counter = 0 k = self.steps_2 + 0 u = np.arange(x.shape[0]) counter3 = 0 loss_best_last_check = loss_best.clone() reduced_last_check = np.zeros(loss_best.shape) == np.zeros(loss_best.shape) # n_reduced = 0 for i in range(self.steps): # gradient step with torch.no_grad(): x_adv = x_adv.detach() grad2 = x_adv - x_adv_old x_adv_old = x_adv.clone() a = 0.75 if i > 0 else 1.0 if self.norm == "Linf": x_adv_1 = x_adv + step_size * torch.sign(grad) x_adv_1 = torch.clamp( torch.min(torch.max(x_adv_1, x - self.eps), x + self.eps), 0.0, 1.0, ) # nopep8 x_adv_1 = torch.clamp( torch.min( torch.max( x_adv + (x_adv_1 - x_adv) * a + grad2 * (1 - a), x - self.eps, ), x + self.eps, ), 0.0, 1.0, ) # nopep8 elif self.norm == "L2": x_adv_1 = x_adv + step_size[0] * grad / ( (grad ** 2).sum(dim=(1, 2, 3), keepdim=True).sqrt() + 1e-12 ) # nopep8 x_adv_1 = torch.clamp( x + (x_adv_1 - x) / ( ((x_adv_1 - x) ** 2).sum(dim=(1, 2, 3), keepdim=True).sqrt() + 1e-12 ) * torch.min( self.eps * torch.ones(x.shape).to(self.device).detach(), ((x_adv_1 - x) ** 2) .sum(dim=(1, 2, 3), keepdim=True) .sqrt(), ), 0.0, 1.0, ) # nopep8 x_adv_1 = x_adv + (x_adv_1 - x_adv) * a + grad2 * (1 - a) x_adv_1 = torch.clamp( x + (x_adv_1 - x) / ( ((x_adv_1 - x) ** 2).sum(dim=(1, 2, 3), keepdim=True).sqrt() + 1e-12 ) * torch.min( self.eps * torch.ones(x.shape).to(self.device).detach(), ((x_adv_1 - x) ** 2).sum(dim=(1, 2, 3), keepdim=True).sqrt() + 1e-12, ), 0.0, 1.0, ) # nopep8 x_adv = x_adv_1 + 0.0 # get gradient x_adv.requires_grad_() grad = torch.zeros_like(x) for _ in range(self.eot_iter): with torch.enable_grad(): # 1 forward pass (eot_iter = 1) logits = self.get_logits(x_adv) loss_indiv = self.dlr_loss_targeted(logits, y, y_target) loss = loss_indiv.sum() # 1 backward pass (eot_iter = 1) grad += torch.autograd.grad(loss, [x_adv])[0].detach() grad /= float(self.eot_iter) pred = logits.detach().max(1)[1] == y acc = torch.min(acc, pred) acc_steps[i + 1] = acc + 0 x_best_adv[(pred == 0).nonzero().squeeze()] = ( x_adv[(pred == 0).nonzero().squeeze()] + 0.0 ) if self.verbose: print("iteration: {} - Best loss: {:.6f}".format(i, loss_best.sum())) # check step size with torch.no_grad(): y1 = loss_indiv.detach().clone() loss_steps[i] = y1.cpu() + 0 ind = (y1 > loss_best).nonzero().squeeze() x_best[ind] = x_adv[ind].clone() grad_best[ind] = grad[ind].clone() loss_best[ind] = y1[ind] + 0 loss_best_steps[i + 1] = loss_best + 0 counter3 += 1 if counter3 == k: fl_oscillation = self.check_oscillation( loss_steps.detach().cpu().numpy(), i, k, loss_best.detach().cpu().numpy(), k3=self.thr_decr, ) # nopep8 fl_reduce_no_impr = (~reduced_last_check) * ( loss_best_last_check.cpu().numpy() >= loss_best.cpu().numpy() ) # nopep8 fl_oscillation = ~(~fl_oscillation * ~fl_reduce_no_impr) reduced_last_check = np.copy(fl_oscillation) loss_best_last_check = loss_best.clone() if np.sum(fl_oscillation) > 0: step_size[u[fl_oscillation]] /= 2.0 # n_reduced = fl_oscillation.astype(float).sum() fl_oscillation = np.where(fl_oscillation) x_adv[fl_oscillation] = x_best[fl_oscillation].clone() grad[fl_oscillation] = grad_best[fl_oscillation].clone() counter3 = 0 k = np.maximum(k - self.size_decr, self.steps_min) return x_best, acc, loss_best, x_best_adv def perturb(self, x_in, y_in, best_loss=False, cheap=True): assert self.norm in ["Linf", "L2"] x = x_in.clone() if len(x_in.shape) == 4 else x_in.clone().unsqueeze(0) y = y_in.clone() if len(y_in.shape) == 1 else y_in.clone().unsqueeze(0) adv = x.clone() acc = self.get_logits(x).max(1)[1] == y # loss = -1e10 * torch.ones_like(acc).float() if self.verbose: print( "-------------------------- running {}-attack with epsilon {:.4f} --------------------------".format( self.norm, self.eps ) ) print("initial accuracy: {:.2%}".format(acc.float().mean())) startt = time.time() torch.random.manual_seed(self.seed) torch.cuda.random.manual_seed(self.seed) if not cheap: raise ValueError("not implemented yet") else: for target_class in range(2, self.n_target_classes + 2): self.target_class = target_class for counter in range(self.n_restarts): ind_to_fool = acc.nonzero().squeeze() if len(ind_to_fool.shape) == 0: ind_to_fool = ind_to_fool.unsqueeze(0) if ind_to_fool.numel() != 0: x_to_fool, y_to_fool = ( x[ind_to_fool].clone(), y[ind_to_fool].clone(), ) # nopep8 ( best_curr, acc_curr, loss_curr, adv_curr, ) = self.attack_single_run( x_to_fool, y_to_fool ) # nopep8 ind_curr = (acc_curr == 0).nonzero().squeeze() acc[ind_to_fool[ind_curr]] = 0 adv[ind_to_fool[ind_curr]] = adv_curr[ind_curr].clone() if self.verbose: print( "restart {} - target_class {} - robust accuracy: {:.2%} at eps = {:.5f} - cum. time: {:.1f} s".format( counter, self.target_class, acc.float().mean(), self.eps, time.time() - startt, ) ) return acc, adv