"""
Code is from https://github.com/BorealisAI/advertorch/blob/master/advertorch/attacks/spsa.py
"""
import torch
from torch.nn.modules.loss import _Loss
from ..attack import Attack
[docs]
class MarginalLoss(_Loss):
def forward(self, logits, targets):
assert logits.shape[-1] >= 2
top_logits, top_classes = torch.topk(logits, 2, dim=-1)
target_logits = logits[torch.arange(logits.shape[0]), targets]
max_nontarget_logits = torch.where(
top_classes[..., 0] == targets, top_logits[..., 1], top_logits[..., 0],
)
loss = max_nontarget_logits - target_logits
if self.reduction == "none":
pass
elif self.reduction == "sum":
loss = loss.sum()
elif self.reduction == "mean":
loss = loss.mean()
else:
raise ValueError("unknown reduction: '%s'" % (self.recution,))
return loss
[docs]
class SPSA(Attack):
r"""
SPSA in the paper 'Adversarial Risk and the Dangers of Evaluating Against Weak Attacks'
[https://arxiv.org/abs/1802.05666]
Distance Measure : Linf
Arguments:
model (nn.Module): model to attack.
eps (float): maximum perturbation. (Default: 8/255)
delta (float): scaling parameter of SPSA. (Default: 0.01)
lr (float): the learning rate of the `Adam` optimizer. (Default: 0.01)
nb_iter (int): number of iterations of the attack. (Default: 1)
nb_sample (int): number of samples for SPSA gradient approximation. (Default: 128)
max_batch_size (int): maximum batch size to be evaluated at once. (Default: 64)
Shape:
- images: :math:`(N, C, H, W)` where `N = number of batches`, `C = number of channels`, `H = height` and `W = width`. It must have a range [0, 1].
- labels: :math:`(N)` where each value :math:`y_i` is :math:`0 \leq y_i \leq` `number of labels`.
- output: :math:`(N, C, H, W)`.
Examples::
>>> attack = torchattacks.SPSA(model, eps=0.3)
>>> adv_images = attack(images, labels)
"""
def __init__(
self,
model,
eps=0.3,
delta=0.01,
lr=0.01,
nb_iter=1,
nb_sample=128,
max_batch_size=64,
):
super().__init__("SPSA", model)
self.eps = eps
self.delta = delta
self.lr = lr
self.nb_iter = nb_iter
self.nb_sample = nb_sample
self.max_batch_size = max_batch_size
self.loss_fn = MarginalLoss(reduction="none")
self.supported_mode = ["default", "targeted"]
[docs]
def forward(self, images, labels):
r"""
Overridden.
"""
images = images.clone().detach().to(self.device)
labels = labels.clone().detach().to(self.device)
adv_images = self.spsa_perturb(images, labels)
return adv_images
def loss(self, *args):
if self.targeted:
return self.loss_fn(*args)
else:
return -self.loss_fn(*args)
[docs]
def linf_clamp_(self, dx, x, eps):
"""Clamps perturbation `dx` to fit L_inf norm and image bounds.
Limit the L_inf norm of `dx` to be <= `eps`, and the bounds of `x + dx`
to be in `[clip_min, clip_max]`.
Return: the clamped perturbation `dx`.
"""
# dx_clamped = self.batch_clamp(eps, dx)
dx_clamped = torch.clamp(dx, min=-eps, max=eps)
# x_adv = self.clamp(x + dx_clamped, clip_min, clip_max)
x_adv = torch.clamp(x + dx_clamped, min=0, max=1)
# `dx` is changed *inplace* so the optimizer will keep
# tracking it. the simplest mechanism for inplace was
# adding the difference between the new value `x_adv - x`
# and the old value `dx`.
dx += x_adv - x - dx
return dx
def _get_batch_sizes(self, n, max_batch_size):
batches = [max_batch_size for _ in range(n // max_batch_size)]
if n % max_batch_size > 0:
batches.append(n % max_batch_size)
return batches
@torch.no_grad()
def spsa_grad(self, images, labels, delta, nb_sample, max_batch_size):
"""Uses SPSA method to apprixmate gradient w.r.t `x`.
Use the SPSA method to approximate the gradient of `loss(predict(x), y)`
with respect to `x`, based on the nonce `v`.
Return the approximated gradient of `loss_fn(predict(x), y)` with respect to `x`.
"""
grad = torch.zeros_like(images)
images = torch.unsqueeze(images, 0)
labels = torch.unsqueeze(labels, 0)
def f(xvar, yvar):
return self.loss(self.get_logits(xvar), yvar)
images = images.expand(max_batch_size, *images.shape[1:]).contiguous()
labels = labels.expand(max_batch_size, *labels.shape[1:]).contiguous()
v = torch.empty_like(images[:, :1, ...])
for batch_size in self._get_batch_sizes(nb_sample, max_batch_size):
x_ = images[:batch_size]
y_ = labels[:batch_size]
vb = v[:batch_size]
vb = vb.bernoulli_().mul_(2.0).sub_(1.0)
v_ = vb.expand_as(x_).contiguous()
x_shape = x_.shape
x_ = x_.view(-1, *images.shape[2:])
y_ = y_.view(-1, *labels.shape[2:])
v_ = v_.view(-1, *v.shape[2:])
df = f(x_ + delta * v_, y_) - f(x_ - delta * v_, y_)
df = df.view(-1, *[1 for _ in v_.shape[1:]])
grad_ = df / (2.0 * delta * v_)
grad_ = grad_.view(x_shape)
grad_ = grad_.sum(dim=0, keepdim=False)
grad += grad_
grad /= nb_sample
return grad
def spsa_perturb(self, x, y):
dx = torch.zeros_like(x)
dx.grad = torch.zeros_like(dx)
optimizer = torch.optim.Adam([dx], lr=self.lr)
for _ in range(self.nb_iter):
optimizer.zero_grad()
dx.grad = self.spsa_grad(
x + dx, y, self.delta, self.nb_sample, self.max_batch_size
)
optimizer.step()
dx = self.linf_clamp_(dx, x, self.eps)
x_adv = x + dx
return x_adv