tudelft-cda-lab/GROOT

is_numerical argument GrootTreeClassifier

Closed this issue · 2 comments

laudv commented

Running the example code on the make moons data in the README I get:

Traceback (most recent call last):
  File "/home/.../groot_test.py", line 11, in <module>
    tree = GrootTreeClassifier(attack_model=attack_model, is_numerical=is_numerical, random_state=0)
TypeError: __init__() got an unexpected keyword argument 'is_numerical'

Leaving out the argument and having this line instead:
tree = GrootTreeClassifier(attack_model=attack_model, random_state=0)
results in this error:

Traceback (most recent call last):
  File "/home/.../groot_test.py", line 15, in <module>
    adversarial_accuracy = DecisionTreeAdversary(tree, "groot").adversarial_accuracy(X_test, y_test)
  File "/home/.../venv/lib/python3.9/site-packages/groot/adversary.py", line 259, in __init__
    self.is_numeric = self.decision_tree.is_numerical
AttributeError: 'GrootTreeClassifier' object has no attribute 'is_numerical'

I'm guessing the code got an update, but the readme didn't. Or I made a stupid mistake, also very possible.

Sorry for missing this issue, somehow I had my notification turned off! The problem is indeed my mistake of forgetting to update the README.md. I now replaced the example inside of the README.md with the example from the docs and this should work. Could you test that? Let me know if there are any remaining problems!

The new example should be:

from groot.model import GrootTreeClassifier
from groot.toolbox import Model

from sklearn.datasets import make_moons

# Load the dataset
X, y = make_moons(noise=0.3, random_state=0)
X_test, y_test = make_moons(noise=0.3, random_state=1)

# Define the attacker's capabilities (L-inf norm radius 0.3)
epsilon = 0.3
attack_model = [epsilon, epsilon]

# Create and fit a GROOT tree
tree = GrootTreeClassifier(
    attack_model=attack_model,
    random_state=0
)
tree.fit(X, y)

# Determine the accuracy and accuracy against attackers
accuracy = tree.score(X_test, y_test)
model = Model.from_groot(tree)
adversarial_accuracy = model.adversarial_accuracy(X_test, y_test, attack="tree", epsilon=0.3)

print("Accuracy:", accuracy)
print("Adversarial Accuracy:", adversarial_accuracy)
laudv commented

Yes, that works fine, thanks!