/nx2_cross_validation

Run Nx2 Cross Validation for multiple binary classifiers in parallel with optional downsampling

Primary LanguagePython

Automatically comparing classifier performance

This code provides a quick and easy way to assess the performance and significance of multiple binary classifiers on dataset that exhibit extreme class imbalance.

Required libraries

  • Pandas
  • Numpy
  • Scipy for significance testing
  • Scikit-learn for classification algorithms

The functions

evaluate_models(X,y,models,times, undersampling=[0],time_unit='s')

Perform Nx2cv for a set of data, models, and parameters. For each level of undersampling, all models are run in parallel.

  • X - The feature matrix. Must be provided in 0s and 1s
  • y - The target classes
  • models - A python list of binary classifier models to be tested. Each model must have a model name assigned to it after being created (e.g. model.name = 'model1')
  • times - The number of 2-fold cross validation iterations to run (e.g. 5 will run 2-fold validation 5 times and generate 10 accuracy metrics)
  • undersampling - A list of undersampling proportions to attempt. 0 indicates no undersampling and is used by default
  • time_unit - The unit of time to be used for avg. run time and avg. predict time in the results table (s, m, or h)
rank_models(results)

When provided with the results dataframe that is generated by evaluate_models, rank_models() will return the average statistics of models and rank them by ROC AUC in descending order.

significance(results)

Generate a matrix that contains rows and columns for each model and undersampling level. Each cell of this matrix represents the probability that the two models have the same level of accuracy. This probability is calculated using an F test.

Example

First, we generate a number of models to test and provide them with names.

forest = RandomForestClassifier(n_estimators = 10, n_jobs=2, max_depth=50)
forest.name='forest1'
forest2 = RandomForestClassifier(n_estimators = 50, n_jobs=4, max_depth=100)
forest2.name='forest2'
logit=LogisticRegression()
logit.name='logit1'
ada= AdaBoostClassifier(DecisionTreeClassifier(max_depth=3),n_estimators = 10)
ada.name='ada1'
gforest= GradientBoostingClassifier(n_estimators = 10, max_depth=2, subsample=.5)
gforest.name='gforest1'
gforest2= GradientBoostingClassifier(n_estimators = 50, max_depth=3, subsample=.5)
gforest2.name='gforest2'

Next, an artificial classification dataset with severe class imbalance is generated. Approixmately 95% of records in this set will belong to class 0 (as specified by the weights parameter).

from sklearn.datasets import make_classification
data=make_classification(n_samples=100000, n_features=100, n_informative=4, weights=[.95], flip_y=.02, n_repeated=13, class_sep=.5)
X=pd.DataFrame(data[0])
y=pd.Series(data[1])

This dataset is run through our list of models with 5x2-fold cross validation and with the majority class undersampled until the minority class accounts for 10% and 20% of the data as well as with no undersampling.

results=evaluate_models(X,y,[logit,forest,forest2,ada,gforest,gforest2],5, undersampling=[0,.1,.2], time_unit='m')

The results are large, but here is a preview...

>>> results.head(10)
      model  iteration  fold undersampling       auc  fit_time  predict_time
0    logit1          0     1          none  0.664417  0.015710      0.000430
1    logit1          0     2          none  0.672645  0.017392      0.000408
2   forest1          0     1          none  0.782605  0.097416      0.002430
3   forest1          0     2          none  0.808222  0.095979      0.002295
4   forest2          0     1          none  0.838131  0.245126      0.003866
5   forest2          0     2          none  0.838551  0.232675      0.003725
6      ada1          0     1          none  0.817182  0.346996      0.003207
7      ada1          0     2          none  0.817818  0.320188      0.003003
8  gforest1          0     1          none  0.705441  0.073240      0.000530
9  gforest1          0     2          none  0.719508  0.071568      0.000460

The rank_models() function then aggregates the data to show which model performed best. Times shown are in minutes.

>>> rank_models(results)
                             auc  fit_time  predict_time
model    undersampling                                  
forest2  0.2            0.846917  0.048240      0.004250
forest2  0.1            0.845637  0.120340      0.003976
forest2  none           0.839483  0.237978      0.003786
gforest2 0.2            0.825184  0.118848      0.001160
gforest2 0.1            0.821249  0.249924      0.001234
forest1  0.2            0.816427  0.019376      0.002227
gforest2 none           0.815302  0.448876      0.001377
forest1  0.1            0.808853  0.048392      0.002267
ada1     none           0.808564  0.335405      0.003131
ada1     0.1            0.805994  0.189129      0.003135
ada1     0.2            0.804026  0.089088      0.003203
forest1  none           0.795091  0.096185      0.002295
gforest1 0.2            0.734390  0.018959      0.000504
gforest1 0.1            0.724401  0.039700      0.000483
gforest1 none           0.714376  0.071366      0.000470
logit1   0.2            0.674707  0.003152      0.000414
logit1   0.1            0.671341  0.008916      0.000391
logit1   none           0.669132  0.015787      0.000413

The significance results are also large, but here are the results for the gforest2 and forest2 models. High values indicate that the two models are not significantly different.

>>> significance(results)[['gforest2','forest2']]
Model                       gforest2      gforest2      gforest2  \
Undersampling                   none           0.2           0.1
Model    Undersampling
ada1     none              0.5381128     0.1590714     0.3717211
ada1     0.2             0.007221477   0.002052937  9.122191e-05
ada1     0.1               0.3916473    0.04004108    0.03473235
gforest2 none                    NaN    0.01308544    0.09802531
gforest2 0.2              0.01308544           NaN     0.1741726
gforest2 0.1              0.09802531     0.1741726           NaN
forest2  none            0.002961209   0.006890251    0.01451488
forest2  0.2            0.0002602179  2.377107e-06  0.0005662108
forest2  0.1            0.0001374467  2.365901e-05  0.0003685025
forest1  none              0.0483831   0.008635023    0.02564205
forest1  0.2               0.5963419    0.07934876      0.541395
forest1  0.1               0.2102436   0.004488396    0.03481045
gforest1 none           8.510569e-06  8.251625e-06  8.129192e-06
gforest1 0.2            2.620663e-06  7.024162e-06  7.280183e-06
gforest1 0.1            4.977668e-05  0.0001022563  0.0001428671
logit1   none           8.258857e-07  4.268434e-07  1.153785e-07
logit1   0.2            8.324898e-07  5.092532e-07  1.357096e-07
logit1   0.1            9.334062e-07  4.178198e-07  1.478798e-07

Model                        forest2       forest2       forest2
Undersampling                   none           0.2           0.1
Model    Undersampling
ada1     none             0.01016006    0.01025593   0.008048507
ada1     0.2            0.0009343802  0.0001625981   6.61957e-05
ada1     0.1              0.01351534   0.002296429   0.002967971
gforest2 none            0.002961209  0.0002602179  0.0001374467
gforest2 0.2             0.006890251  2.377107e-06  2.365901e-05
gforest2 0.1              0.01451488  0.0005662108  0.0003685025
forest2  none                    NaN    0.03965544    0.04566424
forest2  0.2              0.03965544           NaN     0.3502626
forest2  0.1              0.04566424     0.3502626           NaN
forest1  none            0.001257939  0.0006595105  0.0005784847
forest1  0.2             0.000312439  0.0002433438  0.0004074213
forest1  0.1            5.765322e-05  4.521528e-05  1.641226e-05
gforest1 none           8.637735e-06   5.23887e-06  2.411338e-06
gforest1 0.2            2.073483e-06  3.211817e-06  9.978427e-07
gforest1 0.1            9.419857e-05  5.354328e-05  4.748392e-05
logit1   none           8.117545e-07  2.211232e-07  1.839379e-07
logit1   0.2            9.489241e-07  2.667403e-07  2.066221e-07
logit1   0.1            7.953832e-07  2.061502e-07  1.806246e-07

Further reading