
Prototype based ML implementation for ascertaing the confidence of predicted labels from the Learning Vector Quantization family of advanced machine learning classification algorithms.

Primary LanguagePython

Python: 3.9 Pytorch: 1.11 Prototorch: 0.7.3 License: MIT


What is it?

Classification label security is a prototype-based recall procedure that determines the confidence of predicted labels from the Learning Vector Quantization family of advanced machine learning classification algorithms.

File structure

├── contour.py                               # Visualization of plots
├── iris_securitycelvq.py                    # Iris_test set example with celvq
├── iris_securityglvq.py                     # Iris_test set example with glvq
├── iris_securitygmlvq.py                    # Iris_test set example with gmlvq
├── iris_securitycelvq.py                    # Iris_test set example with celvq
├── optimised_m.py                           # script for optimal search of hyperparameter(m)
├── label_security1.py                       # classification label security/certainty for LVQs
├── protocert.py                             # Auxilliary code
└── README.md

How to use?

from label_security1 import LabelSecurity, LabelSecurityM, LabelSecurityLM 

# Non matrix LVQs
 label_security= LabelSecurity(x_test, class_labels, predict_results, model_prototypes, X)
# Matrix and Local-Matrix LVQs
label_security= LabelSecurityM(x_test, class_labels, model_prototypes, model_omega, X)

The LVQ models are first trained using a training data. The learned prototypes are accessed and used to compute the classification label certainties of the test data.

Visualization / Results

Classification results with reject and non-reject options based on the chow's approach (out of a simulated test results with a security thresh-hold of 0.7) is shown below for the GLVQ, GMLVQ and CELVQ models respectively.

Below is a plot indicating a diminishing trend of the classification lable security of a sample data point with increasing m hyperparameter for label_security1.py


The optimal choice of hyperparameter m as against the default choice of m=2 is shown below for label_security1.py with the iris data set set using GLVQ, GMLVQ and CELVQ.
