Trusted-AI/adversarial-robustness-toolbox

Regarding duplicated queries in the Square Attack

Lodour opened this issue · 2 comments

Lodour commented

Is your feature request related to a problem? Please describe.
I was benchmarking the performance of some black-box attacks and noticed that the Square Attack issued about 50% duplicated queries. This occurs because most queries are sent twice in a subtle way; see the code below.

Describe the solution you'd like
Update the implementation to reuse the last prediction outputs.

I am happy to send a PR if you find this performance improvement useful.

Describe alternatives you've considered
N/A

Additional context
For example, x_adv is passed to self.estimator.predict at L347, and its subset x_robust = x_adv[sample_is_robust] is immediately passed to self.loss, which boils down to a duplicated call of self.estimator.predict on the same inputs.

# Determine correctly predicted samples
y_pred = self.estimator.predict(x_adv, batch_size=self.batch_size)
sample_is_robust = np.logical_not(self.adv_criterion(y_pred, y))
if np.sum(sample_is_robust) == 0: # pragma: no cover
break
x_robust = x_adv[sample_is_robust]
x_init = x[sample_is_robust]
y_robust = y[sample_is_robust]
sample_loss_init = self.loss(x_robust, y_robust)

Hi @Lodour Thank you for using ART! How does you proposed solution with the same functionality look like?

Lodour commented

My workaround is passing y_pred to self.loss, but this only handles the default loss function self._get_logits_diff

def _get_logits_diff(self, x: np.ndarray, y: np.ndarray) -> np.ndarray:
y_pred = self.estimator.predict(x, batch_size=self.batch_size)
logit_correct = np.take_along_axis(y_pred, np.expand_dims(np.argmax(y, axis=1), axis=1), axis=1)
logit_highest_incorrect = np.take_along_axis(
y_pred, np.expand_dims(np.argsort(y_pred, axis=1)[:, -2], axis=1), axis=1
)
return (logit_correct - logit_highest_incorrect)[:, 0]

The method would become something like

def _get_logits_diff(self, x: np.ndarray, y: np.ndarray, y_pred: Optional[np.ndarray] = None):
    if y_pred is None:
        y_pred = self.estimator.predict(x, batch_size=self.batch_size)
    ...

I haven't come up with a general solution for custom loss functions.