Source code for torchattacks.attacks.pixle

from itertools import chain

import numpy as np
import torch
from torch.nn.functional import softmax

from ..attack import Attack


[docs] class Pixle(Attack): r""" Pixle: a fast and effective black-box attack based on rearranging pixels' [https://arxiv.org/abs/2202.02236] Distance Measure : L0 Arguments: model (nn.Module): model to attack. x_dimensions (int or float, or a tuple containing a combination of those): size of the sampled patch along ther x side for each iteration. The integers are considered as fixed number of size, while the float as parcentage of the size. A tuple is used to specify both under and upper bound of the size. (Default: (2, 10)) y_dimensions (int or float, or a tuple containing a combination of those): size of the sampled patch along ther y side for each iteration. The integers are considered as fixed number of size, while the float as parcentage of the size. A tuple is used to specify both under and upper bound of the size. (Default: (2, 10)) pixel_mapping (str): the type of mapping used to move the pixels. Can be: 'random', 'similarity', 'similarity_random', 'distance', 'distance_random' (Default: random) restarts (int): the number of restarts that the algortihm performs. (Default: 20) max_iterations (int): number of iterations to perform for each restart. (Default: 10) update_each_iteration (bool): if the attacked images must be modified after each iteration (True) or after each restart (False). (Default: False) Shape: - images: :math:`(N, C, H, W)` where `N = number of batches`, `C = number of channels`, `H = height` and `W = width`. It must have a range [0, 1]. - labels: :math:`(N)` where each value :math:`y_i` is :math:`0 \leq y_i \leq` `number of labels`. - output: :math:`(N, C, H, W)`. Examples:: >>> attack = torchattacks.Pixle(model, x_dimensions=(0.1, 0.2), restarts=10, max_iterations=50) >>> adv_images = attack(images, labels) """ def __init__( self, model, x_dimensions=(2, 10), y_dimensions=(2, 10), pixel_mapping="random", restarts=20, max_iterations=10, update_each_iteration=False, ): super().__init__("Pixle", model) if restarts < 0 or not isinstance(restarts, int): raise ValueError( "restarts must be and integer >= 0 " "({})".format(restarts) ) self.update_each_iteration = update_each_iteration self.max_patches = max_iterations self.restarts = restarts self.pixel_mapping = pixel_mapping.lower() if self.pixel_mapping not in [ "random", "similarity", "similarity_random", "distance", "distance_random", ]: raise ValueError( "pixel_mapping must be one of [random, similarity," "similarity_random, distance, distance_random]" " ({})".format(self.pixel_mapping) ) if isinstance(y_dimensions, (int, float)): y_dimensions = [y_dimensions, y_dimensions] if isinstance(x_dimensions, (int, float)): x_dimensions = [x_dimensions, x_dimensions] if not all( [ (isinstance(d, (int)) and d > 0) or (isinstance(d, float) and 0 <= d <= 1) for d in chain(y_dimensions, x_dimensions) ] ): raise ValueError( "dimensions of first patch must contains integers" " or floats in [0, 1]" " ({})".format(y_dimensions) ) self.p1_x_dimensions = x_dimensions self.p1_y_dimensions = y_dimensions self.supported_mode = ["default", "targeted"]
[docs] def forward(self, images, labels): if not self.update_each_iteration: adv_images = self.restart_forward(images, labels) return adv_images else: adv_images = self.iterative_forward(images, labels) return adv_images
def restart_forward(self, images, labels): if len(images.shape) == 3: images = images.unsqueeze(0) if self.targeted: labels = self.get_target_label(images, labels) x_bounds = tuple( [ max(1, d if isinstance(d, int) else round(images.size(3) * d)) for d in self.p1_x_dimensions ] ) y_bounds = tuple( [ max(1, d if isinstance(d, int) else round(images.size(2) * d)) for d in self.p1_y_dimensions ] ) adv_images = [] images = images.clone().detach().to(self.device) labels = labels.clone().detach().to(self.device) bs, _, _, _ = images.shape for idx in range(bs): image, label = images[idx : idx + 1], labels[idx : idx + 1] best_image = image.clone() pert_image = image.clone() loss, callback = self._get_fun(image, label, target_attack=self.targeted) best_solution = None best_p = loss(solution=image, solution_as_perturbed=True) image_probs = [best_p] it = 0 for r in range(self.restarts): stop = False for it in range(self.max_patches): (x, y), (x_offset, y_offset) = self.get_patch_coordinates( image=image, x_bounds=x_bounds, y_bounds=y_bounds ) destinations = self.get_pixel_mapping( image, x, x_offset, y, y_offset, destination_image=best_image ) solution = [x, y, x_offset, y_offset] + destinations pert_image = self._perturb( source=image, destination=best_image, solution=solution ) p = loss(solution=pert_image, solution_as_perturbed=True) if p < best_p: best_p = p best_solution = pert_image image_probs.append(best_p) if callback(pert_image, None, True): best_solution = pert_image stop = True break if best_solution is None: best_image = pert_image else: best_image = best_solution if stop: break adv_images.append(best_image) adv_images = torch.cat(adv_images) return adv_images def iterative_forward(self, images, labels): assert len(images.shape) == 3 or ( len(images.shape) == 4 and images.size(0) == 1 ) if len(images.shape) == 3: images = images.unsqueeze(0) if self.targeted: labels = self.get_target_label(images, labels) x_bounds = tuple( [ max(1, d if isinstance(d, int) else round(images.size(3) * d)) for d in self.p1_x_dimensions ] ) y_bounds = tuple( [ max(1, d if isinstance(d, int) else round(images.size(2) * d)) for d in self.p1_y_dimensions ] ) adv_images = [] images = images.clone().detach().to(self.device) labels = labels.clone().detach().to(self.device) bs, _, _, _ = images.shape for idx in range(bs): image, label = images[idx : idx + 1], labels[idx : idx + 1] best_image = image.clone() loss, callback = self._get_fun(image, label, target_attack=self.targeted) best_p = loss(solution=image, solution_as_perturbed=True) image_probs = [best_p] for it in range(self.max_patches): (x, y), (x_offset, y_offset) = self.get_patch_coordinates( image=image, x_bounds=x_bounds, y_bounds=y_bounds ) destinations = self.get_pixel_mapping( image, x, x_offset, y, y_offset, destination_image=best_image ) solution = [x, y, x_offset, y_offset] + destinations pert_image = self._perturb( source=image, destination=best_image, solution=solution ) p = loss(solution=pert_image, solution_as_perturbed=True) if p < best_p: best_p = p best_image = pert_image image_probs.append(best_p) if callback(pert_image, None, True): best_image = pert_image break adv_images.append(best_image) adv_images = torch.cat(adv_images) return adv_images def _get_prob(self, image): out = self.get_logits(image.to(self.device)) prob = softmax(out, dim=1) return prob.detach().cpu().numpy() def loss(self, img, label, target_attack=False): p = self._get_prob(img) p = p[np.arange(len(p)), label] if target_attack: p = 1 - p return p.sum() def get_patch_coordinates(self, image, x_bounds, y_bounds): c, h, w = image.shape[1:] x, y = np.random.uniform(0, 1, 2) x_offset = np.random.randint(x_bounds[0], x_bounds[1] + 1) y_offset = np.random.randint(y_bounds[0], y_bounds[1] + 1) x, y = int(x * (w - 1)), int(y * (h - 1)) if x + x_offset > w: x_offset = w - x if y + y_offset > h: y_offset = h - y return (x, y), (x_offset, y_offset) def get_pixel_mapping( self, source_image, x, x_offset, y, y_offset, destination_image=None ): if destination_image is None: destination_image = source_image destinations = [] c, h, w = source_image.shape[1:] source_image = source_image[0] if self.pixel_mapping == "random": for i in range(x_offset): for j in range(y_offset): dx, dy = np.random.uniform(0, 1, 2) dx, dy = int(dx * (w - 1)), int(dy * (h - 1)) destinations.append([dx, dy]) else: for i in np.arange(y, y + y_offset): for j in np.arange(x, x + x_offset): pixel = source_image[:, i : i + 1, j : j + 1] diff = destination_image - pixel diff = diff[0].abs().mean(0).view(-1) if "similarity" in self.pixel_mapping: diff = 1 / (1 + diff) diff[diff == 1] = 0 probs = torch.softmax(diff, 0).cpu().numpy() indexes = np.arange(len(diff)) pair = None linear_iter = iter( sorted( zip(indexes, probs), key=lambda pit: pit[1], reverse=True ) ) while True: if "random" in self.pixel_mapping: index = np.random.choice(indexes, p=probs) else: index = next(linear_iter)[0] _y, _x = np.unravel_index(index, (h, w)) if _y == i and _x == j: continue pair = (_x, _y) break destinations.append(pair) return destinations def _get_fun(self, img, label, target_attack=False): img = img.to(self.device) if isinstance(label, torch.Tensor): label = label.cpu().numpy() @torch.no_grad() def func(solution, destination=None, solution_as_perturbed=False, **kwargs): if not solution_as_perturbed: pert_image = self._perturb( source=img, destination=destination, solution=solution ) else: pert_image = solution p = self._get_prob(pert_image) p = p[np.arange(len(p)), label] if target_attack: p = 1 - p return p.sum() @torch.no_grad() def callback(solution, destination=None, solution_as_perturbed=False, **kwargs): if not solution_as_perturbed: pert_image = self._perturb( source=img, destination=destination, solution=solution ) else: pert_image = solution p = self._get_prob(pert_image)[0] mx = np.argmax(p) if target_attack: return mx == label else: return mx != label return func, callback def _perturb(self, source, solution, destination=None): if destination is None: destination = source c, h, w = source.shape[1:] x, y, xl, yl = solution[:4] destinations = solution[4:] source_pixels = np.ix_(range(c), np.arange(y, y + yl), np.arange(x, x + xl)) indexes = torch.tensor(destinations) destination = destination.clone().detach().to(self.device) s = source[0][source_pixels].view(c, -1) destination[0, :, indexes[:, 0], indexes[:, 1]] = s return destination