Source code for torchattacks.attacks.cw

import torch
import torch.nn as nn
import torch.optim as optim

from ..attack import Attack


[docs] class CW(Attack): r""" CW in the paper 'Towards Evaluating the Robustness of Neural Networks' [https://arxiv.org/abs/1608.04644] Distance Measure : L2 Arguments: model (nn.Module): model to attack. c (float): c in the paper. parameter for box-constraint. (Default: 1) :math:`minimize \Vert\frac{1}{2}(tanh(w)+1)-x\Vert^2_2+c\cdot f(\frac{1}{2}(tanh(w)+1))` kappa (float): kappa (also written as 'confidence') in the paper. (Default: 0) :math:`f(x')=max(max\{Z(x')_i:i\neq t\} -Z(x')_t, - \kappa)` steps (int): number of steps. (Default: 50) lr (float): learning rate of the Adam optimizer. (Default: 0.01) .. warning:: With default c, you can't easily get adversarial images. Set higher c like 1. Shape: - images: :math:`(N, C, H, W)` where `N = number of batches`, `C = number of channels`, `H = height` and `W = width`. It must have a range [0, 1]. - labels: :math:`(N)` where each value :math:`y_i` is :math:`0 \leq y_i \leq` `number of labels`. - output: :math:`(N, C, H, W)`. Examples:: >>> attack = torchattacks.CW(model, c=1, kappa=0, steps=50, lr=0.01) >>> adv_images = attack(images, labels) .. note:: Binary search for c is NOT IMPLEMENTED methods in the paper due to time consuming. """ def __init__(self, model, c=1, kappa=0, steps=50, lr=0.01): super().__init__("CW", model) self.c = c self.kappa = kappa self.steps = steps self.lr = lr self.supported_mode = ["default", "targeted"]
[docs] def forward(self, images, labels): r""" Overridden. """ images = images.clone().detach().to(self.device) labels = labels.clone().detach().to(self.device) if self.targeted: target_labels = self.get_target_label(images, labels) # w = torch.zeros_like(images).detach() # Requires 2x times w = self.inverse_tanh_space(images).detach() w.requires_grad = True best_adv_images = images.clone().detach() best_L2 = 1e10 * torch.ones((len(images))).to(self.device) prev_cost = 1e10 dim = len(images.shape) MSELoss = nn.MSELoss(reduction="none") Flatten = nn.Flatten() optimizer = optim.Adam([w], lr=self.lr) for step in range(self.steps): # Get adversarial images adv_images = self.tanh_space(w) # Calculate loss current_L2 = MSELoss(Flatten(adv_images), Flatten(images)).sum(dim=1) L2_loss = current_L2.sum() outputs = self.get_logits(adv_images) if self.targeted: f_loss = self.f(outputs, target_labels).sum() else: f_loss = self.f(outputs, labels).sum() cost = L2_loss + self.c * f_loss optimizer.zero_grad() cost.backward() optimizer.step() # Update adversarial images pre = torch.argmax(outputs.detach(), 1) if self.targeted: # We want to let pre == target_labels in a targeted attack condition = (pre == target_labels).float() else: # If the attack is not targeted we simply make these two values unequal condition = (pre != labels).float() # Filter out images that get either correct predictions or non-decreasing loss, # i.e., only images that are both misclassified and loss-decreasing are left mask = condition * (best_L2 > current_L2.detach()) best_L2 = mask * current_L2.detach() + (1 - mask) * best_L2 mask = mask.view([-1] + [1] * (dim - 1)) best_adv_images = mask * adv_images.detach() + (1 - mask) * best_adv_images # Early stop when loss does not converge. # max(.,1) To prevent MODULO BY ZERO error in the next step. if step % max(self.steps // 10, 1) == 0: if cost.item() > prev_cost: return best_adv_images prev_cost = cost.item() return best_adv_images
def tanh_space(self, x): return 1 / 2 * (torch.tanh(x) + 1) def inverse_tanh_space(self, x): # torch.atanh is only for torch >= 1.7.0 # atanh is defined in the range -1 to 1 return self.atanh(torch.clamp(x * 2 - 1, min=-1, max=1)) def atanh(self, x): return 0.5 * torch.log((1 + x) / (1 - x)) # f-function in the paper def f(self, outputs, labels): one_hot_labels = torch.eye(outputs.shape[1]).to(self.device)[labels] # find the max logit other than the target class other = torch.max((1 - one_hot_labels) * outputs, dim=1)[0] # get the target class's logit real = torch.max(one_hot_labels * outputs, dim=1)[0] if self.targeted: return torch.clamp((other - real), min=-self.kappa) else: return torch.clamp((real - other), min=-self.kappa)