Source code for torchattacks.wrappers.lgv

import os
import copy
import itertools
import torch
import torch.nn as nn
from random import shuffle, sample

from ..attack import Attack
from ..attacks.bim import BIM

# fail-safe import of tqdm
try:
    from tqdm import tqdm
except ImportError:

    def tqdm(iterator, *args, **kwargs):
        return iterator


[docs] class LGV(Attack): r""" LGV attack in the paper 'LGV: Boosting Adversarial Example Transferability from Large Geometric Vicinity' [https://arxiv.org/abs/2207.13129] Arguments: model (nn.Module): initial model to attack. trainloader (torch.utils.data.DataLoader): data loader of the unnormalized train set. Must load data in [0, 1]. Be aware that the batch size may impact success rate. The original paper uses a batch size of 256. A different batch-size might require to tune the learning rate. lr (float): constant learning rate to collect models. In the paper, 0.05 is best for ResNet-50. 0.1 seems best for some other architectures. (Default: 0.05) epochs (int): number of epochs. (Default: 10) nb_models_epoch (int): number of models to collect per epoch. (Default: 4) wd (float): weight decay of SGD to collect models. (Default: 1e-4) n_grad (int): number of models to ensemble at each attack iteration. 1 (default) is recommended for efficient iterative attacks. Higher numbers give generally better results at the expense of computations. -1 uses all models (should be used for single-step attacks like FGSM). verbose (bool): print progress. Install the tqdm package for better print. (Default: True) .. note:: If a list of models is not provided to `load_models()`, the attack will start by collecting models along the SGD trajectory for `epochs` epochs with the constant learning rate `lr`. Shape: - images: :math:`(N, C, H, W)` where `N = number of batches`, `C = number of channels`, `H = height` and `W = width`. It must have a range [0, 1]. - labels: :math:`(N)` where each value :math:`y_i` is :math:`0 \leq y_i \leq` `number of labels`. - output: :math:`(N, C, H, W)`. Examples:: >>> attack = torchattacks.LGV(model, trainloader, lr=0.05, epochs=10, nb_models_epoch=4, wd=1e-4, n_grad=1, attack_class=BIM, eps=4/255, alpha=4/255/10, steps=50, verbose=True) >>> attack.collect_models() >>> attack.save_models('./models/lgv/') >>> adv_images = attack(images, labels) """ def __init__( self, model, trainloader, lr=0.05, epochs=10, nb_models_epoch=4, wd=1e-4, n_grad=1, verbose=True, attack_class=BIM, **kwargs, ): model = copy.deepcopy(model) # deep copy the model to train it super().__init__("LGV", model) self.trainloader = trainloader self.lr = lr self.epochs = epochs self.nb_models_epoch = nb_models_epoch self.wd = wd self.n_grad = n_grad self.order = "shuffle" self.attack_class = attack_class self.verbose = verbose self.kwargs_att = kwargs if not isinstance(lr, float) or lr < 0: raise ValueError("lr should be a non-negative float") if not isinstance(epochs, int) or epochs < 0: raise ValueError("epochs should be a non-negative integer") if not isinstance(nb_models_epoch, int) or nb_models_epoch < 0: raise ValueError("nb_models_epoch should be a non-negative integer") self.supported_mode = ["default", "targeted"] self.list_models = [] self.base_attack = None # will be initialized after model collection
[docs] def collect_models(self): """ Collect LGV models along the SGD trajectory """ given_training = self.model.training self.model.train() optimizer = torch.optim.SGD( self.model.parameters(), lr=self.lr, momentum=0.9, weight_decay=self.wd ) loss_fn = nn.CrossEntropyLoss() epoch_frac = 1.0 / self.nb_models_epoch n_batches = int(len(self.trainloader) * epoch_frac) for i_sample in tqdm( range(self.epochs * self.nb_models_epoch), "Collecting models" ): loader = itertools.islice(self.trainloader, n_batches) for j, (input, target) in enumerate(loader): if torch.cuda.is_available(): input = input.to("cuda", non_blocking=True) target = target.to("cuda", non_blocking=True) pred = self.get_logits(input) loss = loss_fn(pred, target) optimizer.zero_grad() loss.backward() optimizer.step() model_sample = copy.deepcopy(self.model) if not given_training: model_sample.eval() self.list_models.append(model_sample) if not given_training: self.model.eval()
[docs] def load_models(self, list_models): """ Load collected models Arguments: list_models (list of nn.Module): list of LGV models. """ if not isinstance(list_models, list): raise ValueError("list_models should be a list of pytorch models") self.list_models = list_models
[docs] def save_models(self, path): """ Save collected models to the `path` directory Arguments: path (str): directory where to save models. """ if len(self.list_models) == 0: raise RuntimeError("Call collect_models() before saving collected models.") os.makedirs(path, exist_ok=True) for i, model in enumerate(self.list_models): path_i = os.path.join(path, f"lgv_model_{i:05}.pt") torch.save({"state_dict": model.state_dict()}, path_i)
[docs] def forward(self, images, labels): r""" Overridden. """ if len(self.list_models) == 0: if self.verbose: print(f"Phase 1: collect models for {self.epochs} epochs") self.collect_models() if not self.base_attack: if self.verbose: print( f"Phase 2: craft adversarial examples with {self.attack_class.__name__}" ) self.list_models = [model.to(self.device) for model in self.list_models] f_model = LightEnsemble( self.list_models, order=self.order, n_grad=self.n_grad ) if self._model_training: f_model.eval() self.base_attack = self.attack_class( model=f_model.to(self.device), **self.kwargs_att ) # set_model_training_mode() to base attack self.base_attack.set_model_training_mode( model_training=self._model_training, batchnorm_training=self._batchnorm_training, dropout_training=self._dropout_training, ) # set targeted to base attack if self.targeted: if self.attack_mode == "targeted": self.base_attack.set_mode_targeted_by_function( target_map_function=self._target_map_function ) elif self.attack_mode == "targeted(least-likely)": self.base_attack.set_mode_targeted_least_likely(kth_min=self._kth_min) elif self.attack_mode == "targeted(random)": self.base_attack.set_mode_targeted_random() else: raise NotImplementedError("Targeted attack mode not supported by LGV.") # set return type to base attack # self.base_attack.set_return_type(self.return_type) adv_images = self.base_attack(images, labels) return adv_images
[docs] class LightEnsemble(nn.Module): def __init__(self, list_models, order="shuffle", n_grad=1): """ Perform a single forward pass to one of the models when call forward() Arguments: list_models (list of nn.Module): list of LGV models. order (str): 'shuffle' draw a model without replacement (default), 'random' draw a model with replacement, None cycle in provided order. n_grad (int): number of models to ensemble in each forward pass (fused logits). Select models according to `order`. If equal to -1, use all models and order is ignored. """ super(LightEnsemble, self).__init__() self.n_models = len(list_models) if self.n_models < 1: raise ValueError("Empty list of models") if not (n_grad > 0 or n_grad == -1): raise ValueError("n_grad should be strictly positive or equal to -1") if order == "shuffle": shuffle(list_models) elif order in [None, "random"]: pass else: raise ValueError("Not supported order") self.models = nn.ModuleList(list_models) self.order = order self.n_grad = n_grad self.f_count = 0 def forward(self, x): if self.n_grad >= self.n_models or self.n_grad < 0: indexes = list(range(self.n_models)) elif self.order == "random": indexes = sample(range(self.n_models), self.n_grad) else: indexes = [ i % self.n_models for i in list(range(self.f_count, self.f_count + self.n_grad)) ] self.f_count += self.n_grad if self.n_grad == 1: x = self.models[indexes[0]](x) else: # clone to make sure x is not changed by inplace methods x_list = [ model(x.clone()) for i, model in enumerate(self.models) if i in indexes ] x = torch.stack(x_list) x = torch.mean(x, dim=0, keepdim=False) return x