import time
from collections import OrderedDict
import torch
from torch.utils.data import DataLoader, TensorDataset
def wrapper_method(func):
def wrapper_func(self, *args, **kwargs):
result = func(self, *args, **kwargs)
for atk in self.__dict__.get("_attacks").values():
eval("atk." + func.__name__ + "(*args, **kwargs)")
return result
return wrapper_func
[docs]
class Attack(object):
r"""
Base class for all attacks.
.. note::
It automatically set device to the device where given model is.
It basically changes training mode to eval during attack process.
To change this, please see `set_model_training_mode`.
"""
def __init__(self, name, model):
r"""
Initializes internal attack state.
Arguments:
name (str): name of attack.
model (torch.nn.Module): model to attack.
"""
self.attack = name
self._attacks = OrderedDict()
self.set_model(model)
try:
self.device = next(model.parameters()).device
except Exception:
self.device = None
print("Failed to set device automatically, please try set_device() manual.")
# Controls attack mode.
self.attack_mode = "default"
self.supported_mode = ["default"]
self.targeted = False
self._target_map_function = None
# Controls when normalization is used.
self.normalization_used = None
self._normalization_applied = None
if self.model.__class__.__name__ == "RobModel":
self._set_rmodel_normalization_used(model)
# Controls model mode during attack.
self._model_training = False
self._batchnorm_training = False
self._dropout_training = False
[docs]
def forward(self, inputs, labels=None, *args, **kwargs):
r"""
It defines the computation performed at every call.
Should be overridden by all subclasses.
"""
raise NotImplementedError
@wrapper_method
def set_model(self, model):
self.model = model
self.model_name = model.__class__.__name__
def get_logits(self, inputs, labels=None, *args, **kwargs):
if self._normalization_applied is False:
inputs = self.normalize(inputs)
logits = self.model(inputs)
return logits
@wrapper_method
def _set_normalization_applied(self, flag):
self._normalization_applied = flag
@wrapper_method
def set_device(self, device):
self.device = device
@wrapper_method
def _set_rmodel_normalization_used(self, model):
r"""
Set attack normalization for MAIR [https://github.com/Harry24k/MAIR].
"""
mean = getattr(model, "mean", None)
std = getattr(model, "std", None)
if (mean is not None) and (std is not None):
if isinstance(mean, torch.Tensor):
mean = mean.cpu().numpy()
if isinstance(std, torch.Tensor):
std = std.cpu().numpy()
if (mean != 0).all() or (std != 1).all():
self.set_normalization_used(mean, std)
@wrapper_method
def set_normalization_used(self, mean, std):
self.normalization_used = {}
n_channels = len(mean)
mean = torch.tensor(mean).reshape(1, n_channels, 1, 1)
std = torch.tensor(std).reshape(1, n_channels, 1, 1)
self.normalization_used["mean"] = mean
self.normalization_used["std"] = std
self._set_normalization_applied(True)
def normalize(self, inputs):
mean = self.normalization_used["mean"].to(inputs.device)
std = self.normalization_used["std"].to(inputs.device)
return (inputs - mean) / std
def inverse_normalize(self, inputs):
mean = self.normalization_used["mean"].to(inputs.device)
std = self.normalization_used["std"].to(inputs.device)
return inputs * std + mean
[docs]
def get_mode(self):
r"""
Get attack mode.
"""
return self.attack_mode
@wrapper_method
def set_mode_default(self):
r"""
Set attack mode as default mode.
"""
self.attack_mode = "default"
self.targeted = False
print("Attack mode is changed to 'default.'")
@wrapper_method
def _set_mode_targeted(self, mode, quiet):
if "targeted" not in self.supported_mode:
raise ValueError("Targeted mode is not supported.")
self.targeted = True
self.attack_mode = mode
if not quiet:
print("Attack mode is changed to '%s'." % mode)
@wrapper_method
def set_mode_targeted_by_function(self, target_map_function, quiet=False):
r"""
Set attack mode as targeted.
Arguments:
target_map_function (function): Label mapping function.
e.g. lambda inputs, labels:(labels+1)%10.
None for using input labels as targeted labels. (Default)
quiet (bool): Display information message or not. (Default: False)
"""
self._set_mode_targeted("targeted(custom)", quiet)
self._target_map_function = target_map_function
@wrapper_method
def set_mode_targeted_random(self, quiet=False):
r"""
Set attack mode as targeted with random labels.
Arguments:
quiet (bool): Display information message or not. (Default: False)
"""
self._set_mode_targeted("targeted(random)", quiet)
self._target_map_function = self.get_random_target_label
@wrapper_method
def set_mode_targeted_least_likely(self, kth_min=1, quiet=False):
r"""
Set attack mode as targeted with least likely labels.
Arguments:
kth_min (str): label with the k-th smallest probability used as target labels. (Default: 1)
num_classses (str): number of classes. (Default: False)
"""
self._set_mode_targeted("targeted(least-likely)", quiet)
assert kth_min > 0
self._kth_min = kth_min
self._target_map_function = self.get_least_likely_label
@wrapper_method
def set_mode_targeted_by_label(self, quiet=False):
r"""
Set attack mode as targeted.
Arguments:
quiet (bool): Display information message or not. (Default: False)
.. note::
Use user-supplied labels as target labels.
"""
self._set_mode_targeted("targeted(label)", quiet)
self._target_map_function = "function is a string"
@wrapper_method
def set_model_training_mode(
self, model_training=False, batchnorm_training=False, dropout_training=False
):
r"""
Set training mode during attack process.
Arguments:
model_training (bool): True for using training mode for the entire model during attack process.
batchnorm_training (bool): True for using training mode for batchnorms during attack process.
dropout_training (bool): True for using training mode for dropouts during attack process.
.. note::
For RNN-based models, we cannot calculate gradients with eval mode.
Thus, it should be changed to the training mode during the attack.
"""
self._model_training = model_training
self._batchnorm_training = batchnorm_training
self._dropout_training = dropout_training
@wrapper_method
def _change_model_mode(self, given_training):
if self._model_training:
self.model.train()
for _, m in self.model.named_modules():
if not self._batchnorm_training:
if "BatchNorm" in m.__class__.__name__:
m = m.eval()
if not self._dropout_training:
if "Dropout" in m.__class__.__name__:
m = m.eval()
else:
self.model.eval()
@wrapper_method
def _recover_model_mode(self, given_training):
if given_training:
self.model.train()
[docs]
def save(
self,
data_loader,
save_path=None,
verbose=True,
return_verbose=False,
save_predictions=False,
save_clean_inputs=False,
save_type="float",
):
r"""
Save adversarial inputs as torch.tensor from given torch.utils.data.DataLoader.
Arguments:
save_path (str): save_path.
data_loader (torch.utils.data.DataLoader): data loader.
verbose (bool): True for displaying detailed information. (Default: True)
return_verbose (bool): True for returning detailed information. (Default: False)
save_predictions (bool): True for saving predicted labels (Default: False)
save_clean_inputs (bool): True for saving clean inputs (Default: False)
"""
if save_path is not None:
adv_input_list = []
label_list = []
if save_predictions:
pred_list = []
if save_clean_inputs:
input_list = []
correct = 0
total = 0
l2_distance = []
total_batch = len(data_loader)
given_training = self.model.training
for step, (inputs, labels) in enumerate(data_loader):
start = time.time()
adv_inputs = self.__call__(inputs, labels)
batch_size = len(inputs)
if verbose or return_verbose:
with torch.no_grad():
outputs = self.get_output_with_eval_nograd(adv_inputs)
# Calculate robust accuracy
_, pred = torch.max(outputs.data, 1)
total += labels.size(0)
right_idx = pred == labels.to(self.device)
correct += right_idx.sum()
rob_acc = 100 * float(correct) / total
# Calculate l2 distance
delta = (adv_inputs - inputs.to(self.device)).view(
batch_size, -1
) # nopep8
l2_distance.append(
torch.norm(delta[~right_idx], p=2, dim=1)
) # nopep8
l2 = torch.cat(l2_distance).mean().item()
# Calculate time computation
progress = (step + 1) / total_batch * 100
end = time.time()
elapsed_time = end - start
if verbose:
self._save_print(
progress, rob_acc, l2, elapsed_time, end="\r"
) # nopep8
if save_path is not None:
adv_input_list.append(adv_inputs.detach().cpu())
label_list.append(labels.detach().cpu())
adv_input_list_cat = torch.cat(adv_input_list, 0)
label_list_cat = torch.cat(label_list, 0)
save_dict = {
"adv_inputs": adv_input_list_cat,
"labels": label_list_cat,
} # nopep8
if save_predictions:
pred_list.append(pred.detach().cpu())
pred_list_cat = torch.cat(pred_list, 0)
save_dict["preds"] = pred_list_cat
if save_clean_inputs:
input_list.append(inputs.detach().cpu())
input_list_cat = torch.cat(input_list, 0)
save_dict["clean_inputs"] = input_list_cat
if self.normalization_used is not None:
save_dict["adv_inputs"] = self.inverse_normalize(
save_dict["adv_inputs"]
) # nopep8
if save_clean_inputs:
save_dict["clean_inputs"] = self.inverse_normalize(
save_dict["clean_inputs"]
) # nopep8
if save_type == "int":
save_dict["adv_inputs"] = self.to_type(
save_dict["adv_inputs"], "int"
) # nopep8
if save_clean_inputs:
save_dict["clean_inputs"] = self.to_type(
save_dict["clean_inputs"], "int"
) # nopep8
save_dict["save_type"] = save_type
torch.save(save_dict, save_path)
# To avoid erasing the printed information.
if verbose:
self._save_print(progress, rob_acc, l2, elapsed_time, end="\n")
if given_training:
self.model.train()
if return_verbose:
return rob_acc, l2, elapsed_time
[docs]
@staticmethod
def to_type(inputs, type):
r"""
Return inputs as int if float is given.
"""
if type == "int":
if isinstance(inputs, torch.FloatTensor) or isinstance(
inputs, torch.cuda.FloatTensor
):
return (inputs * 255).type(torch.uint8)
elif type == "float":
if isinstance(inputs, torch.ByteTensor) or isinstance(
inputs, torch.cuda.ByteTensor
):
return inputs.float() / 255
else:
raise ValueError(type + " is not a valid type. [Options: float, int]")
return inputs
@staticmethod
def _save_print(progress, rob_acc, l2, elapsed_time, end):
print(
"- Save progress: %2.2f %% / Robust accuracy: %2.2f %% / L2: %1.5f (%2.3f it/s) \t"
% (progress, rob_acc, l2, elapsed_time),
end=end,
)
@staticmethod
def load(
load_path,
batch_size=128,
shuffle=False,
normalize=None,
load_predictions=False,
load_clean_inputs=False,
):
save_dict = torch.load(load_path)
keys = ["adv_inputs", "labels"]
if load_predictions:
keys.append("preds")
if load_clean_inputs:
keys.append("clean_inputs")
if save_dict["save_type"] == "int":
save_dict["adv_inputs"] = save_dict["adv_inputs"].float() / 255
if load_clean_inputs:
save_dict["clean_inputs"] = (
save_dict["clean_inputs"].float() / 255
) # nopep8
if normalize is not None:
n_channels = len(normalize["mean"])
mean = torch.tensor(normalize["mean"]).reshape(1, n_channels, 1, 1)
std = torch.tensor(normalize["std"]).reshape(1, n_channels, 1, 1)
save_dict["adv_inputs"] = (save_dict["adv_inputs"] - mean) / std
if load_clean_inputs:
save_dict["clean_inputs"] = (
save_dict["clean_inputs"] - mean
) / std # nopep8
adv_data = TensorDataset(*[save_dict[key] for key in keys])
adv_loader = DataLoader(adv_data, batch_size=batch_size, shuffle=shuffle)
print(
"Data is loaded in the following order: [%s]" % (", ".join(keys))
) # nopep8
return adv_loader
@torch.no_grad()
def get_output_with_eval_nograd(self, inputs):
given_training = self.model.training
if given_training:
self.model.eval()
outputs = self.get_logits(inputs)
if given_training:
self.model.train()
return outputs
[docs]
def get_target_label(self, inputs, labels=None):
r"""
Function for changing the attack mode.
Return input labels.
"""
if self._target_map_function is None:
raise ValueError(
"target_map_function is not initialized by set_mode_targeted."
)
if self.attack_mode == "targeted(label)":
target_labels = labels
else:
target_labels = self._target_map_function(inputs, labels)
return target_labels
@torch.no_grad()
def get_least_likely_label(self, inputs, labels=None):
outputs = self.get_output_with_eval_nograd(inputs)
if labels is None:
_, labels = torch.max(outputs, dim=1)
n_classses = outputs.shape[-1]
target_labels = torch.zeros_like(labels)
for counter in range(labels.shape[0]):
l = list(range(n_classses))
l.remove(labels[counter])
_, t = torch.kthvalue(outputs[counter][l], self._kth_min)
target_labels[counter] = l[t]
return target_labels.long().to(self.device)
@torch.no_grad()
def get_random_target_label(self, inputs, labels=None):
outputs = self.get_output_with_eval_nograd(inputs)
if labels is None:
_, labels = torch.max(outputs, dim=1)
n_classses = outputs.shape[-1]
target_labels = torch.zeros_like(labels)
for counter in range(labels.shape[0]):
l = list(range(n_classses))
l.remove(labels[counter])
t = (len(l) * torch.rand([1])).long().to(self.device)
target_labels[counter] = l[t]
return target_labels.long().to(self.device)
def __call__(self, inputs, labels=None, *args, **kwargs):
given_training = self.model.training
self._change_model_mode(given_training)
if self._normalization_applied is True:
inputs = self.inverse_normalize(inputs)
self._set_normalization_applied(False)
adv_inputs = self.forward(inputs, labels, *args, **kwargs)
# adv_inputs = self.to_type(adv_inputs, self.return_type)
adv_inputs = self.normalize(adv_inputs)
self._set_normalization_applied(True)
else:
adv_inputs = self.forward(inputs, labels, *args, **kwargs)
# adv_inputs = self.to_type(adv_inputs, self.return_type)
self._recover_model_mode(given_training)
return adv_inputs
def __repr__(self):
info = self.__dict__.copy()
del_keys = ["model", "attack", "supported_mode"]
for key in info.keys():
if key[0] == "_":
del_keys.append(key)
for key in del_keys:
del info[key]
info["attack_mode"] = self.attack_mode
info["normalization_used"] = (
True if self.normalization_used is not None else False
)
return (
self.attack
+ "("
+ ", ".join("{}={}".format(key, val) for key, val in info.items())
+ ")"
)
def __setattr__(self, name, value):
object.__setattr__(self, name, value)
attacks = self.__dict__.get("_attacks")
# Get all items in iterable items.
def get_all_values(items, stack=[]):
if items not in stack:
stack.append(items)
if isinstance(items, list) or isinstance(items, dict):
if isinstance(items, dict):
items = list(items.keys()) + list(items.values())
for item in items:
yield from get_all_values(item, stack)
else:
if isinstance(items, Attack):
yield items
else:
if isinstance(items, Attack):
yield items
for num, value in enumerate(get_all_values(value)):
attacks[name + "." + str(num)] = value
for subname, subvalue in value.__dict__.get("_attacks").items():
attacks[name + "." + subname] = subvalue