Source code for torchattacks.wrappers.multiattack

import torch

from ..attack import Attack


[docs] class MultiAttack(Attack): r""" MultiAttack is a class to attack a model with various attacks agains same images and labels. Arguments: model (nn.Module): model to attack. attacks (list): list of attacks. Examples:: >>> atk1 = torchattacks.PGD(model, eps=8/255, alpha=2/255, iters=40, random_start=True) >>> atk2 = torchattacks.PGD(model, eps=8/255, alpha=2/255, iters=40, random_start=True) >>> atk = torchattacks.MultiAttack([atk1, atk2]) >>> adv_images = attack(images, labels) """ def __init__(self, attacks, verbose=False): super().__init__("MultiAttack", attacks[0].model) self.attacks = attacks self.verbose = verbose self.supported_mode = ["default"] self.check_validity() self._accumulate_multi_atk_records = False self._multi_atk_records = [0.0] def check_validity(self): if len(self.attacks) < 2: raise ValueError("More than two attacks should be given.") ids = [id(attack.model) for attack in self.attacks] if len(set(ids)) != 1: raise ValueError( "At least one of attacks is referencing a different model." )
[docs] def forward(self, images, labels): r""" Overridden. """ batch_size = images.shape[0] fails = torch.arange(batch_size).to(self.device) final_images = images.clone().detach().to(self.device) labels = labels.clone().detach().to(self.device) multi_atk_records = [batch_size] for _, attack in enumerate(self.attacks): adv_images = attack(images[fails], labels[fails]) outputs = self.get_logits(adv_images) _, pre = torch.max(outputs.data, 1) corrects = pre == labels[fails] wrongs = ~corrects succeeds = torch.masked_select(fails, wrongs) succeeds_of_fails = torch.masked_select( torch.arange(fails.shape[0]).to(self.device), wrongs ) final_images[succeeds] = adv_images[succeeds_of_fails] fails = torch.masked_select(fails, corrects) multi_atk_records.append(len(fails)) if len(fails) == 0: break if self.verbose: print(self._return_sr_record(multi_atk_records)) if self._accumulate_multi_atk_records: self._update_multi_atk_records(multi_atk_records) return final_images
def _clear_multi_atk_records(self): self._multi_atk_records = [0.0] def _covert_to_success_rates(self, multi_atk_records): sr = [ ((1 - multi_atk_records[i] / multi_atk_records[0]) * 100) for i in range(1, len(multi_atk_records)) ] return sr def _return_sr_record(self, multi_atk_records): sr = self._covert_to_success_rates(multi_atk_records) return "Attack success rate: " + " | ".join(["%2.2f %%" % item for item in sr]) def _update_multi_atk_records(self, multi_atk_records): for i, item in enumerate(multi_atk_records): self._multi_atk_records[i] += item
[docs] def save( self, data_loader, save_path=None, verbose=True, return_verbose=False, save_predictions=False, save_clean_images=False, ): r""" Overridden. """ self._clear_multi_atk_records() prev_verbose = self.verbose self.verbose = False self._accumulate_multi_atk_records = True for i, attack in enumerate(self.attacks): self._multi_atk_records.append(0.0) if return_verbose: rob_acc, l2, elapsed_time = super().save( data_loader, save_path, verbose, return_verbose, save_predictions, save_clean_images, ) sr = self._covert_to_success_rates(self._multi_atk_records) elif verbose: super().save( data_loader, save_path, verbose, return_verbose, save_predictions, save_clean_images, ) sr = self._covert_to_success_rates(self._multi_atk_records) else: super().save( data_loader, save_path, False, False, save_predictions, save_clean_images, ) self._clear_multi_atk_records() self._accumulate_multi_atk_records = False self.verbose = prev_verbose if return_verbose: return rob_acc, sr, l2, elapsed_time
def _save_print(self, progress, rob_acc, l2, elapsed_time, end): r""" Overridden. """ print( "- Save progress: %2.2f %% / Robust accuracy: %2.2f %%" % (progress, rob_acc) + " / " + self._return_sr_record(self._multi_atk_records) + " / L2: %1.5f (%2.3f it/s) \t" % (l2, elapsed_time), end=end, )