/jknn

A Java implementation of the k-nearest neighbors algorithm

Primary LanguageJava

jknn

A Java implementation of the k-nearest neighbors algorithm

MNIST example

// Load the train and test datasets
Dataset trainDataset = new MnistDataset().load(trainLabelsStream, trainImagesStream);
Dataset testDataset = new MnistDataset().load(testLabelsStream, testImagesStream);

// Create a classifier with the specified distance function
Classifier classifier = new BruteForceClassifier(new EuclideanDistance());

// Fit the classifier with the training dataset
classifier.fit(trainDataset);

// The number of nearest neighbors used to predict the label
int k = 9;

// Get the accuracy of the fitted classifier on the testing dataset
double accuracy = classifier.accuracy(testDataset, k);

// Use the fitted classifier to predict the label of a new feature set
String predictedLabel = classifier.classify(features, k);

This example results in an accuracy of 96.95% on the MNIST dataset (with no preprocessing)