from itertools import chain
import numpy as np
import torch
from torch.nn.functional import softmax
from ..attack import Attack
[docs]
class Pixle(Attack):
r"""
Pixle: a fast and effective black-box attack based on rearranging pixels'
[https://arxiv.org/abs/2202.02236]
Distance Measure : L0
Arguments:
model (nn.Module): model to attack.
x_dimensions (int or float, or a tuple containing a combination of those): size of the sampled patch along ther x side for each iteration. The integers are considered as fixed number of size,
while the float as parcentage of the size. A tuple is used to specify both under and upper bound of the size. (Default: (2, 10))
y_dimensions (int or float, or a tuple containing a combination of those): size of the sampled patch along ther y side for each iteration. The integers are considered as fixed number of size,
while the float as parcentage of the size. A tuple is used to specify both under and upper bound of the size. (Default: (2, 10))
pixel_mapping (str): the type of mapping used to move the pixels. Can be: 'random', 'similarity', 'similarity_random', 'distance', 'distance_random' (Default: random)
restarts (int): the number of restarts that the algortihm performs. (Default: 20)
max_iterations (int): number of iterations to perform for each restart. (Default: 10)
update_each_iteration (bool): if the attacked images must be modified after each iteration (True) or after each restart (False). (Default: False)
Shape:
- images: :math:`(N, C, H, W)` where `N = number of batches`, `C = number of channels`, `H = height` and `W = width`. It must have a range [0, 1].
- labels: :math:`(N)` where each value :math:`y_i` is :math:`0 \leq y_i \leq` `number of labels`.
- output: :math:`(N, C, H, W)`.
Examples::
>>> attack = torchattacks.Pixle(model, x_dimensions=(0.1, 0.2), restarts=10, max_iterations=50)
>>> adv_images = attack(images, labels)
"""
def __init__(
self,
model,
x_dimensions=(2, 10),
y_dimensions=(2, 10),
pixel_mapping="random",
restarts=20,
max_iterations=10,
update_each_iteration=False,
):
super().__init__("Pixle", model)
if restarts < 0 or not isinstance(restarts, int):
raise ValueError(
"restarts must be and integer >= 0 " "({})".format(restarts)
)
self.update_each_iteration = update_each_iteration
self.max_patches = max_iterations
self.restarts = restarts
self.pixel_mapping = pixel_mapping.lower()
if self.pixel_mapping not in [
"random",
"similarity",
"similarity_random",
"distance",
"distance_random",
]:
raise ValueError(
"pixel_mapping must be one of [random, similarity,"
"similarity_random, distance, distance_random]"
" ({})".format(self.pixel_mapping)
)
if isinstance(y_dimensions, (int, float)):
y_dimensions = [y_dimensions, y_dimensions]
if isinstance(x_dimensions, (int, float)):
x_dimensions = [x_dimensions, x_dimensions]
if not all(
[
(isinstance(d, (int)) and d > 0)
or (isinstance(d, float) and 0 <= d <= 1)
for d in chain(y_dimensions, x_dimensions)
]
):
raise ValueError(
"dimensions of first patch must contains integers"
" or floats in [0, 1]"
" ({})".format(y_dimensions)
)
self.p1_x_dimensions = x_dimensions
self.p1_y_dimensions = y_dimensions
self.supported_mode = ["default", "targeted"]
[docs]
def forward(self, images, labels):
if not self.update_each_iteration:
adv_images = self.restart_forward(images, labels)
return adv_images
else:
adv_images = self.iterative_forward(images, labels)
return adv_images
def restart_forward(self, images, labels):
if len(images.shape) == 3:
images = images.unsqueeze(0)
if self.targeted:
labels = self.get_target_label(images, labels)
x_bounds = tuple(
[
max(1, d if isinstance(d, int) else round(images.size(3) * d))
for d in self.p1_x_dimensions
]
)
y_bounds = tuple(
[
max(1, d if isinstance(d, int) else round(images.size(2) * d))
for d in self.p1_y_dimensions
]
)
adv_images = []
images = images.clone().detach().to(self.device)
labels = labels.clone().detach().to(self.device)
bs, _, _, _ = images.shape
for idx in range(bs):
image, label = images[idx : idx + 1], labels[idx : idx + 1]
best_image = image.clone()
pert_image = image.clone()
loss, callback = self._get_fun(image, label, target_attack=self.targeted)
best_solution = None
best_p = loss(solution=image, solution_as_perturbed=True)
image_probs = [best_p]
it = 0
for r in range(self.restarts):
stop = False
for it in range(self.max_patches):
(x, y), (x_offset, y_offset) = self.get_patch_coordinates(
image=image, x_bounds=x_bounds, y_bounds=y_bounds
)
destinations = self.get_pixel_mapping(
image, x, x_offset, y, y_offset, destination_image=best_image
)
solution = [x, y, x_offset, y_offset] + destinations
pert_image = self._perturb(
source=image, destination=best_image, solution=solution
)
p = loss(solution=pert_image, solution_as_perturbed=True)
if p < best_p:
best_p = p
best_solution = pert_image
image_probs.append(best_p)
if callback(pert_image, None, True):
best_solution = pert_image
stop = True
break
if best_solution is None:
best_image = pert_image
else:
best_image = best_solution
if stop:
break
adv_images.append(best_image)
adv_images = torch.cat(adv_images)
return adv_images
def iterative_forward(self, images, labels):
assert len(images.shape) == 3 or (
len(images.shape) == 4 and images.size(0) == 1
)
if len(images.shape) == 3:
images = images.unsqueeze(0)
if self.targeted:
labels = self.get_target_label(images, labels)
x_bounds = tuple(
[
max(1, d if isinstance(d, int) else round(images.size(3) * d))
for d in self.p1_x_dimensions
]
)
y_bounds = tuple(
[
max(1, d if isinstance(d, int) else round(images.size(2) * d))
for d in self.p1_y_dimensions
]
)
adv_images = []
images = images.clone().detach().to(self.device)
labels = labels.clone().detach().to(self.device)
bs, _, _, _ = images.shape
for idx in range(bs):
image, label = images[idx : idx + 1], labels[idx : idx + 1]
best_image = image.clone()
loss, callback = self._get_fun(image, label, target_attack=self.targeted)
best_p = loss(solution=image, solution_as_perturbed=True)
image_probs = [best_p]
for it in range(self.max_patches):
(x, y), (x_offset, y_offset) = self.get_patch_coordinates(
image=image, x_bounds=x_bounds, y_bounds=y_bounds
)
destinations = self.get_pixel_mapping(
image, x, x_offset, y, y_offset, destination_image=best_image
)
solution = [x, y, x_offset, y_offset] + destinations
pert_image = self._perturb(
source=image, destination=best_image, solution=solution
)
p = loss(solution=pert_image, solution_as_perturbed=True)
if p < best_p:
best_p = p
best_image = pert_image
image_probs.append(best_p)
if callback(pert_image, None, True):
best_image = pert_image
break
adv_images.append(best_image)
adv_images = torch.cat(adv_images)
return adv_images
def _get_prob(self, image):
out = self.get_logits(image.to(self.device))
prob = softmax(out, dim=1)
return prob.detach().cpu().numpy()
def loss(self, img, label, target_attack=False):
p = self._get_prob(img)
p = p[np.arange(len(p)), label]
if target_attack:
p = 1 - p
return p.sum()
def get_patch_coordinates(self, image, x_bounds, y_bounds):
c, h, w = image.shape[1:]
x, y = np.random.uniform(0, 1, 2)
x_offset = np.random.randint(x_bounds[0], x_bounds[1] + 1)
y_offset = np.random.randint(y_bounds[0], y_bounds[1] + 1)
x, y = int(x * (w - 1)), int(y * (h - 1))
if x + x_offset > w:
x_offset = w - x
if y + y_offset > h:
y_offset = h - y
return (x, y), (x_offset, y_offset)
def get_pixel_mapping(
self, source_image, x, x_offset, y, y_offset, destination_image=None
):
if destination_image is None:
destination_image = source_image
destinations = []
c, h, w = source_image.shape[1:]
source_image = source_image[0]
if self.pixel_mapping == "random":
for i in range(x_offset):
for j in range(y_offset):
dx, dy = np.random.uniform(0, 1, 2)
dx, dy = int(dx * (w - 1)), int(dy * (h - 1))
destinations.append([dx, dy])
else:
for i in np.arange(y, y + y_offset):
for j in np.arange(x, x + x_offset):
pixel = source_image[:, i : i + 1, j : j + 1]
diff = destination_image - pixel
diff = diff[0].abs().mean(0).view(-1)
if "similarity" in self.pixel_mapping:
diff = 1 / (1 + diff)
diff[diff == 1] = 0
probs = torch.softmax(diff, 0).cpu().numpy()
indexes = np.arange(len(diff))
pair = None
linear_iter = iter(
sorted(
zip(indexes, probs), key=lambda pit: pit[1], reverse=True
)
)
while True:
if "random" in self.pixel_mapping:
index = np.random.choice(indexes, p=probs)
else:
index = next(linear_iter)[0]
_y, _x = np.unravel_index(index, (h, w))
if _y == i and _x == j:
continue
pair = (_x, _y)
break
destinations.append(pair)
return destinations
def _get_fun(self, img, label, target_attack=False):
img = img.to(self.device)
if isinstance(label, torch.Tensor):
label = label.cpu().numpy()
@torch.no_grad()
def func(solution, destination=None, solution_as_perturbed=False, **kwargs):
if not solution_as_perturbed:
pert_image = self._perturb(
source=img, destination=destination, solution=solution
)
else:
pert_image = solution
p = self._get_prob(pert_image)
p = p[np.arange(len(p)), label]
if target_attack:
p = 1 - p
return p.sum()
@torch.no_grad()
def callback(solution, destination=None, solution_as_perturbed=False, **kwargs):
if not solution_as_perturbed:
pert_image = self._perturb(
source=img, destination=destination, solution=solution
)
else:
pert_image = solution
p = self._get_prob(pert_image)[0]
mx = np.argmax(p)
if target_attack:
return mx == label
else:
return mx != label
return func, callback
def _perturb(self, source, solution, destination=None):
if destination is None:
destination = source
c, h, w = source.shape[1:]
x, y, xl, yl = solution[:4]
destinations = solution[4:]
source_pixels = np.ix_(range(c), np.arange(y, y + yl), np.arange(x, x + xl))
indexes = torch.tensor(destinations)
destination = destination.clone().detach().to(self.device)
s = source[0][source_pixels].view(c, -1)
destination[0, :, indexes[:, 0], indexes[:, 1]] = s
return destination