Adding a label check for Trades adversarial trainer.
GiulioZizzo opened this issue · 0 comments
GiulioZizzo commented
Describe the bug
The trades adversarial trainer assumes in a few places that the label format is one-hot encoded. We propose a simple change to check the dimensionality of the supplied labels in a few places before applying an argmax operation.
To Reproduce
Steps to reproduce the behavior:
- Go to test_adversarial_trainer_trades_pytorch.py
- Add a line such as y_test_mnist = np.argmax(y_test_mnist, axis=1) after loading the MNIST data to give labels in index style.
numpy.AxisError: axis 1 is out of bounds for array of dimension 1
error will be given.
Expected behavior
The classifier should be able to handle different label formats, or provide an informative warning message if the wrong label format is supplied.
System information (please complete the following information):
- OS: MacOS
- Python version: 3.9
- ART version or commit number: ART 1.15
- TensorFlow / Keras / PyTorch / MXNet version: torch==2.0.1