This code provides a quick and easy way to assess the performance and significance of multiple binary classifiers on dataset that exhibit extreme class imbalance.
- Pandas
- Numpy
- Scipy for significance testing
- Scikit-learn for classification algorithms
evaluate_models(X,y,models,times, undersampling=[0],time_unit='s')
Perform Nx2cv for a set of data, models, and parameters. For each level of undersampling, all models are run in parallel.
- X - The feature matrix. Must be provided in 0s and 1s
- y - The target classes
- models - A python list of binary classifier models to be tested. Each model must have a model name assigned to it after being created (e.g. model.name = 'model1')
- times - The number of 2-fold cross validation iterations to run (e.g. 5 will run 2-fold validation 5 times and generate 10 accuracy metrics)
- undersampling - A list of undersampling proportions to attempt. 0 indicates no undersampling and is used by default
- time_unit - The unit of time to be used for avg. run time and avg. predict time in the results table (s, m, or h)
rank_models(results)
When provided with the results dataframe that is generated by evaluate_models, rank_models() will return the average statistics of models and rank them by ROC AUC in descending order.
significance(results)
Generate a matrix that contains rows and columns for each model and undersampling level. Each cell of this matrix represents the probability that the two models have the same level of accuracy. This probability is calculated using an F test.
First, we generate a number of models to test and provide them with names.
forest = RandomForestClassifier(n_estimators = 10, n_jobs=2, max_depth=50)
forest.name='forest1'
forest2 = RandomForestClassifier(n_estimators = 50, n_jobs=4, max_depth=100)
forest2.name='forest2'
logit=LogisticRegression()
logit.name='logit1'
ada= AdaBoostClassifier(DecisionTreeClassifier(max_depth=3),n_estimators = 10)
ada.name='ada1'
gforest= GradientBoostingClassifier(n_estimators = 10, max_depth=2, subsample=.5)
gforest.name='gforest1'
gforest2= GradientBoostingClassifier(n_estimators = 50, max_depth=3, subsample=.5)
gforest2.name='gforest2'
Next, an artificial classification dataset with severe class imbalance is generated. Approixmately 95% of records in this set will belong to class 0 (as specified by the weights parameter).
from sklearn.datasets import make_classification
data=make_classification(n_samples=100000, n_features=100, n_informative=4, weights=[.95], flip_y=.02, n_repeated=13, class_sep=.5)
X=pd.DataFrame(data[0])
y=pd.Series(data[1])
This dataset is run through our list of models with 5x2-fold cross validation and with the majority class undersampled until the minority class accounts for 10% and 20% of the data as well as with no undersampling.
results=evaluate_models(X,y,[logit,forest,forest2,ada,gforest,gforest2],5, undersampling=[0,.1,.2], time_unit='m')
The results are large, but here is a preview...
>>> results.head(10)
model iteration fold undersampling auc fit_time predict_time
0 logit1 0 1 none 0.664417 0.015710 0.000430
1 logit1 0 2 none 0.672645 0.017392 0.000408
2 forest1 0 1 none 0.782605 0.097416 0.002430
3 forest1 0 2 none 0.808222 0.095979 0.002295
4 forest2 0 1 none 0.838131 0.245126 0.003866
5 forest2 0 2 none 0.838551 0.232675 0.003725
6 ada1 0 1 none 0.817182 0.346996 0.003207
7 ada1 0 2 none 0.817818 0.320188 0.003003
8 gforest1 0 1 none 0.705441 0.073240 0.000530
9 gforest1 0 2 none 0.719508 0.071568 0.000460
The rank_models() function then aggregates the data to show which model performed best. Times shown are in minutes.
>>> rank_models(results)
auc fit_time predict_time
model undersampling
forest2 0.2 0.846917 0.048240 0.004250
forest2 0.1 0.845637 0.120340 0.003976
forest2 none 0.839483 0.237978 0.003786
gforest2 0.2 0.825184 0.118848 0.001160
gforest2 0.1 0.821249 0.249924 0.001234
forest1 0.2 0.816427 0.019376 0.002227
gforest2 none 0.815302 0.448876 0.001377
forest1 0.1 0.808853 0.048392 0.002267
ada1 none 0.808564 0.335405 0.003131
ada1 0.1 0.805994 0.189129 0.003135
ada1 0.2 0.804026 0.089088 0.003203
forest1 none 0.795091 0.096185 0.002295
gforest1 0.2 0.734390 0.018959 0.000504
gforest1 0.1 0.724401 0.039700 0.000483
gforest1 none 0.714376 0.071366 0.000470
logit1 0.2 0.674707 0.003152 0.000414
logit1 0.1 0.671341 0.008916 0.000391
logit1 none 0.669132 0.015787 0.000413
The significance results are also large, but here are the results for the gforest2 and forest2 models. High values indicate that the two models are not significantly different.
>>> significance(results)[['gforest2','forest2']]
Model gforest2 gforest2 gforest2 \
Undersampling none 0.2 0.1
Model Undersampling
ada1 none 0.5381128 0.1590714 0.3717211
ada1 0.2 0.007221477 0.002052937 9.122191e-05
ada1 0.1 0.3916473 0.04004108 0.03473235
gforest2 none NaN 0.01308544 0.09802531
gforest2 0.2 0.01308544 NaN 0.1741726
gforest2 0.1 0.09802531 0.1741726 NaN
forest2 none 0.002961209 0.006890251 0.01451488
forest2 0.2 0.0002602179 2.377107e-06 0.0005662108
forest2 0.1 0.0001374467 2.365901e-05 0.0003685025
forest1 none 0.0483831 0.008635023 0.02564205
forest1 0.2 0.5963419 0.07934876 0.541395
forest1 0.1 0.2102436 0.004488396 0.03481045
gforest1 none 8.510569e-06 8.251625e-06 8.129192e-06
gforest1 0.2 2.620663e-06 7.024162e-06 7.280183e-06
gforest1 0.1 4.977668e-05 0.0001022563 0.0001428671
logit1 none 8.258857e-07 4.268434e-07 1.153785e-07
logit1 0.2 8.324898e-07 5.092532e-07 1.357096e-07
logit1 0.1 9.334062e-07 4.178198e-07 1.478798e-07
Model forest2 forest2 forest2
Undersampling none 0.2 0.1
Model Undersampling
ada1 none 0.01016006 0.01025593 0.008048507
ada1 0.2 0.0009343802 0.0001625981 6.61957e-05
ada1 0.1 0.01351534 0.002296429 0.002967971
gforest2 none 0.002961209 0.0002602179 0.0001374467
gforest2 0.2 0.006890251 2.377107e-06 2.365901e-05
gforest2 0.1 0.01451488 0.0005662108 0.0003685025
forest2 none NaN 0.03965544 0.04566424
forest2 0.2 0.03965544 NaN 0.3502626
forest2 0.1 0.04566424 0.3502626 NaN
forest1 none 0.001257939 0.0006595105 0.0005784847
forest1 0.2 0.000312439 0.0002433438 0.0004074213
forest1 0.1 5.765322e-05 4.521528e-05 1.641226e-05
gforest1 none 8.637735e-06 5.23887e-06 2.411338e-06
gforest1 0.2 2.073483e-06 3.211817e-06 9.978427e-07
gforest1 0.1 9.419857e-05 5.354328e-05 4.748392e-05
logit1 none 8.117545e-07 2.211232e-07 1.839379e-07
logit1 0.2 9.489241e-07 2.667403e-07 2.066221e-07
logit1 0.1 7.953832e-07 2.061502e-07 1.806246e-07
- Handling Class Imbalance in Customer Churn Prediction - Burez and Van den Poel (2009): http://scholar.google.com/scholar?cluster=9767278684451707768&hl=en&as_sdt=0,11
- Approximate statistical tests for comparing supervised classification learning algorithms - Dietterich (1998): http://scholar.google.com/scholar?cluster=1634956806564791154&hl=en&as_sdt=0,11&as_vis=1
- Applied Predictive Modeling Chapter 16 - Kuhn and Johnson: http://appliedpredictivemodeling.com/
- ROC Curve: http://en.wikipedia.org/wiki/Receiver_operating_characteristic