linfa_bayes::MultinomialNb,when use predict(), program paniced!
coolstudio1678 opened this issue · 0 comments
coolstudio1678 commented
error message:
thread 'main' panicked at C:\Users\c1895.cargo\registry\src\index.crates.io-6f17d22bba15001f\linfa-bayes-0.7.0\src\base_nb.rs:44:32:
called Result::unwrap()
on an Err
value: UndefinedOrder
use linfa::metrics::ToConfusionMatrix;
use linfa::traits::{Fit, Predict};
use linfa_bayes::{MultinomialNb, Result};
use linfa::prelude::*;
use ndarray::prelude::*;
use polars::prelude::*;
use std::{fs,path::Path};
use std::time::Instant;
use ciborium::*;
fn main() -> Result<()> {
let start = Instant::now();
let df =CsvReader::from_path("./model_data_f.csv")
.unwrap()
.infer_schema(None)
.has_header(true)
.finish()
.unwrap();
let df = df.tail(Some(500));
let col_target = df.column("d_target_high").unwrap();
let target_high = col_target.f64().unwrap().iter().collect::<Vec<_>>();
let target: ArrayBase<ndarray::OwnedRepr<_>, _> = ArrayBase::from_vec(target_high.to_vec());
let mut features = df.clone();
// drop some data;
for col_name in features.get_column_names_owned(){
let casted_col = df.column(&col_name).unwrap()
.cast(&DataType::Float32)
.expect("Failed to cast column");
features.with_column(casted_col).unwrap();
}
let features_ar = features.to_ndarray::<Float64Type>(IndexOrder::C).unwrap();
let linfa_dataset = Dataset::new(features_ar,target)
//.with_weights(wt)
.map_targets(|x| if *x > Some(2.0) {"good"} else {"bad"});
let (train,valid) = linfa_dataset.split_with_ratio(0.9);
// dbg!(&train);
let model = MultinomialNb::params().fit(&train)?;
dbg!("model is ok");
let pred = model.predict(&valid);
dbg!("test is ok");
let cm = pred.confusion_matrix(&valid)?;
let file = Path::new(".").join("model_bayes.cobr");
// dbg!(&file);
//let value_model = value::Value::from(model);
// let model_value = bincode::serialize(&model).unwrap();
let value_model = cbor!(model).unwrap();
let mut vec_model = Vec::new();
let _result = ciborium::ser::into_writer(&value_model, &mut vec_model).unwrap();
fs::write(file, vec_model).unwrap();
println!("accuracy {}, MCC {}", cm.accuracy() *100., cm.mcc());
let duration = start.elapsed();
println!("Time elapsed is: {:?}", duration);
Ok(())
}