import torch
from ..attack import Attack
[docs]
class MultiAttack(Attack):
r"""
MultiAttack is a class to attack a model with various attacks agains same images and labels.
Arguments:
model (nn.Module): model to attack.
attacks (list): list of attacks.
Examples::
>>> atk1 = torchattacks.PGD(model, eps=8/255, alpha=2/255, iters=40, random_start=True)
>>> atk2 = torchattacks.PGD(model, eps=8/255, alpha=2/255, iters=40, random_start=True)
>>> atk = torchattacks.MultiAttack([atk1, atk2])
>>> adv_images = attack(images, labels)
"""
def __init__(self, attacks, verbose=False):
super().__init__("MultiAttack", attacks[0].model)
self.attacks = attacks
self.verbose = verbose
self.supported_mode = ["default"]
self.check_validity()
self._accumulate_multi_atk_records = False
self._multi_atk_records = [0.0]
def check_validity(self):
if len(self.attacks) < 2:
raise ValueError("More than two attacks should be given.")
ids = [id(attack.model) for attack in self.attacks]
if len(set(ids)) != 1:
raise ValueError(
"At least one of attacks is referencing a different model."
)
[docs]
def forward(self, images, labels):
r"""
Overridden.
"""
batch_size = images.shape[0]
fails = torch.arange(batch_size).to(self.device)
final_images = images.clone().detach().to(self.device)
labels = labels.clone().detach().to(self.device)
multi_atk_records = [batch_size]
for _, attack in enumerate(self.attacks):
adv_images = attack(images[fails], labels[fails])
outputs = self.get_logits(adv_images)
_, pre = torch.max(outputs.data, 1)
corrects = pre == labels[fails]
wrongs = ~corrects
succeeds = torch.masked_select(fails, wrongs)
succeeds_of_fails = torch.masked_select(
torch.arange(fails.shape[0]).to(self.device), wrongs
)
final_images[succeeds] = adv_images[succeeds_of_fails]
fails = torch.masked_select(fails, corrects)
multi_atk_records.append(len(fails))
if len(fails) == 0:
break
if self.verbose:
print(self._return_sr_record(multi_atk_records))
if self._accumulate_multi_atk_records:
self._update_multi_atk_records(multi_atk_records)
return final_images
def _clear_multi_atk_records(self):
self._multi_atk_records = [0.0]
def _covert_to_success_rates(self, multi_atk_records):
sr = [
((1 - multi_atk_records[i] / multi_atk_records[0]) * 100)
for i in range(1, len(multi_atk_records))
]
return sr
def _return_sr_record(self, multi_atk_records):
sr = self._covert_to_success_rates(multi_atk_records)
return "Attack success rate: " + " | ".join(["%2.2f %%" % item for item in sr])
def _update_multi_atk_records(self, multi_atk_records):
for i, item in enumerate(multi_atk_records):
self._multi_atk_records[i] += item
[docs]
def save(
self,
data_loader,
save_path=None,
verbose=True,
return_verbose=False,
save_predictions=False,
save_clean_images=False,
):
r"""
Overridden.
"""
self._clear_multi_atk_records()
prev_verbose = self.verbose
self.verbose = False
self._accumulate_multi_atk_records = True
for i, attack in enumerate(self.attacks):
self._multi_atk_records.append(0.0)
if return_verbose:
rob_acc, l2, elapsed_time = super().save(
data_loader,
save_path,
verbose,
return_verbose,
save_predictions,
save_clean_images,
)
sr = self._covert_to_success_rates(self._multi_atk_records)
elif verbose:
super().save(
data_loader,
save_path,
verbose,
return_verbose,
save_predictions,
save_clean_images,
)
sr = self._covert_to_success_rates(self._multi_atk_records)
else:
super().save(
data_loader,
save_path,
False,
False,
save_predictions,
save_clean_images,
)
self._clear_multi_atk_records()
self._accumulate_multi_atk_records = False
self.verbose = prev_verbose
if return_verbose:
return rob_acc, sr, l2, elapsed_time
def _save_print(self, progress, rob_acc, l2, elapsed_time, end):
r"""
Overridden.
"""
print(
"- Save progress: %2.2f %% / Robust accuracy: %2.2f %%"
% (progress, rob_acc)
+ " / "
+ self._return_sr_record(self._multi_atk_records)
+ " / L2: %1.5f (%2.3f it/s) \t" % (l2, elapsed_time),
end=end,
)