Source code for torchattacks.attack

import time
from collections import OrderedDict

import torch
from torch.utils.data import DataLoader, TensorDataset


def wrapper_method(func):
    def wrapper_func(self, *args, **kwargs):
        result = func(self, *args, **kwargs)
        for atk in self.__dict__.get("_attacks").values():
            eval("atk." + func.__name__ + "(*args, **kwargs)")
        return result

    return wrapper_func


[docs] class Attack(object): r""" Base class for all attacks. .. note:: It automatically set device to the device where given model is. It basically changes training mode to eval during attack process. To change this, please see `set_model_training_mode`. """ def __init__(self, name, model): r""" Initializes internal attack state. Arguments: name (str): name of attack. model (torch.nn.Module): model to attack. """ self.attack = name self._attacks = OrderedDict() self.set_model(model) try: self.device = next(model.parameters()).device except Exception: self.device = None print("Failed to set device automatically, please try set_device() manual.") # Controls attack mode. self.attack_mode = "default" self.supported_mode = ["default"] self.targeted = False self._target_map_function = None # Controls when normalization is used. self.normalization_used = None self._normalization_applied = None if self.model.__class__.__name__ == "RobModel": self._set_rmodel_normalization_used(model) # Controls model mode during attack. self._model_training = False self._batchnorm_training = False self._dropout_training = False
[docs] def forward(self, inputs, labels=None, *args, **kwargs): r""" It defines the computation performed at every call. Should be overridden by all subclasses. """ raise NotImplementedError
@wrapper_method def set_model(self, model): self.model = model self.model_name = model.__class__.__name__ def get_logits(self, inputs, labels=None, *args, **kwargs): if self._normalization_applied is False: inputs = self.normalize(inputs) logits = self.model(inputs) return logits @wrapper_method def _set_normalization_applied(self, flag): self._normalization_applied = flag @wrapper_method def set_device(self, device): self.device = device @wrapper_method def _set_rmodel_normalization_used(self, model): r""" Set attack normalization for MAIR [https://github.com/Harry24k/MAIR]. """ mean = getattr(model, "mean", None) std = getattr(model, "std", None) if (mean is not None) and (std is not None): if isinstance(mean, torch.Tensor): mean = mean.cpu().numpy() if isinstance(std, torch.Tensor): std = std.cpu().numpy() if (mean != 0).all() or (std != 1).all(): self.set_normalization_used(mean, std) @wrapper_method def set_normalization_used(self, mean, std): self.normalization_used = {} n_channels = len(mean) mean = torch.tensor(mean).reshape(1, n_channels, 1, 1) std = torch.tensor(std).reshape(1, n_channels, 1, 1) self.normalization_used["mean"] = mean self.normalization_used["std"] = std self._set_normalization_applied(True) def normalize(self, inputs): mean = self.normalization_used["mean"].to(inputs.device) std = self.normalization_used["std"].to(inputs.device) return (inputs - mean) / std def inverse_normalize(self, inputs): mean = self.normalization_used["mean"].to(inputs.device) std = self.normalization_used["std"].to(inputs.device) return inputs * std + mean
[docs] def get_mode(self): r""" Get attack mode. """ return self.attack_mode
@wrapper_method def set_mode_default(self): r""" Set attack mode as default mode. """ self.attack_mode = "default" self.targeted = False print("Attack mode is changed to 'default.'") @wrapper_method def _set_mode_targeted(self, mode, quiet): if "targeted" not in self.supported_mode: raise ValueError("Targeted mode is not supported.") self.targeted = True self.attack_mode = mode if not quiet: print("Attack mode is changed to '%s'." % mode) @wrapper_method def set_mode_targeted_by_function(self, target_map_function, quiet=False): r""" Set attack mode as targeted. Arguments: target_map_function (function): Label mapping function. e.g. lambda inputs, labels:(labels+1)%10. None for using input labels as targeted labels. (Default) quiet (bool): Display information message or not. (Default: False) """ self._set_mode_targeted("targeted(custom)", quiet) self._target_map_function = target_map_function @wrapper_method def set_mode_targeted_random(self, quiet=False): r""" Set attack mode as targeted with random labels. Arguments: quiet (bool): Display information message or not. (Default: False) """ self._set_mode_targeted("targeted(random)", quiet) self._target_map_function = self.get_random_target_label @wrapper_method def set_mode_targeted_least_likely(self, kth_min=1, quiet=False): r""" Set attack mode as targeted with least likely labels. Arguments: kth_min (str): label with the k-th smallest probability used as target labels. (Default: 1) num_classses (str): number of classes. (Default: False) """ self._set_mode_targeted("targeted(least-likely)", quiet) assert kth_min > 0 self._kth_min = kth_min self._target_map_function = self.get_least_likely_label @wrapper_method def set_mode_targeted_by_label(self, quiet=False): r""" Set attack mode as targeted. Arguments: quiet (bool): Display information message or not. (Default: False) .. note:: Use user-supplied labels as target labels. """ self._set_mode_targeted("targeted(label)", quiet) self._target_map_function = "function is a string" @wrapper_method def set_model_training_mode( self, model_training=False, batchnorm_training=False, dropout_training=False ): r""" Set training mode during attack process. Arguments: model_training (bool): True for using training mode for the entire model during attack process. batchnorm_training (bool): True for using training mode for batchnorms during attack process. dropout_training (bool): True for using training mode for dropouts during attack process. .. note:: For RNN-based models, we cannot calculate gradients with eval mode. Thus, it should be changed to the training mode during the attack. """ self._model_training = model_training self._batchnorm_training = batchnorm_training self._dropout_training = dropout_training @wrapper_method def _change_model_mode(self, given_training): if self._model_training: self.model.train() for _, m in self.model.named_modules(): if not self._batchnorm_training: if "BatchNorm" in m.__class__.__name__: m = m.eval() if not self._dropout_training: if "Dropout" in m.__class__.__name__: m = m.eval() else: self.model.eval() @wrapper_method def _recover_model_mode(self, given_training): if given_training: self.model.train()
[docs] def save( self, data_loader, save_path=None, verbose=True, return_verbose=False, save_predictions=False, save_clean_inputs=False, save_type="float", ): r""" Save adversarial inputs as torch.tensor from given torch.utils.data.DataLoader. Arguments: save_path (str): save_path. data_loader (torch.utils.data.DataLoader): data loader. verbose (bool): True for displaying detailed information. (Default: True) return_verbose (bool): True for returning detailed information. (Default: False) save_predictions (bool): True for saving predicted labels (Default: False) save_clean_inputs (bool): True for saving clean inputs (Default: False) """ if save_path is not None: adv_input_list = [] label_list = [] if save_predictions: pred_list = [] if save_clean_inputs: input_list = [] correct = 0 total = 0 l2_distance = [] total_batch = len(data_loader) given_training = self.model.training for step, (inputs, labels) in enumerate(data_loader): start = time.time() adv_inputs = self.__call__(inputs, labels) batch_size = len(inputs) if verbose or return_verbose: with torch.no_grad(): outputs = self.get_output_with_eval_nograd(adv_inputs) # Calculate robust accuracy _, pred = torch.max(outputs.data, 1) total += labels.size(0) right_idx = pred == labels.to(self.device) correct += right_idx.sum() rob_acc = 100 * float(correct) / total # Calculate l2 distance delta = (adv_inputs - inputs.to(self.device)).view( batch_size, -1 ) # nopep8 l2_distance.append( torch.norm(delta[~right_idx], p=2, dim=1) ) # nopep8 l2 = torch.cat(l2_distance).mean().item() # Calculate time computation progress = (step + 1) / total_batch * 100 end = time.time() elapsed_time = end - start if verbose: self._save_print( progress, rob_acc, l2, elapsed_time, end="\r" ) # nopep8 if save_path is not None: adv_input_list.append(adv_inputs.detach().cpu()) label_list.append(labels.detach().cpu()) adv_input_list_cat = torch.cat(adv_input_list, 0) label_list_cat = torch.cat(label_list, 0) save_dict = { "adv_inputs": adv_input_list_cat, "labels": label_list_cat, } # nopep8 if save_predictions: pred_list.append(pred.detach().cpu()) pred_list_cat = torch.cat(pred_list, 0) save_dict["preds"] = pred_list_cat if save_clean_inputs: input_list.append(inputs.detach().cpu()) input_list_cat = torch.cat(input_list, 0) save_dict["clean_inputs"] = input_list_cat if self.normalization_used is not None: save_dict["adv_inputs"] = self.inverse_normalize( save_dict["adv_inputs"] ) # nopep8 if save_clean_inputs: save_dict["clean_inputs"] = self.inverse_normalize( save_dict["clean_inputs"] ) # nopep8 if save_type == "int": save_dict["adv_inputs"] = self.to_type( save_dict["adv_inputs"], "int" ) # nopep8 if save_clean_inputs: save_dict["clean_inputs"] = self.to_type( save_dict["clean_inputs"], "int" ) # nopep8 save_dict["save_type"] = save_type torch.save(save_dict, save_path) # To avoid erasing the printed information. if verbose: self._save_print(progress, rob_acc, l2, elapsed_time, end="\n") if given_training: self.model.train() if return_verbose: return rob_acc, l2, elapsed_time
[docs] @staticmethod def to_type(inputs, type): r""" Return inputs as int if float is given. """ if type == "int": if isinstance(inputs, torch.FloatTensor) or isinstance( inputs, torch.cuda.FloatTensor ): return (inputs * 255).type(torch.uint8) elif type == "float": if isinstance(inputs, torch.ByteTensor) or isinstance( inputs, torch.cuda.ByteTensor ): return inputs.float() / 255 else: raise ValueError(type + " is not a valid type. [Options: float, int]") return inputs
@staticmethod def _save_print(progress, rob_acc, l2, elapsed_time, end): print( "- Save progress: %2.2f %% / Robust accuracy: %2.2f %% / L2: %1.5f (%2.3f it/s) \t" % (progress, rob_acc, l2, elapsed_time), end=end, ) @staticmethod def load( load_path, batch_size=128, shuffle=False, normalize=None, load_predictions=False, load_clean_inputs=False, ): save_dict = torch.load(load_path) keys = ["adv_inputs", "labels"] if load_predictions: keys.append("preds") if load_clean_inputs: keys.append("clean_inputs") if save_dict["save_type"] == "int": save_dict["adv_inputs"] = save_dict["adv_inputs"].float() / 255 if load_clean_inputs: save_dict["clean_inputs"] = ( save_dict["clean_inputs"].float() / 255 ) # nopep8 if normalize is not None: n_channels = len(normalize["mean"]) mean = torch.tensor(normalize["mean"]).reshape(1, n_channels, 1, 1) std = torch.tensor(normalize["std"]).reshape(1, n_channels, 1, 1) save_dict["adv_inputs"] = (save_dict["adv_inputs"] - mean) / std if load_clean_inputs: save_dict["clean_inputs"] = ( save_dict["clean_inputs"] - mean ) / std # nopep8 adv_data = TensorDataset(*[save_dict[key] for key in keys]) adv_loader = DataLoader(adv_data, batch_size=batch_size, shuffle=shuffle) print( "Data is loaded in the following order: [%s]" % (", ".join(keys)) ) # nopep8 return adv_loader @torch.no_grad() def get_output_with_eval_nograd(self, inputs): given_training = self.model.training if given_training: self.model.eval() outputs = self.get_logits(inputs) if given_training: self.model.train() return outputs
[docs] def get_target_label(self, inputs, labels=None): r""" Function for changing the attack mode. Return input labels. """ if self._target_map_function is None: raise ValueError( "target_map_function is not initialized by set_mode_targeted." ) if self.attack_mode == "targeted(label)": target_labels = labels else: target_labels = self._target_map_function(inputs, labels) return target_labels
@torch.no_grad() def get_least_likely_label(self, inputs, labels=None): outputs = self.get_output_with_eval_nograd(inputs) if labels is None: _, labels = torch.max(outputs, dim=1) n_classses = outputs.shape[-1] target_labels = torch.zeros_like(labels) for counter in range(labels.shape[0]): l = list(range(n_classses)) l.remove(labels[counter]) _, t = torch.kthvalue(outputs[counter][l], self._kth_min) target_labels[counter] = l[t] return target_labels.long().to(self.device) @torch.no_grad() def get_random_target_label(self, inputs, labels=None): outputs = self.get_output_with_eval_nograd(inputs) if labels is None: _, labels = torch.max(outputs, dim=1) n_classses = outputs.shape[-1] target_labels = torch.zeros_like(labels) for counter in range(labels.shape[0]): l = list(range(n_classses)) l.remove(labels[counter]) t = (len(l) * torch.rand([1])).long().to(self.device) target_labels[counter] = l[t] return target_labels.long().to(self.device) def __call__(self, inputs, labels=None, *args, **kwargs): given_training = self.model.training self._change_model_mode(given_training) if self._normalization_applied is True: inputs = self.inverse_normalize(inputs) self._set_normalization_applied(False) adv_inputs = self.forward(inputs, labels, *args, **kwargs) # adv_inputs = self.to_type(adv_inputs, self.return_type) adv_inputs = self.normalize(adv_inputs) self._set_normalization_applied(True) else: adv_inputs = self.forward(inputs, labels, *args, **kwargs) # adv_inputs = self.to_type(adv_inputs, self.return_type) self._recover_model_mode(given_training) return adv_inputs def __repr__(self): info = self.__dict__.copy() del_keys = ["model", "attack", "supported_mode"] for key in info.keys(): if key[0] == "_": del_keys.append(key) for key in del_keys: del info[key] info["attack_mode"] = self.attack_mode info["normalization_used"] = ( True if self.normalization_used is not None else False ) return ( self.attack + "(" + ", ".join("{}={}".format(key, val) for key, val in info.items()) + ")" ) def __setattr__(self, name, value): object.__setattr__(self, name, value) attacks = self.__dict__.get("_attacks") # Get all items in iterable items. def get_all_values(items, stack=[]): if items not in stack: stack.append(items) if isinstance(items, list) or isinstance(items, dict): if isinstance(items, dict): items = list(items.keys()) + list(items.values()) for item in items: yield from get_all_values(item, stack) else: if isinstance(items, Attack): yield items else: if isinstance(items, Attack): yield items for num, value in enumerate(get_all_values(value)): attacks[name + "." + str(num)] = value for subname, subvalue in value.__dict__.get("_attacks").items(): attacks[name + "." + subname] = subvalue