Source code for torchattacks.attacks.sparsefool

import numpy as np

import torch

from ..attack import Attack
from .deepfool import DeepFool


[docs] class SparseFool(Attack): r""" Attack in the paper 'SparseFool: a few pixels make a big difference' [https://arxiv.org/abs/1811.02248] Modified from "https://github.com/LTS4/SparseFool/" Distance Measure : L0 Arguments: model (nn.Module): model to attack. steps (int): number of steps. (Default: 10) lam (float): parameter for scaling DeepFool noise. (Default: 3) overshoot (float): parameter for enhancing the noise. (Default: 0.02) Shape: - images: :math:`(N, C, H, W)` where `N = number of batches`, `C = number of channels`, `H = height` and `W = width`. It must have a range [0, 1]. - labels: :math:`(N)` where each value :math:`y_i` is :math:`0 \leq y_i \leq` `number of labels`. - output: :math:`(N, C, H, W)`. Examples:: >>> attack = torchattacks.SparseFool(model, steps=10, lam=3, overshoot=0.02) >>> adv_images = attack(images, labels) """ def __init__(self, model, steps=10, lam=3, overshoot=0.02): super().__init__("SparseFool", model) self.steps = steps self.lam = lam self.overshoot = overshoot self.deepfool = DeepFool(model) self.supported_mode = ["default"]
[docs] def forward(self, images, labels): r""" Overridden. """ images = images.clone().detach().to(self.device) labels = labels.clone().detach().to(self.device) batch_size = len(images) correct = torch.tensor([True] * batch_size) curr_steps = 0 adv_images = [] for idx in range(batch_size): image = images[idx : idx + 1].clone().detach() adv_images.append(image) while (True in correct) and (curr_steps < self.steps): for idx in range(batch_size): image = images[idx : idx + 1] label = labels[idx : idx + 1] adv_image = adv_images[idx] fs = self.get_logits(adv_image)[0] _, pre = torch.max(fs, dim=0) if pre != label: correct[idx] = False continue adv_image, target_label = self.deepfool.forward_return_target_labels( adv_image, label ) adv_image = image + self.lam * (adv_image - image) adv_image.requires_grad = True fs = self.get_logits(adv_image)[0] _, pre = torch.max(fs, dim=0) if pre == label: pre = target_label cost = fs[pre] - fs[label] grad = torch.autograd.grad( cost, adv_image, retain_graph=False, create_graph=False )[0] grad = grad / grad.norm() adv_image = self._linear_solver(image, grad, adv_image) adv_image = image + (1 + self.overshoot) * (adv_image - image) adv_images[idx] = torch.clamp(adv_image, min=0, max=1).detach() curr_steps += 1 adv_images = torch.cat(adv_images).detach() return adv_images
def _linear_solver(self, x_0, coord_vec, boundary_point): input_shape = x_0.size() plane_normal = coord_vec.clone().detach().view(-1) plane_point = boundary_point.clone().detach().view(-1) x_i = x_0.clone().detach() f_k = torch.dot(plane_normal, x_0.view(-1) - plane_point) sign_true = f_k.sign().item() beta = 0.001 * sign_true current_sign = sign_true while current_sign == sign_true and coord_vec.nonzero().size()[0] > 0: f_k = torch.dot(plane_normal, x_i.view(-1) - plane_point) + beta pert = f_k.abs() / coord_vec.abs().max() mask = torch.zeros_like(coord_vec) mask[ np.unravel_index(torch.argmax(coord_vec.abs()).cpu(), input_shape) ] = 1.0 # nopep8 r_i = torch.clamp(pert, min=1e-4) * mask * coord_vec.sign() x_i = x_i + r_i x_i = torch.clamp(x_i, min=0, max=1) f_k = torch.dot(plane_normal, x_i.view(-1) - plane_point) current_sign = f_k.sign().item() coord_vec[r_i != 0] = 0 return x_i