rust-ml/linfa

linfa_bayes::MultinomialNb,when use predict(), program paniced!

coolstudio1678 opened this issue · 0 comments

error message:
thread 'main' panicked at C:\Users\c1895.cargo\registry\src\index.crates.io-6f17d22bba15001f\linfa-bayes-0.7.0\src\base_nb.rs:44:32:
called Result::unwrap() on an Err value: UndefinedOrder

use linfa::metrics::ToConfusionMatrix;
use linfa::traits::{Fit, Predict};
use linfa_bayes::{MultinomialNb, Result};
use linfa::prelude::*;
use ndarray::prelude::*;
use polars::prelude::*;
use std::{fs,path::Path};
use std::time::Instant;
use ciborium::*;

fn main() -> Result<()> {

    let start = Instant::now();
    let df =CsvReader::from_path("./model_data_f.csv")
            .unwrap()
            .infer_schema(None)
            .has_header(true)
            .finish()
            .unwrap();
    let df = df.tail(Some(500));
    let col_target = df.column("d_target_high").unwrap();
    let target_high = col_target.f64().unwrap().iter().collect::<Vec<_>>();
    let target: ArrayBase<ndarray::OwnedRepr<_>, _> = ArrayBase::from_vec(target_high.to_vec());

    let mut features = df.clone();
   // drop some data;
 
    for col_name in features.get_column_names_owned(){
        let casted_col = df.column(&col_name).unwrap()
            .cast(&DataType::Float32)
            .expect("Failed to cast column");
        features.with_column(casted_col).unwrap();        
    }

    let features_ar = features.to_ndarray::<Float64Type>(IndexOrder::C).unwrap();

    let linfa_dataset = Dataset::new(features_ar,target)
        //.with_weights(wt)
        .map_targets(|x| if *x > Some(2.0) {"good"} else {"bad"});
    
    let (train,valid) = linfa_dataset.split_with_ratio(0.9);
    // dbg!(&train);
    let model = MultinomialNb::params().fit(&train)?;
    dbg!("model is ok");
    let pred = model.predict(&valid);
    dbg!("test is ok");
    let cm = pred.confusion_matrix(&valid)?;

     let file = Path::new(".").join("model_bayes.cobr");
    // dbg!(&file);
    //let value_model = value::Value::from(model);
    // let model_value = bincode::serialize(&model).unwrap();
    let value_model = cbor!(model).unwrap();
    let mut vec_model = Vec::new();
    let _result = ciborium::ser::into_writer(&value_model, &mut vec_model).unwrap();
    fs::write(file, vec_model).unwrap();


    println!("accuracy {}, MCC {}", cm.accuracy() *100., cm.mcc());
    let duration = start.elapsed();
    println!("Time elapsed is: {:?}", duration);

    Ok(())
}