is_numerical argument GrootTreeClassifier
Closed this issue · 2 comments
laudv commented
Running the example code on the make moons data in the README I get:
Traceback (most recent call last):
File "/home/.../groot_test.py", line 11, in <module>
tree = GrootTreeClassifier(attack_model=attack_model, is_numerical=is_numerical, random_state=0)
TypeError: __init__() got an unexpected keyword argument 'is_numerical'
Leaving out the argument and having this line instead:
tree = GrootTreeClassifier(attack_model=attack_model, random_state=0)
results in this error:
Traceback (most recent call last):
File "/home/.../groot_test.py", line 15, in <module>
adversarial_accuracy = DecisionTreeAdversary(tree, "groot").adversarial_accuracy(X_test, y_test)
File "/home/.../venv/lib/python3.9/site-packages/groot/adversary.py", line 259, in __init__
self.is_numeric = self.decision_tree.is_numerical
AttributeError: 'GrootTreeClassifier' object has no attribute 'is_numerical'
I'm guessing the code got an update, but the readme didn't. Or I made a stupid mistake, also very possible.
daniel-vos commented
Sorry for missing this issue, somehow I had my notification turned off! The problem is indeed my mistake of forgetting to update the README.md
. I now replaced the example inside of the README.md
with the example from the docs and this should work. Could you test that? Let me know if there are any remaining problems!
The new example should be:
from groot.model import GrootTreeClassifier
from groot.toolbox import Model
from sklearn.datasets import make_moons
# Load the dataset
X, y = make_moons(noise=0.3, random_state=0)
X_test, y_test = make_moons(noise=0.3, random_state=1)
# Define the attacker's capabilities (L-inf norm radius 0.3)
epsilon = 0.3
attack_model = [epsilon, epsilon]
# Create and fit a GROOT tree
tree = GrootTreeClassifier(
attack_model=attack_model,
random_state=0
)
tree.fit(X, y)
# Determine the accuracy and accuracy against attackers
accuracy = tree.score(X_test, y_test)
model = Model.from_groot(tree)
adversarial_accuracy = model.adversarial_accuracy(X_test, y_test, attack="tree", epsilon=0.3)
print("Accuracy:", accuracy)
print("Adversarial Accuracy:", adversarial_accuracy)
laudv commented
Yes, that works fine, thanks!