/scratch_genetic

A from-scratch genetic-algorithm library used in my march-madness-predictor project

Primary LanguageRustGNU General Public License v3.0GPL-3.0

scratch_genetic

Description

A from-scratch genetic-algorithm library used in my march-madness-predictor project

API Reference

Contents

  1. genetic module
    1. gen_pop function
    2. test_and_sort function
    3. reproduce function
    4. load_and_predict function
    5. export_model function
  2. Examples
    1. Training
    2. Predicting

genetic module

The genetic module is the only public module under the scratch_genetic library.

It contains a set of functions implementing the "genetic algorithm" which mimics the concept of natural selection in order to create a model that can be used to predict the results of something.

The way it works is you convert your data into a stream of input and output bits, create a set of randomized networks based on the size of the data coming in and with various settings, train by running the test and reproduction functions, and then export the final, best network at the end.

Then you can

gen_pop function

pub async fn gen_pop(
        pop_size: usize,
        layer_sizes: Vec<usize>, num_inputs: usize, num_outputs: usize,
        activation_thresh: f64, trait_swap_chance: f64,
        weight_mutate_chance: f64, weight_mutate_amount: f64,
        offset_mutate_chance: f64, offset_mutate_amount: f64) -> Vec<Network> {

This function generates a random vector of an underlying private struct Network. It's private because you won't need to manually mess with it, you'll just need to pass it between functions.

You can see it takes quite a few parameters. These are all the manual settings to give to your network to adjust how it trains.

Parameters:

  • pop_size - number of networks to train on. Bigger is better, but bigger is also slower
  • layer_sizes - a vector containing sizes for each layer of the neural network
  • num_inputs - the number of bits that your input data produces after being converted (must be divisible by 8)
  • num_outputs - the expected number of bits generated by the output (must be divisible by 8)
  • activation_thresh - how hard it is for a neuron to turn on
  • trait_swap_chance - controls the variablity of a child sharing different traits from each parent when reproducing
  • weight_mutate_chance - the chance that a weight on the connections between neurons changes
  • weight_mutate_amount - how strong the change above is
  • offset_mutate_chance and offset_mutate_amount - same as the above two, but with the base value of the connection

test_and_sort function

pub async fn test_and_sort(pop: &mut Vec<Network>, data_set: &Vec<(Vec<u8>, Vec<u8>)>) {

This takes the "population" (vector of Networks created by gen_pop) and your test data, sees how close each nework gets to reproducing each test data's output, and then sorts the networks based on that performance.

reproduce function

pub async fn reproduce(pop: &mut Vec<Network>) {

After sorting, you'll want to reproduce. This takes your set of networks, keeps the upper half of them, and uses those to replace the bottom half with children sharing mixed genes and mutations based on the parameters you provided in gen_pop.

load_and_predict function

pub async fn load_and_predict(file_name: &'static str, input_bits: &Vec<u8>) -> Vec<u8> {

Load in a model that's been exported and, provided the input bits you pass in, generate output bits

export_model function

pub async fn export_model(file_name: &'static str, pop: &Network) {

Take a network and export it to a file.

Examples

The following examples use these constants:

// Neuron connection settings
pub const NEURON_ACTIVATION_THRESH: f64 = 0.60;
pub const TRAIT_SWAP_CHANCE: f64 = 0.80;
pub const WEIGHT_MUTATE_CHANCE: f64 = 0.65;
pub const WEIGHT_MUTATE_AMOUNT: f64 = 0.5;
pub const OFFSET_MUTATE_CHANCE: f64 = 0.25;
pub const OFFSET_MUTATE_AMOUNT: f64 = 0.05;

// Neural network settings
pub const LAYER_SIZES: [usize; 4] = [ 8, 32, 32, 16 ];

// Algortithm settings
const POP_SIZE: usize = 2000;

const DATA_FILE_NAME: &'static str = "NCAA Mens March Madness Historical Results.csv";
const MODEL_FILE_NAME: &'static str = "model.mmp";
const NUM_GENS: usize = 1000;

Training

println!("Training new March Madness Predictor Model");

// Custom class that structures CSV data and allows turning into bits.
println!("Loading training data from {}", DATA_FILE_NAME);
let games = GameInfo::collection_from_file(DATA_FILE_NAME);
let games: Vec<(Vec<u8>, Vec<u8>)> = games.iter().map(|game| { // Redefines games
    (game.to_input_bits().to_vec(), game.to_output_bits().to_vec())
}).collect();

println!("Generating randomized population");
let now = Instant::now();
let mut pop = gen_pop(
    POP_SIZE,
    LAYER_SIZES.to_vec(), NUM_INPUTS, NUM_OUTPUTS,
    NEURON_ACTIVATION_THRESH, TRAIT_SWAP_CHANCE,
    WEIGHT_MUTATE_CHANCE, WEIGHT_MUTATE_AMOUNT,
    OFFSET_MUTATE_CHANCE, OFFSET_MUTATE_AMOUNT
).await;
let elapsed = now.elapsed();
println!("Generation took {}s", elapsed.as_secs_f64());

println!("Starting training");
for i in 0..NUM_GENS {
    println!("Generation {} / {}", i, NUM_GENS);
    test_and_sort(&mut pop, &games).await;
    reproduce(&mut pop).await;
}

// Save algorithm
println!("Saving model to {}", MODEL_FILE_NAME);
export_model(MODEL_FILE_NAME, &pop[0]).await;

Predicting

pub async fn predict(team_names: &str) {
    let table_data = team_names.split(",");
    let mut indexable_table_data = Vec::new();
    for item in table_data {
        indexable_table_data.push(item);
    }
    
    // A team, A seed, B team, B seed, date, round, region
    if indexable_table_data.len() != 7 {
        println!("Invalid input string!");
        return;
    }

    // Like the other example, this stuff is converting CSV data into a useable form
    println!("Converting input into data...");
    let entry = TableEntry {
        winner: String::from(indexable_table_data[0]),
        win_seed: String::from(indexable_table_data[1]),
        loser: String::from(indexable_table_data[2]),
        lose_seed: String::from(indexable_table_data[3]),
        date: String::from(indexable_table_data[4]),
        round: String::from(indexable_table_data[5]),
        region: String::from(indexable_table_data[6]),

        win_score: String::from("0"),
        lose_score: String::from("0"),
        overtime: String::from("")
    };
    let game = GameInfo::from_table_entry(&entry);

    // Here's where the code is used
    println!("Predicting!");
    let result = load_and_predict(MODEL_FILE_NAME, &game.to_input_bits().to_vec()).await;

    println!("Predicted score for {}: {}", indexable_table_data[0], result[0]);
    println!("Predicted score for {}: {}", indexable_table_data[2], result[1]);
    println!("Expected overtimes: {}", result[2]);
}