diff --git a/.gitmodules b/.gitmodules index cbbe4a5..68a3c82 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,4 +1,3 @@ [submodule "xgboost-sys/xgboost"] path = xgboost-sys/xgboost - url = https://github.com/davechallis/xgboost - branch = master + url = https://github.com/dmlc/xgboost diff --git a/Cargo.toml b/Cargo.toml index b9d6584..2a68045 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,11 +8,15 @@ homepage = "https://github.com/davechallis/rust-xgboost" description = "Machine learning using XGBoost" documentation = "https://docs.rs/xgboost" readme = "README.md" +edition = "2021" [dependencies] -xgboost-sys = "0.2.0" +xgboost-sys = { path = "xgboost-sys" } libc = "0.2" -derive_builder = "0.5" +derive_builder = "0.20" log = "0.4" -tempfile = "3.0" -indexmap = "1.0" +tempfile = "3.15" +indexmap = "2.7" + +[features] +cuda = ["xgboost-sys/cuda"] diff --git a/README.md b/README.md index 009f869..c408a4c 100644 --- a/README.md +++ b/README.md @@ -5,6 +5,12 @@ Rust bindings for the [XGBoost](https://xgboost.ai) gradient boosting library. +## Requirements + +- Clang v16.0.0 + +## Documentation + * [Documentation](https://docs.rs/xgboost) Basic usage example: @@ -81,7 +87,7 @@ more detailed examples of different features. Currently in a very early stage of development, so the API is changing as usability issues occur, or new features are supported. -Builds against XGBoost 0.81. +Builds against XGBoost 2.0.3. ### Platforms diff --git a/examples/basic/src/main.rs b/examples/basic/src/main.rs index 2e8955e..eee9713 100644 --- a/examples/basic/src/main.rs +++ b/examples/basic/src/main.rs @@ -12,9 +12,9 @@ fn main() { // load train and test matrices from text files (in LibSVM format). println!("Loading train and test matrices..."); - let dtrain = DMatrix::load("../../xgboost-sys/xgboost/demo/data/agaricus.txt.train").unwrap(); + let dtrain = DMatrix::load(r#"{"uri": "../../xgboost-sys/xgboost/demo/data/agaricus.txt.train?format=libsvm"}"#).unwrap(); println!("Train matrix: {}x{}", dtrain.num_rows(), dtrain.num_cols()); - let dtest = DMatrix::load("../../xgboost-sys/xgboost/demo/data/agaricus.txt.test").unwrap(); + let dtest = DMatrix::load(r#"{"uri": "../../xgboost-sys/xgboost/demo/data/agaricus.txt.test?format=libsvm"}"#).unwrap(); println!("Test matrix: {}x{}", dtest.num_rows(), dtest.num_cols()); // configure objectives, metrics, etc. @@ -66,15 +66,15 @@ fn main() { // save and load model file println!("\nSaving and loading Booster model..."); - booster.save("xgb.model").unwrap(); - let booster = Booster::load("xgb.model").unwrap(); + booster.save("xgb.json").unwrap(); + let booster = Booster::load("xgb.json").unwrap(); let preds2 = booster.predict(&dtest).unwrap(); assert_eq!(preds, preds2); // save and load data matrix file println!("\nSaving and loading matrix data..."); dtest.save("test.dmat").unwrap(); - let dtest2 = DMatrix::load("test.dmat").unwrap(); + let dtest2 = DMatrix::load_binary("test.dmat").unwrap(); assert_eq!(booster.predict(&dtest2).unwrap(), preds); // error handling example diff --git a/examples/custom_objective/src/main.rs b/examples/custom_objective/src/main.rs index 707f037..7af09e2 100644 --- a/examples/custom_objective/src/main.rs +++ b/examples/custom_objective/src/main.rs @@ -6,8 +6,8 @@ use xgboost::{parameters, DMatrix, Booster}; fn main() { // load train and test matrices from text files (in LibSVM format) println!("Custom objective example..."); - let dtrain = DMatrix::load("../../xgboost-sys/xgboost/demo/data/agaricus.txt.train").unwrap(); - let dtest = DMatrix::load("../../xgboost-sys/xgboost/demo/data/agaricus.txt.test").unwrap(); + let dtrain = DMatrix::load(r#"{"uri": "../../xgboost-sys/xgboost/demo/data/agaricus.txt.train?format=libsvm"}"#).unwrap(); + let dtest = DMatrix::load(r#"{"uri": "../../xgboost-sys/xgboost/demo/data/agaricus.txt.test?format=libsvm"}"#).unwrap(); // specify datasets to evaluate against during training let evaluation_sets = [(&dtest, "test"), (&dtrain, "train")]; diff --git a/examples/generalised_linear_model/src/main.rs b/examples/generalised_linear_model/src/main.rs index a34974c..ceb1022 100644 --- a/examples/generalised_linear_model/src/main.rs +++ b/examples/generalised_linear_model/src/main.rs @@ -12,8 +12,8 @@ fn main() { // load train and test matrices from text files (in LibSVM format) println!("Custom objective example..."); - let dtrain = DMatrix::load("../../xgboost-sys/xgboost/demo/data/agaricus.txt.train").unwrap(); - let dtest = DMatrix::load("../../xgboost-sys/xgboost/demo/data/agaricus.txt.test").unwrap(); + let dtrain = DMatrix::load(r#"{"uri": "../../xgboost-sys/xgboost/demo/data/agaricus.txt.train?format=libsvm"}"#).unwrap(); + let dtest = DMatrix::load(r#"{"uri": "../../xgboost-sys/xgboost/demo/data/agaricus.txt.test?format=libsvm"}"#).unwrap(); // configure objectives, metrics, etc. let learning_params = parameters::learning::LearningTaskParametersBuilder::default() diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 0000000..d976b15 --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1,2 @@ +max_width = 120 +single_line_if_else_max_width = 80 diff --git a/src/booster.rs b/src/booster.rs index 1f2dbac..a965b6d 100644 --- a/src/booster.rs +++ b/src/booster.rs @@ -1,19 +1,19 @@ +use crate::dmatrix::DMatrix; +use crate::error::XGBError; use libc; -use std::{fs::File, fmt, slice, ffi, ptr}; -use std::str::FromStr; -use std::io::{self, Write, BufReader, BufRead}; use std::collections::{BTreeMap, HashMap}; -use std::path::{Path, PathBuf}; -use error::XGBError; -use dmatrix::DMatrix; +use std::io::{self, BufRead, BufReader, Write}; use std::os::unix::ffi::OsStrExt; +use std::path::{Path, PathBuf}; +use std::str::FromStr; +use std::{ffi, fmt, fs::File, ptr, slice}; -use xgboost_sys; -use tempfile; use indexmap::IndexMap; +use tempfile; +use xgboost_sys; use super::XGBResult; -use parameters::{BoosterParameters, TrainingParameters}; +use crate::parameters::{BoosterParameters, TrainingParameters}; pub type CustomObjective = fn(&[f32], &DMatrix) -> (Vec, Vec); @@ -76,7 +76,11 @@ impl Booster { let mut handle = ptr::null_mut(); // TODO: check this is safe if any dmats are freed let s: Vec = dmats.iter().map(|x| x.handle).collect(); - xgb_call!(xgboost_sys::XGBoosterCreate(s.as_ptr(), dmats.len() as u64, &mut handle))?; + xgb_call!(xgboost_sys::XGBoosterCreate( + s.as_ptr(), + dmats.len() as u64, + &mut handle + ))?; let mut booster = Booster { handle }; booster.set_params(params)?; @@ -112,7 +116,11 @@ impl Booster { let mut handle = ptr::null_mut(); xgb_call!(xgboost_sys::XGBoosterCreate(ptr::null(), 0, &mut handle))?; - xgb_call!(xgboost_sys::XGBoosterLoadModelFromBuffer(handle, bytes.as_ptr() as *const _, bytes.len() as u64))?; + xgb_call!(xgboost_sys::XGBoosterLoadModelFromBuffer( + handle, + bytes.as_ptr() as *const _, + bytes.len() as u64 + ))?; Ok(Booster { handle }) } @@ -140,36 +148,8 @@ impl Booster { dmats }; - let mut bst = Booster::new_with_cached_dmats(¶ms.booster_params, &cached_dmats)?; - //let num_parallel_tree = 1; - - // load distributed code checkpoint from rabit - let version = bst.load_rabit_checkpoint()?; - debug!("Loaded Rabit checkpoint: version={}", version); - assert!(unsafe { xgboost_sys::RabitGetWorldSize() != 1 || version == 0 }); - - let _rank = unsafe { xgboost_sys::RabitGetRank() }; - let start_iteration = version / 2; - //let mut nboost = start_iteration; - - for i in start_iteration..params.boost_rounds as i32 { - // distributed code: need to resume to this point - // skip first update if a recovery step - if version % 2 == 0 { - if let Some(objective_fn) = params.custom_objective_fn { - debug!("Boosting in round: {}", i); - bst.update_custom(params.dtrain, objective_fn)?; - } else { - debug!("Updating in round: {}", i); - bst.update(params.dtrain, i)?; - } - bst.save_rabit_checkpoint()?; - } - - assert!(unsafe { xgboost_sys::RabitGetWorldSize() == 1 || version == xgboost_sys::RabitVersionNumber() }); - - //nboost += 1; - + let bst = Booster::new_with_cached_dmats(¶ms.booster_params, &cached_dmats)?; + for i in 0..params.boost_rounds as i32 { if let Some(eval_sets) = params.evaluation_sets { let mut dmat_eval_results = bst.eval_set(eval_sets, i)?; @@ -178,7 +158,8 @@ impl Booster { for (dmat, dmat_name) in eval_sets { let margin = bst.predict_margin(dmat)?; let eval_result = eval_fn(&margin, dmat); - let eval_results = dmat_eval_results.entry(eval_name.to_string()) + let eval_results = dmat_eval_results + .entry(eval_name.to_string()) .or_insert_with(IndexMap::new); eval_results.insert(dmat_name.to_string(), eval_result); } @@ -222,7 +203,11 @@ impl Booster { /// * `dtrain` - matrix to train the model with for a single iteration /// * `iteration` - current iteration number pub fn update(&mut self, dtrain: &DMatrix, iteration: i32) -> XGBResult<()> { - xgb_call!(xgboost_sys::XGBoosterUpdateOneIter(self.handle, iteration, dtrain.handle)) + xgb_call!(xgboost_sys::XGBoosterUpdateOneIter( + self.handle, + iteration, + dtrain.handle + )) } /// Update this model by training it for one round with a custom objective function. @@ -241,8 +226,11 @@ impl Booster { /// * `hessian` - second order gradient fn boost(&mut self, dtrain: &DMatrix, gradient: &[f32], hessian: &[f32]) -> XGBResult<()> { if gradient.len() != hessian.len() { - let msg = format!("Mismatch between length of gradient and hessian arrays ({} != {})", - gradient.len(), hessian.len()); + let msg = format!( + "Mismatch between length of gradient and hessian arrays ({} != {})", + gradient.len(), + hessian.len() + ); return Err(XGBError::new(msg)); } assert_eq!(gradient.len(), hessian.len()); @@ -250,14 +238,20 @@ impl Booster { // TODO: _validate_feature_names let mut grad_vec = gradient.to_vec(); let mut hess_vec = hessian.to_vec(); - xgb_call!(xgboost_sys::XGBoosterBoostOneIter(self.handle, - dtrain.handle, - grad_vec.as_mut_ptr(), - hess_vec.as_mut_ptr(), - grad_vec.len() as u64)) + xgb_call!(xgboost_sys::XGBoosterBoostOneIter( + self.handle, + dtrain.handle, + grad_vec.as_mut_ptr(), + hess_vec.as_mut_ptr(), + grad_vec.len() as u64 + )) } - fn eval_set(&self, evals: &[(&DMatrix, &str)], iteration: i32) -> XGBResult>> { + fn eval_set( + &self, + evals: &[(&DMatrix, &str)], + iteration: i32, + ) -> XGBResult>> { let (dmats, names) = { let mut dmats = Vec::with_capacity(evals.len()); let mut names = Vec::with_capacity(evals.len()); @@ -285,12 +279,14 @@ impl Booster { evptrs.shrink_to_fit(); let mut out_result = ptr::null(); - xgb_call!(xgboost_sys::XGBoosterEvalOneIter(self.handle, - iteration, - s.as_mut_ptr(), - evptrs.as_mut_ptr(), - dmats.len() as u64, - &mut out_result))?; + xgb_call!(xgboost_sys::XGBoosterEvalOneIter( + self.handle, + iteration, + s.as_mut_ptr(), + evptrs.as_mut_ptr(), + dmats.len() as u64, + &mut out_result + ))?; let out = unsafe { ffi::CStr::from_ptr(out_result).to_str().unwrap().to_owned() }; Ok(Booster::parse_eval_string(&out, &names)) } @@ -304,11 +300,9 @@ impl Booster { let name = "default"; let mut eval = self.eval_set(&[(dmat, name)], 0)?; let mut result = HashMap::new(); - eval.remove(name).unwrap() - .into_iter() - .for_each(|(k, v)| { - result.insert(k.to_owned(), v); - }); + eval.swap_remove(name).unwrap().into_iter().for_each(|(k, v)| { + result.insert(k.to_owned(), v); + }); Ok(result) } @@ -318,7 +312,12 @@ impl Booster { let key = ffi::CString::new(key).unwrap(); let mut out_buf = ptr::null(); let mut success = 0; - xgb_call!(xgboost_sys::XGBoosterGetAttr(self.handle, key.as_ptr(), &mut out_buf, &mut success))?; + xgb_call!(xgboost_sys::XGBoosterGetAttr( + self.handle, + key.as_ptr(), + &mut out_buf, + &mut success + ))?; if success == 0 { return Ok(None); } @@ -341,12 +340,16 @@ impl Booster { let mut out_len = 0; let mut out = ptr::null_mut(); xgb_call!(xgboost_sys::XGBoosterGetAttrNames(self.handle, &mut out_len, &mut out))?; - - let out_ptr_slice = unsafe { slice::from_raw_parts(out, out_len as usize) }; - let out_vec = out_ptr_slice.iter() - .map(|str_ptr| unsafe { ffi::CStr::from_ptr(*str_ptr).to_str().unwrap().to_owned() }) - .collect(); - Ok(out_vec) + if out_len > 0 { + let out_ptr_slice = unsafe { slice::from_raw_parts(out, out_len as usize) }; + let out_vec = out_ptr_slice + .iter() + .map(|str_ptr| unsafe { ffi::CStr::from_ptr(*str_ptr).to_str().unwrap().to_owned() }) + .collect(); + Ok(out_vec) + } else { + Ok(Vec::new()) + } } /// Predict results for given data. @@ -357,13 +360,15 @@ impl Booster { let ntree_limit = 0; let mut out_len = 0; let mut out_result = ptr::null(); - xgb_call!(xgboost_sys::XGBoosterPredict(self.handle, - dmat.handle, - option_mask, - ntree_limit, - 0, - &mut out_len, - &mut out_result))?; + xgb_call!(xgboost_sys::XGBoosterPredict( + self.handle, + dmat.handle, + option_mask, + ntree_limit, + 0, + &mut out_len, + &mut out_result + ))?; assert!(!out_result.is_null()); let data = unsafe { slice::from_raw_parts(out_result, out_len as usize).to_vec() }; @@ -378,13 +383,15 @@ impl Booster { let ntree_limit = 0; let mut out_len = 0; let mut out_result = ptr::null(); - xgb_call!(xgboost_sys::XGBoosterPredict(self.handle, - dmat.handle, - option_mask, - ntree_limit, - 1, - &mut out_len, - &mut out_result))?; + xgb_call!(xgboost_sys::XGBoosterPredict( + self.handle, + dmat.handle, + option_mask, + ntree_limit, + 1, + &mut out_len, + &mut out_result + ))?; assert!(!out_result.is_null()); let data = unsafe { slice::from_raw_parts(out_result, out_len as usize).to_vec() }; Ok(data) @@ -400,13 +407,15 @@ impl Booster { let ntree_limit = 0; let mut out_len = 0; let mut out_result = ptr::null(); - xgb_call!(xgboost_sys::XGBoosterPredict(self.handle, - dmat.handle, - option_mask, - ntree_limit, - 0, - &mut out_len, - &mut out_result))?; + xgb_call!(xgboost_sys::XGBoosterPredict( + self.handle, + dmat.handle, + option_mask, + ntree_limit, + 0, + &mut out_len, + &mut out_result + ))?; assert!(!out_result.is_null()); let data = unsafe { slice::from_raw_parts(out_result, out_len as usize).to_vec() }; @@ -427,13 +436,15 @@ impl Booster { let ntree_limit = 0; let mut out_len = 0; let mut out_result = ptr::null(); - xgb_call!(xgboost_sys::XGBoosterPredict(self.handle, - dmat.handle, - option_mask, - ntree_limit, - 0, - &mut out_len, - &mut out_result))?; + xgb_call!(xgboost_sys::XGBoosterPredict( + self.handle, + dmat.handle, + option_mask, + ntree_limit, + 0, + &mut out_len, + &mut out_result + ))?; assert!(!out_result.is_null()); let data = unsafe { slice::from_raw_parts(out_result, out_len as usize).to_vec() }; @@ -455,13 +466,15 @@ impl Booster { let ntree_limit = 0; let mut out_len = 0; let mut out_result = ptr::null(); - xgb_call!(xgboost_sys::XGBoosterPredict(self.handle, - dmat.handle, - option_mask, - ntree_limit, - 0, - &mut out_len, - &mut out_result))?; + xgb_call!(xgboost_sys::XGBoosterPredict( + self.handle, + dmat.handle, + option_mask, + ntree_limit, + 0, + &mut out_len, + &mut out_result + ))?; assert!(!out_result.is_null()); let data = unsafe { slice::from_raw_parts(out_result, out_len as usize).to_vec() }; @@ -482,7 +495,7 @@ impl Booster { Err(err) => return Err(XGBError::new(err.to_string())), }; - let file_path = tmp_dir.path().join("fmap.txt"); + let file_path = tmp_dir.path().join("fmap.json"); let mut file: File = match File::create(&file_path) { Ok(f) => f, Err(err) => return Err(XGBError::new(err.to_string())), @@ -507,36 +520,37 @@ impl Booster { let format = ffi::CString::new("text").unwrap(); let mut out_len = 0; let mut out_dump_array = ptr::null_mut(); - xgb_call!(xgboost_sys::XGBoosterDumpModelEx(self.handle, - fmap.as_ptr(), - with_statistics as i32, - format.as_ptr(), - &mut out_len, - &mut out_dump_array))?; - - let out_ptr_slice = unsafe { slice::from_raw_parts(out_dump_array, out_len as usize) }; - let out_vec: Vec = out_ptr_slice.iter() - .map(|str_ptr| unsafe { ffi::CStr::from_ptr(*str_ptr).to_str().unwrap().to_owned() }) - .collect(); - - assert_eq!(out_len as usize, out_vec.len()); - Ok(out_vec.join("\n")) - } - - pub(crate) fn load_rabit_checkpoint(&self) -> XGBResult { - let mut version = 0; - xgb_call!(xgboost_sys::XGBoosterLoadRabitCheckpoint(self.handle, &mut version))?; - Ok(version) - } - - pub(crate) fn save_rabit_checkpoint(&self) -> XGBResult<()> { - xgb_call!(xgboost_sys::XGBoosterSaveRabitCheckpoint(self.handle)) + xgb_call!(xgboost_sys::XGBoosterDumpModelEx( + self.handle, + fmap.as_ptr(), + with_statistics as i32, + format.as_ptr(), + &mut out_len, + &mut out_dump_array + ))?; + + if out_len > 0 { + let out_ptr_slice = unsafe { slice::from_raw_parts(out_dump_array, out_len as usize) }; + let out_vec: Vec = out_ptr_slice + .iter() + .map(|str_ptr| unsafe { ffi::CStr::from_ptr(*str_ptr).to_str().unwrap().to_owned() }) + .collect(); + + assert_eq!(out_len as usize, out_vec.len()); + Ok(out_vec.join("\n")) + } else { + Ok(String::new()) + } } - fn set_param(&mut self, name: &str, value: &str) -> XGBResult<()> { + pub fn set_param(&mut self, name: &str, value: &str) -> XGBResult<()> { let name = ffi::CString::new(name).unwrap(); let value = ffi::CString::new(value).unwrap(); - xgb_call!(xgboost_sys::XGBoosterSetParam(self.handle, name.as_ptr(), value.as_ptr())) + xgb_call!(xgboost_sys::XGBoosterSetParam( + self.handle, + name.as_ptr(), + value.as_ptr() + )) } fn parse_eval_string(eval: &str, evnames: &[&str]) -> IndexMap> { @@ -546,13 +560,14 @@ impl Booster { for part in eval.split('\t').skip(1) { for evname in evnames { if part.starts_with(evname) { - let metric_parts: Vec<&str> = part[evname.len()+1..].split(':').into_iter().collect(); + let metric_parts: Vec<&str> = part[evname.len() + 1..].split(':').collect(); assert_eq!(metric_parts.len(), 2); let metric = metric_parts[0]; - let score = metric_parts[1].parse::() + let score = metric_parts[1] + .parse::() .unwrap_or_else(|_| panic!("Unable to parse XGBoost metrics output: {}", eval)); - let metric_map = result.entry(evname.to_string()).or_insert_with(IndexMap::new); + let metric_map = result.entry(evname.to_string()).or_default(); metric_map.insert(metric.to_owned(), score); } } @@ -561,7 +576,6 @@ impl Booster { debug!("result: {:?}", &result); result } - } impl Drop for Booster { @@ -603,25 +617,31 @@ impl FeatureMap { let line = line?; let parts: Vec<&str> = line.split('\t').collect(); if parts.len() != 3 { - let msg = format!("Unable to parse features from line {}, expected 3 tab separated values", i+1); + let msg = format!( + "Unable to parse features from line {}, expected 3 tab separated values", + i + 1 + ); return Err(io::Error::new(io::ErrorKind::InvalidData, msg)); } assert_eq!(parts.len(), 3); let feature_num: u32 = match parts[0].parse() { - Ok(num) => num, + Ok(num) => num, Err(err) => { - let msg = format!("Unable to parse features from line {}, could not parse feature number: {}", - i+1, err); + let msg = format!( + "Unable to parse features from line {}, could not parse feature number: {}", + i + 1, + err + ); return Err(io::Error::new(io::ErrorKind::InvalidData, msg)); } }; let feature_name = &parts[1]; - let feature_type = match FeatureType::from_str(&parts[2]) { + let feature_type = match FeatureType::from_str(parts[2]) { Ok(feature_type) => feature_type, - Err(msg) => { - let msg = format!("Unable to parse features from line {}: {}", i+1, msg); + Err(msg) => { + let msg = format!("Unable to parse features from line {}: {}", i + 1, msg); return Err(io::Error::new(io::ErrorKind::InvalidData, msg)); } }; @@ -648,10 +668,13 @@ impl FromStr for FeatureType { fn from_str(s: &str) -> Result { match s { - "i" => Ok(FeatureType::Binary), - "q" => Ok(FeatureType::Quantitative), + "i" => Ok(FeatureType::Binary), + "q" => Ok(FeatureType::Quantitative), "int" => Ok(FeatureType::Integer), - _ => Err(format!("unrecognised feature type '{}', must be one of: 'i', 'q', 'int'", s)) + _ => Err(format!( + "unrecognised feature type '{}', must be one of: 'i', 'q', 'int'", + s + )), } } } @@ -670,10 +693,10 @@ impl fmt::Display for FeatureType { #[cfg(test)] mod tests { use super::*; - use parameters::{self, learning, tree}; + use crate::parameters::{self, learning, tree}; fn read_train_matrix() -> XGBResult { - DMatrix::load("xgboost-sys/xgboost/demo/data/agaricus.txt.train") + DMatrix::load(r#"{"uri": "xgboost-sys/xgboost/demo/data/agaricus.txt.train?format=libsvm"}"#) } fn load_test_booster() -> Booster { @@ -688,12 +711,6 @@ mod tests { assert!(res.is_ok()); } - #[test] - fn load_rabit_version() { - let version = load_test_booster().load_rabit_checkpoint().unwrap(); - assert_eq!(version, 0); - } - #[test] fn get_set_attr() { let mut booster = load_test_booster(); @@ -707,7 +724,8 @@ mod tests { #[test] fn save_and_load_from_buffer() { - let dmat_train = DMatrix::load("xgboost-sys/xgboost/demo/data/agaricus.txt.train").unwrap(); + let dmat_train = + DMatrix::load(r#"{"uri": "xgboost-sys/xgboost/demo/data/agaricus.txt.train?format=libsvm"}"#).unwrap(); let mut booster = Booster::new_with_cached_dmats(&BoosterParameters::default(), &[&dmat_train]).unwrap(); let attr = booster.get_attribute("foo").expect("Getting attribute failed"); assert_eq!(attr, None); @@ -733,9 +751,13 @@ mod tests { assert_eq!(attrs, Vec::::new()); booster.set_attribute("foo", "bar").expect("Setting attribute failed"); - booster.set_attribute("another", "another").expect("Setting attribute failed"); + booster + .set_attribute("another", "another") + .expect("Setting attribute failed"); booster.set_attribute("4", "4").expect("Setting attribute failed"); - booster.set_attribute("an even longer attribute name?", "").expect("Setting attribute failed"); + booster + .set_attribute("an even longer attribute name?", "") + .expect("Setting attribute failed"); let mut expected = vec!["foo", "another", "4", "an even longer attribute name?"]; expected.sort(); @@ -746,8 +768,10 @@ mod tests { #[test] fn predict() { - let dmat_train = DMatrix::load("xgboost-sys/xgboost/demo/data/agaricus.txt.train").unwrap(); - let dmat_test = DMatrix::load("xgboost-sys/xgboost/demo/data/agaricus.txt.test").unwrap(); + let dmat_train = + DMatrix::load(r#"{"uri": "xgboost-sys/xgboost/demo/data/agaricus.txt.train?format=libsvm"}"#).unwrap(); + let dmat_test = + DMatrix::load(r#"{"uri": "xgboost-sys/xgboost/demo/data/agaricus.txt.test?format=libsvm"}"#).unwrap(); let tree_params = tree::TreeBoosterParametersBuilder::default() .max_depth(2) @@ -756,9 +780,11 @@ mod tests { .unwrap(); let learning_params = learning::LearningTaskParametersBuilder::default() .objective(learning::Objective::BinaryLogistic) - .eval_metrics(learning::Metrics::Custom(vec![learning::EvaluationMetric::MAPCutNegative(4), - learning::EvaluationMetric::LogLoss, - learning::EvaluationMetric::BinaryErrorRate(0.5)])) + .eval_metrics(learning::Metrics::Custom(vec![ + learning::EvaluationMetric::MAPCutNegative(4), + learning::EvaluationMetric::LogLoss, + learning::EvaluationMetric::BinaryErrorRate(0.5), + ])) .build() .unwrap(); let params = parameters::BoosterParametersBuilder::default() @@ -774,39 +800,43 @@ mod tests { } let train_metrics = booster.evaluate(&dmat_train).unwrap(); - assert_eq!(*train_metrics.get("logloss").unwrap(), 0.006634); - assert_eq!(*train_metrics.get("map@4-").unwrap(), 0.001274); + assert_eq!(*train_metrics.get("logloss").unwrap(), 0.006634271); + assert_eq!(*train_metrics.get("map@4-").unwrap(), 1.0); let test_metrics = booster.evaluate(&dmat_test).unwrap(); - assert_eq!(*test_metrics.get("logloss").unwrap(), 0.00692); - assert_eq!(*test_metrics.get("map@4-").unwrap(), 0.005155); + assert_eq!(*test_metrics.get("logloss").unwrap(), 0.0069199526); + assert_eq!(*test_metrics.get("map@4-").unwrap(), 1.0); let v = booster.predict(&dmat_test).unwrap(); assert_eq!(v.len(), dmat_test.num_rows()); // first 10 predictions - let expected_start = [0.0050151693, - 0.9884467, - 0.0050151693, - 0.0050151693, - 0.026636455, - 0.11789363, - 0.9884467, - 0.01231471, - 0.9884467, - 0.00013656063]; + let expected_start = [ + 0.0050151693, + 0.9884467, + 0.0050151693, + 0.0050151693, + 0.026636455, + 0.11789363, + 0.9884467, + 0.01231471, + 0.9884467, + 0.00013656063, + ]; // last 10 predictions - let expected_end = [0.002520344, - 0.00060917926, - 0.99881005, - 0.00060917926, - 0.00060917926, - 0.00060917926, - 0.00060917926, - 0.9981102, - 0.002855195, - 0.9981102]; + let expected_end = [ + 0.002520344, + 0.00060917926, + 0.99881005, + 0.00060917926, + 0.00060917926, + 0.00060917926, + 0.00060917926, + 0.9981102, + 0.002855195, + 0.9981102, + ]; let eps = 1e-6; for (pred, expected) in v.iter().zip(&expected_start) { @@ -814,7 +844,7 @@ mod tests { assert!(pred - expected < eps); } - for (pred, expected) in v[v.len()-10..].iter().zip(&expected_end) { + for (pred, expected) in v[v.len() - 10..].iter().zip(&expected_end) { println!("predictions={}, expected={}", pred, expected); assert!(pred - expected < eps); } @@ -822,8 +852,10 @@ mod tests { #[test] fn predict_leaf() { - let dmat_train = DMatrix::load("xgboost-sys/xgboost/demo/data/agaricus.txt.train").unwrap(); - let dmat_test = DMatrix::load("xgboost-sys/xgboost/demo/data/agaricus.txt.test").unwrap(); + let dmat_train = + DMatrix::load(r#"{"uri": "xgboost-sys/xgboost/demo/data/agaricus.txt.train?format=libsvm"}"#).unwrap(); + let dmat_test = + DMatrix::load(r#"{"uri": "xgboost-sys/xgboost/demo/data/agaricus.txt.test?format=libsvm"}"#).unwrap(); let tree_params = tree::TreeBoosterParametersBuilder::default() .max_depth(2) @@ -855,8 +887,10 @@ mod tests { #[test] fn predict_contributions() { - let dmat_train = DMatrix::load("xgboost-sys/xgboost/demo/data/agaricus.txt.train").unwrap(); - let dmat_test = DMatrix::load("xgboost-sys/xgboost/demo/data/agaricus.txt.test").unwrap(); + let dmat_train = + DMatrix::load(r#"{"uri": "xgboost-sys/xgboost/demo/data/agaricus.txt.train?format=libsvm"}"#).unwrap(); + let dmat_test = + DMatrix::load(r#"{"uri": "xgboost-sys/xgboost/demo/data/agaricus.txt.test?format=libsvm"}"#).unwrap(); let tree_params = tree::TreeBoosterParametersBuilder::default() .max_depth(2) @@ -889,8 +923,10 @@ mod tests { #[test] fn predict_interactions() { - let dmat_train = DMatrix::load("xgboost-sys/xgboost/demo/data/agaricus.txt.train").unwrap(); - let dmat_test = DMatrix::load("xgboost-sys/xgboost/demo/data/agaricus.txt.test").unwrap(); + let dmat_train = + DMatrix::load(r#"{"uri": "xgboost-sys/xgboost/demo/data/agaricus.txt.train?format=libsvm"}"#).unwrap(); + let dmat_test = + DMatrix::load(r#"{"uri": "xgboost-sys/xgboost/demo/data/agaricus.txt.test?format=libsvm"}"#).unwrap(); let tree_params = tree::TreeBoosterParametersBuilder::default() .max_depth(2) @@ -941,105 +977,109 @@ mod tests { #[test] fn dump_model() { - let dmat_train = DMatrix::load("xgboost-sys/xgboost/demo/data/agaricus.txt.train").unwrap(); + let dmat_train = + DMatrix::load(r#"{"uri": "xgboost-sys/xgboost/demo/data/agaricus.txt.train?format=libsvm"}"#).unwrap(); println!("{:?}", dmat_train.shape()); let tree_params = tree::TreeBoosterParametersBuilder::default() .max_depth(2) .eta(1.0) - .build().unwrap(); + .build() + .unwrap(); let learning_params = learning::LearningTaskParametersBuilder::default() .objective(learning::Objective::BinaryLogistic) - .build().unwrap(); + .build() + .unwrap(); let booster_params = parameters::BoosterParametersBuilder::default() .booster_type(parameters::BoosterType::Tree(tree_params)) .learning_params(learning_params) .verbose(false) - .build().unwrap(); + .build() + .unwrap(); let training_params = parameters::TrainingParametersBuilder::default() .booster_params(booster_params) .dtrain(&dmat_train) .boost_rounds(10) - .build().unwrap(); + .build() + .unwrap(); let booster = Booster::train(&training_params).unwrap(); - let features = FeatureMap::from_file("xgboost-sys/xgboost/demo/data/featmap.txt") - .expect("failed to parse feature map file"); - - assert_eq!(booster.dump_model(true, Some(&features)).unwrap(), -"0:[odor=none] yes=2,no=1,gain=4000.53101,cover=1628.25 -1:[stalk-root=club] yes=4,no=3,gain=1158.21204,cover=924.5 - 3:leaf=1.71217716,cover=812 - 4:leaf=-1.70044053,cover=112.5 -2:[spore-print-color=green] yes=6,no=5,gain=198.173828,cover=703.75 - 5:leaf=-1.94070864,cover=690.5 - 6:leaf=1.85964918,cover=13.25 - -0:[stalk-root=rooted] yes=2,no=1,gain=832.545044,cover=788.852051 -1:[odor=none] yes=4,no=3,gain=569.725098,cover=768.389709 - 3:leaf=0.78471756,cover=458.936859 - 4:leaf=-0.968530357,cover=309.45282 - 2:leaf=-6.23624468,cover=20.462389 - -0:[ring-type=pendant] yes=2,no=1,gain=368.744568,cover=457.069458 -1:[stalk-surface-below-ring=scaly] yes=4,no=3,gain=226.33696,cover=221.051468 - 3:leaf=0.658725023,cover=212.999451 - 4:leaf=5.77228642,cover=8.05200672 -2:[spore-print-color=purple] yes=6,no=5,gain=258.184265,cover=236.018005 - 5:leaf=-0.791407049,cover=233.487625 - 6:leaf=-9.421422,cover=2.53038669 - -0:[odor=foul] yes=2,no=1,gain=140.486069,cover=364.119354 -1:[gill-size=broad] yes=4,no=3,gain=139.860504,cover=274.101959 - 3:leaf=0.614153326,cover=95.8599854 - 4:leaf=-0.877905607,cover=178.241974 - 2:leaf=1.07747853,cover=90.0174103 - -0:[spore-print-color=green] yes=2,no=1,gain=112.605011,cover=189.202194 -1:[gill-spacing=close] yes=4,no=3,gain=66.4029999,cover=177.771835 - 3:leaf=-1.26934469,cover=42.277401 - 4:leaf=0.152607277,cover=135.494431 - 2:leaf=2.92190909,cover=11.4303684 - -0:[odor=almond] yes=2,no=1,gain=52.5610275,cover=170.612762 -1:[odor=anise] yes=4,no=3,gain=67.3869553,cover=150.881165 - 3:leaf=0.431742132,cover=131.902222 - 4:leaf=-1.53846073,cover=18.9789505 -2:[gill-spacing=close] yes=6,no=5,gain=12.4420624,cover=19.731596 - 5:leaf=-3.02413678,cover=3.65769386 - 6:leaf=-1.02315068,cover=16.0739021 - -0:[odor=none] yes=2,no=1,gain=66.2389145,cover=142.360611 -1:[odor=anise] yes=4,no=3,gain=31.2294312,cover=72.7557373 - 3:leaf=0.777142286,cover=64.5309982 - 4:leaf=-1.19710124,cover=8.22473907 -2:[spore-print-color=green] yes=6,no=5,gain=12.1987419,cover=69.6048737 - 5:leaf=-0.912605286,cover=66.1211166 - 6:leaf=0.836115122,cover=3.48375821 - -0:[gill-size=broad] yes=2,no=1,gain=20.6531773,cover=79.4027634 -1:[spore-print-color=white] yes=4,no=3,gain=16.0703697,cover=34.9289207 - 3:leaf=-0.0180106498,cover=25.0319824 - 4:leaf=1.4361918,cover=9.89693928 -2:[odor=foul] yes=6,no=5,gain=22.1144333,cover=44.4738464 - 5:leaf=-0.908311546,cover=36.982872 - 6:leaf=0.890622675,cover=7.49097395 - -0:[odor=almond] yes=2,no=1,gain=11.7128553,cover=53.3251991 -1:[ring-type=pendant] yes=4,no=3,gain=12.546154,cover=44.299942 - 3:leaf=-0.515293062,cover=15.7899179 - 4:leaf=0.56883812,cover=28.5100231 - 2:leaf=-1.01502442,cover=9.02525806 - -0:[population=clustered] yes=2,no=1,gain=14.8892794,cover=45.9312019 -1:[odor=none] yes=4,no=3,gain=10.1308851,cover=43.0564575 - 3:leaf=0.217203051,cover=22.3283749 - 4:leaf=-0.734555721,cover=20.7280827 -2:[stalk-root=missing] yes=6,no=5,gain=19.3462334,cover=2.87474418 - 5:leaf=3.63442755,cover=1.34154534 - 6:leaf=-0.609474957,cover=1.53319895 -"); + assert_eq!( + booster.dump_model(true, None).unwrap(), + "0:[f29<2.00001001] yes=1,no=2,missing=2,gain=4000.53101,cover=1628.25 + 1:[f109<2.00001001] yes=3,no=4,missing=4,gain=198.173828,cover=703.75 + 3:leaf=1.85964918,cover=13.25 + 4:leaf=-1.94070864,cover=690.5 + 2:[f56<2.00001001] yes=5,no=6,missing=6,gain=1158.21204,cover=924.5 + 5:leaf=-1.70044053,cover=112.5 + 6:leaf=1.71217716,cover=812 + +0:[f60<2.00001001] yes=1,no=2,missing=2,gain=832.544983,cover=788.852051 + 1:leaf=-6.23624468,cover=20.462389 + 2:[f29<2.00001001] yes=3,no=4,missing=4,gain=569.725098,cover=768.389709 + 3:leaf=-0.968530357,cover=309.45282 + 4:leaf=0.78471756,cover=458.936859 + +0:[f102<2.00001001] yes=1,no=2,missing=2,gain=368.744568,cover=457.069458 + 1:[f111<2.00001001] yes=3,no=4,missing=4,gain=258.184326,cover=236.018005 + 3:leaf=-9.421422,cover=2.53038669 + 4:leaf=-0.791407049,cover=233.487625 + 2:[f67<2.00001001] yes=5,no=6,missing=6,gain=226.336975,cover=221.051468 + 5:leaf=5.77228642,cover=8.05200672 + 6:leaf=0.658725023,cover=212.999451 + +0:[f27<2.00001001] yes=1,no=2,missing=2,gain=140.486053,cover=364.119354 + 1:leaf=1.07747853,cover=90.0174103 + 2:[f39<2.00001001] yes=3,no=4,missing=4,gain=139.860519,cover=274.101959 + 3:leaf=-0.877905607,cover=178.241974 + 4:leaf=0.614153326,cover=95.8599854 + +0:[f109<2.00001001] yes=1,no=2,missing=2,gain=112.605019,cover=189.202194 + 1:leaf=2.92190909,cover=11.4303684 + 2:[f36<2.00001001] yes=3,no=4,missing=4,gain=66.4029999,cover=177.771835 + 3:leaf=0.152607277,cover=135.494431 + 4:leaf=-1.26934469,cover=42.277401 + +0:[f23<2.00001001] yes=1,no=2,missing=2,gain=52.5610313,cover=170.612762 + 1:[f36<2.00001001] yes=3,no=4,missing=4,gain=12.4420547,cover=19.731596 + 3:leaf=-1.02315068,cover=16.0739021 + 4:leaf=-3.02413678,cover=3.65769386 + 2:[f24<2.00001001] yes=5,no=6,missing=6,gain=67.3869553,cover=150.881165 + 5:leaf=-1.53846073,cover=18.9789505 + 6:leaf=0.431742132,cover=131.902222 + +0:[f29<2.00001001] yes=1,no=2,missing=2,gain=66.2389145,cover=142.360611 + 1:[f109<2.00001001] yes=3,no=4,missing=4,gain=12.1987419,cover=69.6048737 + 3:leaf=0.836115122,cover=3.48375821 + 4:leaf=-0.912605286,cover=66.1211166 + 2:[f24<2.00001001] yes=5,no=6,missing=6,gain=31.229435,cover=72.7557373 + 5:leaf=-1.19710124,cover=8.22473907 + 6:leaf=0.777142286,cover=64.5309982 + +0:[f39<2.00001001] yes=1,no=2,missing=2,gain=20.6531773,cover=79.4027634 + 1:[f27<2.00001001] yes=3,no=4,missing=4,gain=22.1144371,cover=44.4738464 + 3:leaf=0.890622675,cover=7.49097395 + 4:leaf=-0.908311546,cover=36.982872 + 2:[f112<2.00001001] yes=5,no=6,missing=6,gain=16.0703697,cover=34.9289207 + 5:leaf=1.4361918,cover=9.89693928 + 6:leaf=-0.0180106498,cover=25.0319824 + +0:[f23<2.00001001] yes=1,no=2,missing=2,gain=11.7128553,cover=53.3251991 + 1:leaf=-1.01502442,cover=9.02525806 + 2:[f102<2.00001001] yes=3,no=4,missing=4,gain=12.5461531,cover=44.299942 + 3:leaf=0.56883812,cover=28.5100231 + 4:leaf=-0.515293062,cover=15.7899179 + +0:[f115<2.00001001] yes=1,no=2,missing=2,gain=14.8892794,cover=45.9312019 + 1:[f61<2.00001001] yes=3,no=4,missing=4,gain=19.3462334,cover=2.87474418 + 3:leaf=-0.609474957,cover=1.53319895 + 4:leaf=3.63442755,cover=1.34154534 + 2:[f29<2.00001001] yes=5,no=6,missing=6,gain=10.1308861,cover=43.0564575 + 5:leaf=-0.734555721,cover=20.7280827 + 6:leaf=0.217203051,cover=22.3283749 +" + ); } } diff --git a/src/dmatrix.rs b/src/dmatrix.rs index c67a793..4c0b959 100644 --- a/src/dmatrix.rs +++ b/src/dmatrix.rs @@ -1,17 +1,16 @@ -use std::{slice, ffi, ptr, path::Path}; -use libc::{c_uint, c_float}; +use libc::{c_float, c_uint}; use std::os::unix::ffi::OsStrExt; -use std::convert::TryInto; +use std::{ffi, path::Path, ptr, slice}; use xgboost_sys; -use super::{XGBResult, XGBError}; +use super::{XGBError, XGBResult}; -static KEY_GROUP_PTR: &'static str = "group_ptr"; -static KEY_GROUP: &'static str = "group"; -static KEY_LABEL: &'static str = "label"; -static KEY_WEIGHT: &'static str = "weight"; -static KEY_BASE_MARGIN: &'static str = "base_margin"; +static KEY_GROUP_PTR: &str = "group_ptr"; +static KEY_GROUP: &str = "group"; +static KEY_LABEL: &str = "label"; +static KEY_WEIGHT: &str = "weight"; +static KEY_BASE_MARGIN: &str = "base_margin"; /// Data matrix used throughout XGBoost for training/predicting [`Booster`](struct.Booster.html) models. /// @@ -31,7 +30,7 @@ static KEY_BASE_MARGIN: &'static str = "base_margin"; /// ```should_panic /// use xgboost::DMatrix; /// -/// let dmat = DMatrix::load("somefile.txt").unwrap(); +/// let dmat = DMatrix::load(r#"{"uri": "somefile.txt?format=csv"}"#).unwrap(); /// ``` /// /// ## Create from dense array @@ -62,12 +61,13 @@ static KEY_BASE_MARGIN: &'static str = "base_margin"; /// ``` /// use xgboost::DMatrix; /// -/// let indptr = &[0, 2, 3, 6]; +/// let indptr = &[0, 1, 2, 6]; /// let indices = &[0, 2, 2, 0, 1, 2]; /// let data = &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; -/// let dmat = DMatrix::from_csr(indptr, indices, data, None).unwrap(); +/// let dmat = DMatrix::from_csc(indptr, indices, data, None).unwrap(); /// assert_eq!(dmat.shape(), (3, 3)); /// ``` +#[derive(Debug)] pub struct DMatrix { pub(super) handle: xgboost_sys::DMatrixHandle, num_rows: usize, @@ -88,7 +88,11 @@ impl DMatrix { let num_cols = out as usize; info!("Loaded DMatrix with shape: {}x{}", num_rows, num_cols); - Ok(DMatrix { handle, num_rows, num_cols }) + Ok(DMatrix { + handle, + num_rows, + num_cols, + }) } /// Create a new `DMatrix` from dense array in row-major order. @@ -109,12 +113,14 @@ impl DMatrix { /// ``` pub fn from_dense(data: &[f32], num_rows: usize) -> XGBResult { let mut handle = ptr::null_mut(); - xgb_call!(xgboost_sys::XGDMatrixCreateFromMat(data.as_ptr(), - num_rows as xgboost_sys::bst_ulong, - (data.len() / num_rows) as xgboost_sys::bst_ulong, - f32::NAN, - &mut handle))?; - Ok(DMatrix::new(handle)?) + xgb_call!(xgboost_sys::XGDMatrixCreateFromMat( + data.as_ptr(), + num_rows as xgboost_sys::bst_ulong, + (data.len() / num_rows) as xgboost_sys::bst_ulong, + f32::NAN, + &mut handle + ))?; + DMatrix::new(handle) } /// Create a new `DMatrix` from a sparse @@ -128,17 +134,18 @@ impl DMatrix { pub fn from_csr(indptr: &[usize], indices: &[usize], data: &[f32], num_cols: Option) -> XGBResult { assert_eq!(indices.len(), data.len()); let mut handle = ptr::null_mut(); - let indptr: Vec = indptr.iter().map(|x| *x as u64).collect(); let indices: Vec = indices.iter().map(|x| *x as u32).collect(); let num_cols = num_cols.unwrap_or(0); // infer from data if 0 - xgb_call!(xgboost_sys::XGDMatrixCreateFromCSREx(indptr.as_ptr(), - indices.as_ptr(), - data.as_ptr(), - indptr.len().try_into().unwrap(), - data.len().try_into().unwrap(), - num_cols.try_into().unwrap(), - &mut handle))?; - Ok(DMatrix::new(handle)?) + xgb_call!(xgboost_sys::XGDMatrixCreateFromCSREx( + indptr.as_ptr(), + indices.as_ptr(), + data.as_ptr(), + indptr.len(), + data.len(), + num_cols, + &mut handle + ))?; + DMatrix::new(handle) } /// Create a new `DMatrix` from a sparse @@ -152,17 +159,18 @@ impl DMatrix { pub fn from_csc(indptr: &[usize], indices: &[usize], data: &[f32], num_rows: Option) -> XGBResult { assert_eq!(indices.len(), data.len()); let mut handle = ptr::null_mut(); - let indptr: Vec = indptr.iter().map(|x| *x as u64).collect(); let indices: Vec = indices.iter().map(|x| *x as u32).collect(); let num_rows = num_rows.unwrap_or(0); // infer from data if 0 - xgb_call!(xgboost_sys::XGDMatrixCreateFromCSCEx(indptr.as_ptr(), - indices.as_ptr(), - data.as_ptr(), - indptr.len().try_into().unwrap(), - data.len().try_into().unwrap(), - num_rows.try_into().unwrap(), - &mut handle))?; - Ok(DMatrix::new(handle)?) + xgb_call!(xgboost_sys::XGDMatrixCreateFromCSCEx( + indptr.as_ptr(), + indices.as_ptr(), + data.as_ptr(), + indptr.len(), + data.len(), + num_rows, + &mut handle + ))?; + DMatrix::new(handle) } /// Create a new `DMatrix` from given file. @@ -191,9 +199,16 @@ impl DMatrix { debug!("Loading DMatrix from: {}", path.as_ref().display()); let mut handle = ptr::null_mut(); let fname = ffi::CString::new(path.as_ref().as_os_str().as_bytes()).unwrap(); - let silent = true; - xgb_call!(xgboost_sys::XGDMatrixCreateFromFile(fname.as_ptr(), silent as i32, &mut handle))?; - Ok(DMatrix::new(handle)?) + xgb_call!(xgboost_sys::XGDMatrixCreateFromURI(fname.as_ptr(), &mut handle))?; + DMatrix::new(handle) + } + + pub fn load_binary>(path: P) -> XGBResult { + debug!("Loading DMatrix from: {}", path.as_ref().display()); + let mut handle = ptr::null_mut(); + let fname = ffi::CString::new(path.as_ref().as_os_str().as_bytes()).unwrap(); + xgb_call!(xgboost_sys::XGDMatrixCreateFromFile(fname.as_ptr(), 1, &mut handle)).unwrap(); + DMatrix::new(handle) } /// Serialise this `DMatrix` as a binary file to given path. @@ -201,7 +216,11 @@ impl DMatrix { debug!("Writing DMatrix to: {}", path.as_ref().display()); let fname = ffi::CString::new(path.as_ref().as_os_str().as_bytes()).unwrap(); let silent = true; - xgb_call!(xgboost_sys::XGDMatrixSaveBinary(self.handle, fname.as_ptr(), silent as i32)) + xgb_call!(xgboost_sys::XGDMatrixSaveBinary( + self.handle, + fname.as_ptr(), + silent as i32 + )) } /// Get the number of rows in this matrix. @@ -224,11 +243,13 @@ impl DMatrix { debug!("Slicing {} rows from DMatrix", indices.len()); let mut out_handle = ptr::null_mut(); let indices: Vec = indices.iter().map(|x| *x as i32).collect(); - xgb_call!(xgboost_sys::XGDMatrixSliceDMatrix(self.handle, - indices.as_ptr(), - indices.len() as xgboost_sys::bst_ulong, - &mut out_handle))?; - Ok(DMatrix::new(out_handle)?) + xgb_call!(xgboost_sys::XGDMatrixSliceDMatrix( + self.handle, + indices.as_ptr(), + indices.len() as xgboost_sys::bst_ulong, + &mut out_handle + ))?; + DMatrix::new(out_handle) } /// Get ground truth labels for each row of this matrix. @@ -282,44 +303,55 @@ impl DMatrix { self.get_uint_info(KEY_GROUP_PTR) } - fn get_float_info(&self, field: &str) -> XGBResult<&[f32]> { let field = ffi::CString::new(field).unwrap(); let mut out_len = 0; let mut out_dptr = ptr::null(); - xgb_call!(xgboost_sys::XGDMatrixGetFloatInfo(self.handle, - field.as_ptr(), - &mut out_len, - &mut out_dptr))?; + xgb_call!(xgboost_sys::XGDMatrixGetFloatInfo( + self.handle, + field.as_ptr(), + &mut out_len, + &mut out_dptr + ))?; - Ok(unsafe { slice::from_raw_parts(out_dptr as *mut c_float, out_len as usize) }) + if out_len > 0 { + Ok(unsafe { slice::from_raw_parts(out_dptr as *mut c_float, out_len as usize) }) + } else { + Err(XGBError::new("error")) + } } fn set_float_info(&mut self, field: &str, array: &[f32]) -> XGBResult<()> { let field = ffi::CString::new(field).unwrap(); - xgb_call!(xgboost_sys::XGDMatrixSetFloatInfo(self.handle, - field.as_ptr(), - array.as_ptr(), - array.len() as u64)) + xgb_call!(xgboost_sys::XGDMatrixSetFloatInfo( + self.handle, + field.as_ptr(), + array.as_ptr(), + array.len() as u64 + )) } fn get_uint_info(&self, field: &str) -> XGBResult<&[u32]> { let field = ffi::CString::new(field).unwrap(); let mut out_len = 0; let mut out_dptr = ptr::null(); - xgb_call!(xgboost_sys::XGDMatrixGetUIntInfo(self.handle, - field.as_ptr(), - &mut out_len, - &mut out_dptr))?; + xgb_call!(xgboost_sys::XGDMatrixGetUIntInfo( + self.handle, + field.as_ptr(), + &mut out_len, + &mut out_dptr + ))?; Ok(unsafe { slice::from_raw_parts(out_dptr as *mut c_uint, out_len as usize) }) } fn set_uint_info(&mut self, field: &str, array: &[u32]) -> XGBResult<()> { let field = ffi::CString::new(field).unwrap(); - xgb_call!(xgboost_sys::XGDMatrixSetUIntInfo(self.handle, - field.as_ptr(), - array.as_ptr(), - array.len() as u64)) + xgb_call!(xgboost_sys::XGDMatrixSetUIntInfo( + self.handle, + field.as_ptr(), + array.as_ptr(), + array.len() as u64 + )) } } @@ -331,10 +363,10 @@ impl Drop for DMatrix { #[cfg(test)] mod tests { - use tempfile; use super::*; + use tempfile; fn read_train_matrix() -> XGBResult { - DMatrix::load("xgboost-sys/xgboost/demo/data/agaricus.txt.train") + DMatrix::load(r#"{"uri": "xgboost-sys/xgboost/demo/data/agaricus.txt.train?format=libsvm"}"#) } #[test] @@ -349,7 +381,7 @@ mod tests { #[test] fn read_num_cols() { - assert_eq!(read_train_matrix().unwrap().num_cols(), 126); + assert_eq!(read_train_matrix().unwrap().num_cols(), 127); } #[test] @@ -360,7 +392,7 @@ mod tests { let out_path = tmp_dir.path().join("dmat.bin"); dmat.save(&out_path).unwrap(); - let dmat2 = DMatrix::load(&out_path).unwrap(); + let dmat2 = DMatrix::load_binary(out_path).unwrap(); assert_eq!(dmat.num_rows(), dmat2.num_rows()); assert_eq!(dmat.num_cols(), dmat2.num_cols()); @@ -370,17 +402,21 @@ mod tests { #[test] fn get_set_labels() { let mut dmat = read_train_matrix().unwrap(); - assert_eq!(dmat.get_labels().unwrap().len(), 6513); + let labels = dmat.get_labels(); + assert!(labels.is_ok()); + let mut labels = labels.unwrap().to_vec(); + assert_eq!(labels.len(), 6513); - let label = [0.1, 0.0 -4.5, 11.29842, 333333.33]; - assert!(dmat.set_labels(&label).is_ok()); - assert_eq!(dmat.get_labels().unwrap(), label); + labels[0] = 0.1; + assert_ne!(dmat.get_labels().unwrap(), labels); + assert!(dmat.set_labels(&labels).is_ok()); + assert_eq!(dmat.get_labels().unwrap(), labels); } #[test] fn get_set_weights() { let mut dmat = read_train_matrix().unwrap(); - assert_eq!(dmat.get_weights().unwrap(), &[]); + assert!(dmat.get_weights().unwrap().is_empty()); let weight = [1.0, 10.0, 44.9555]; assert!(dmat.set_weights(&weight).is_ok()); @@ -390,9 +426,11 @@ mod tests { #[test] fn get_set_base_margin() { let mut dmat = read_train_matrix().unwrap(); - assert_eq!(dmat.get_base_margin().unwrap(), &[]); + let base_margin = dmat.get_base_margin(); + assert!(base_margin.is_ok()); + assert!(base_margin.unwrap().is_empty()); - let base_margin = [0.00001, 0.000002, 1.23]; + let base_margin = vec![0.00001; dmat.num_rows()]; assert!(dmat.set_base_margin(&base_margin).is_ok()); assert_eq!(dmat.get_base_margin().unwrap(), base_margin); } @@ -400,7 +438,7 @@ mod tests { #[test] fn get_set_group() { let mut dmat = read_train_matrix().unwrap(); - assert_eq!(dmat.get_group().unwrap(), &[]); + assert!(dmat.get_group().unwrap().is_empty()); let group = [1]; assert!(dmat.set_group(&group).is_ok()); @@ -415,7 +453,7 @@ mod tests { let dmat = DMatrix::from_csr(&indptr, &indices, &data, None).unwrap(); assert_eq!(dmat.num_rows(), 4); - assert_eq!(dmat.num_cols(), 0); // https://github.com/dmlc/xgboost/pull/7265 + assert_eq!(dmat.num_cols(), 0); // https://github.com/dmlc/xgboost/pull/7265 let dmat = DMatrix::from_csr(&indptr, &indices, &data, Some(10)).unwrap(); assert_eq!(dmat.num_rows(), 4); @@ -466,7 +504,8 @@ mod tests { assert_eq!(dmat.slice(&[1]).unwrap().shape(), (1, 2)); assert_eq!(dmat.slice(&[0, 1]).unwrap().shape(), (2, 2)); assert_eq!(dmat.slice(&[3, 2, 1]).unwrap().shape(), (3, 2)); - assert_eq!(dmat.slice(&[10, 11, 12]).unwrap().shape(), (3, 2)); + // slicing out of bounds is not safe and can cause a segfault + // assert_eq!(dmat.slice(&[10, 11, 12]).unwrap().shape(), (3, 2)); } #[test] diff --git a/src/error.rs b/src/error.rs index 5059eea..b379400 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,9 +1,9 @@ //! Functionality related to errors and error handling. use std; +use std::error::Error; use std::ffi::CStr; use std::fmt::{self, Display}; -use std::error::Error; use xgboost_sys; @@ -29,9 +29,9 @@ impl XGBError { /// Meaning of any other return values are undefined, and will cause a panic. pub(crate) fn check_return_value(ret_val: i32) -> XGBResult<()> { match ret_val { - 0 => Ok(()), + 0 => Ok(()), -1 => Err(XGBError::from_xgboost()), - _ => panic!("unexpected return value '{}', expected 0 or -1", ret_val), + _ => panic!("unexpected return value '{}', expected 0 or -1", ret_val), } } @@ -39,7 +39,9 @@ impl XGBError { fn from_xgboost() -> Self { let c_str = unsafe { CStr::from_ptr(xgboost_sys::XGBGetLastError()) }; let str_slice = c_str.to_str().unwrap(); - XGBError { desc: str_slice.to_owned() } + XGBError { + desc: str_slice.to_owned(), + } } } diff --git a/src/lib.rs b/src/lib.rs index 5ba0ee9..b1344e6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -60,10 +60,10 @@ extern crate derive_builder; #[macro_use] extern crate log; -extern crate xgboost_sys; +extern crate indexmap; extern crate libc; extern crate tempfile; -extern crate indexmap; +extern crate xgboost_sys; macro_rules! xgb_call { ($x:expr) => { @@ -72,7 +72,7 @@ macro_rules! xgb_call { } mod error; -pub use error::{XGBResult, XGBError}; +pub use error::{XGBError, XGBResult}; mod dmatrix; pub use dmatrix::DMatrix; diff --git a/src/parameters/booster.rs b/src/parameters/booster.rs index 1b56a64..dcd1b1c 100644 --- a/src/parameters/booster.rs +++ b/src/parameters/booster.rs @@ -20,7 +20,7 @@ //! ``` use std::default::Default; -use super::{tree, linear, dart}; +use super::{dart, linear, tree}; /// Type of booster to use when training a [Booster](../struct.Booster.html) model. #[derive(Clone)] @@ -46,7 +46,9 @@ pub enum BoosterType { } impl Default for BoosterType { - fn default() -> Self { BoosterType::Tree(tree::TreeBoosterParameters::default()) } + fn default() -> Self { + BoosterType::Tree(tree::TreeBoosterParameters::default()) + } } impl BoosterType { @@ -54,7 +56,7 @@ impl BoosterType { match *self { BoosterType::Tree(ref p) => p.as_string_pairs(), BoosterType::Linear(ref p) => p.as_string_pairs(), - BoosterType::Dart(ref p) => p.as_string_pairs() + BoosterType::Dart(ref p) => p.as_string_pairs(), } } } diff --git a/src/parameters/dart.rs b/src/parameters/dart.rs index bf7f942..42f254e 100644 --- a/src/parameters/dart.rs +++ b/src/parameters/dart.rs @@ -6,9 +6,10 @@ use std::default::Default; use super::Interval; /// Type of sampling algorithm. -#[derive(Clone)] +#[derive(Clone, Default)] pub enum SampleType { /// Dropped trees are selected uniformly. + #[default] Uniform, /// Dropped trees are selected in proportion to weight. @@ -24,16 +25,13 @@ impl ToString for SampleType { } } -impl Default for SampleType { - fn default() -> Self { SampleType::Uniform } -} - /// Type of normalization algorithm. -#[derive(Clone)] +#[derive(Clone, Default)] pub enum NormalizeType { /// New trees have the same weight of each of dropped trees. /// * weight of new trees are 1 / (k + learning_rate) /// dropped trees are scaled by a factor of k / (k + learning_rate) + #[default] Tree, /// New trees have the same weight of sum of dropped trees (forest). @@ -52,10 +50,6 @@ impl ToString for NormalizeType { } } -impl Default for NormalizeType { - fn default() -> Self { NormalizeType::Tree } -} - /// Additional parameters for Dart Booster. #[derive(Builder, Clone)] #[builder(build_fn(validate = "Self::validate"))] @@ -96,17 +90,14 @@ impl Default for DartBoosterParameters { impl DartBoosterParameters { pub(crate) fn as_string_pairs(&self) -> Vec<(String, String)> { - let mut v = Vec::new(); - - v.push(("booster".to_owned(), "dart".to_owned())); - - v.push(("sample_type".to_owned(), self.sample_type.to_string())); - v.push(("normalize_type".to_owned(), self.normalize_type.to_string())); - v.push(("rate_drop".to_owned(), self.rate_drop.to_string())); - v.push(("one_drop".to_owned(), (self.one_drop as u8).to_string())); - v.push(("skip_drop".to_owned(), self.skip_drop.to_string())); - - v + vec![ + ("booster".to_owned(), "dart".to_owned()), + ("sample_type".to_owned(), self.sample_type.to_string()), + ("normalize_type".to_owned(), self.normalize_type.to_string()), + ("rate_drop".to_owned(), self.rate_drop.to_string()), + ("one_drop".to_owned(), (self.one_drop as u8).to_string()), + ("skip_drop".to_owned(), self.skip_drop.to_string()), + ] } } diff --git a/src/parameters/learning.rs b/src/parameters/learning.rs index ca88e22..828e70e 100644 --- a/src/parameters/learning.rs +++ b/src/parameters/learning.rs @@ -7,8 +7,10 @@ use std::default::Default; use super::Interval; /// Learning objective used when training a booster model. +#[derive(Default)] pub enum Objective { /// Linear regression. + #[default] RegLinear, /// Logistic regression. @@ -71,17 +73,19 @@ pub enum Objective { impl Copy for Objective {} impl Clone for Objective { - fn clone(&self) -> Self { *self } + fn clone(&self) -> Self { + *self + } } impl ToString for Objective { fn to_string(&self) -> String { match *self { - Objective::RegLinear => "reg:linear".to_owned(), + Objective::RegLinear => "reg:squarederror".to_owned(), Objective::RegLogistic => "reg:logistic".to_owned(), Objective::BinaryLogistic => "binary:logistic".to_owned(), Objective::BinaryLogisticRaw => "binary:logitraw".to_owned(), - Objective::GpuRegLinear => "gpu:reg:linear".to_owned(), + Objective::GpuRegLinear => "gpu:reg:squarederror".to_owned(), Objective::GpuRegLogistic => "gpu:reg:logistic".to_owned(), Objective::GpuBinaryLogistic => "gpu:binary:logistic".to_owned(), Objective::GpuBinaryLogisticRaw => "gpu:binary:logitraw".to_owned(), @@ -96,10 +100,6 @@ impl ToString for Objective { } } -impl Default for Objective { - fn default() -> Self { Objective::RegLinear } -} - /// Type of evaluation metrics to use during learning. #[derive(Clone)] pub enum Metrics { @@ -191,23 +191,23 @@ impl ToString for EvaluationMetric { } else { format!("error@{}", t) } - }, + } EvaluationMetric::MultiClassErrorRate => "merror".to_owned(), - EvaluationMetric::MultiClassLogLoss => "mlogloss".to_owned(), - EvaluationMetric::AUC => "auc".to_owned(), - EvaluationMetric::NDCG => "ndcg".to_owned(), - EvaluationMetric::NDCGCut(n) => format!("ndcg@{}", n), - EvaluationMetric::NDCGNegative => "ndcg-".to_owned(), - EvaluationMetric::NDCGCutNegative(n) => format!("ndcg@{}-", n), - EvaluationMetric::MAP => "map".to_owned(), - EvaluationMetric::MAPCut(n) => format!("map@{}", n), - EvaluationMetric::MAPNegative => "map-".to_owned(), - EvaluationMetric::MAPCutNegative(n) => format!("map@{}-", n), - EvaluationMetric::PoissonLogLoss => "poisson-nloglik".to_owned(), - EvaluationMetric::GammaLogLoss => "gamma-nloglik".to_owned(), - EvaluationMetric::CoxLogLoss => "cox-nloglik".to_owned(), - EvaluationMetric::GammaDeviance => "gamma-deviance".to_owned(), - EvaluationMetric::TweedieLogLoss => "tweedie-nloglik".to_owned(), + EvaluationMetric::MultiClassLogLoss => "mlogloss".to_owned(), + EvaluationMetric::AUC => "auc".to_owned(), + EvaluationMetric::NDCG => "ndcg".to_owned(), + EvaluationMetric::NDCGCut(n) => format!("ndcg@{}", n), + EvaluationMetric::NDCGNegative => "ndcg-".to_owned(), + EvaluationMetric::NDCGCutNegative(n) => format!("ndcg@{}-", n), + EvaluationMetric::MAP => "map".to_owned(), + EvaluationMetric::MAPCut(n) => format!("map@{}", n), + EvaluationMetric::MAPNegative => "map-".to_owned(), + EvaluationMetric::MAPCutNegative(n) => format!("map@{}-", n), + EvaluationMetric::PoissonLogLoss => "poisson-nloglik".to_owned(), + EvaluationMetric::GammaLogLoss => "gamma-nloglik".to_owned(), + EvaluationMetric::CoxLogLoss => "cox-nloglik".to_owned(), + EvaluationMetric::GammaDeviance => "gamma-deviance".to_owned(), + EvaluationMetric::TweedieLogLoss => "tweedie-nloglik".to_owned(), } } } diff --git a/src/parameters/linear.rs b/src/parameters/linear.rs index 3168047..562905d 100644 --- a/src/parameters/linear.rs +++ b/src/parameters/linear.rs @@ -3,10 +3,11 @@ use std::default::Default; /// Linear model algorithm. -#[derive(Clone)] +#[derive(Clone, Default)] pub enum LinearUpdate { /// Parallel coordinate descent algorithm based on shotgun algorithm. Uses ‘hogwild’ parallelism and /// therefore produces a nondeterministic solution on each run. + #[default] Shotgun, /// Ordinary coordinate descent algorithm. Also multithreaded but still produces a deterministic solution. @@ -22,10 +23,6 @@ impl ToString for LinearUpdate { } } -impl Default for LinearUpdate { - fn default() -> Self { LinearUpdate::Shotgun } -} - /// BoosterParameters for Linear Booster. #[derive(Builder, Clone)] #[builder(default)] @@ -48,18 +45,14 @@ pub struct LinearBoosterParameters { updater: LinearUpdate, } - impl LinearBoosterParameters { pub(crate) fn as_string_pairs(&self) -> Vec<(String, String)> { - let mut v = Vec::new(); - - v.push(("booster".to_owned(), "gblinear".to_owned())); - - v.push(("lambda".to_owned(), self.lambda.to_string())); - v.push(("alpha".to_owned(), self.alpha.to_string())); - v.push(("updater".to_owned(), self.updater.to_string())); - - v + vec![ + ("booster".to_owned(), "gblinear".to_owned()), + ("lambda".to_owned(), self.lambda.to_string()), + ("alpha".to_owned(), self.alpha.to_string()), + ("updater".to_owned(), self.updater.to_string()), + ] } } diff --git a/src/parameters/mod.rs b/src/parameters/mod.rs index 35b9af6..9e0ddb2 100644 --- a/src/parameters/mod.rs +++ b/src/parameters/mod.rs @@ -9,19 +9,19 @@ use std::default::Default; use std::fmt::{self, Display}; -pub mod tree; +mod booster; +pub mod dart; pub mod learning; pub mod linear; -pub mod dart; -mod booster; +pub mod tree; -use super::DMatrix; pub use self::booster::BoosterType; use super::booster::CustomObjective; +use super::DMatrix; /// Parameters for training boosters. /// Created using [`BoosterParametersBuilder`](struct.BoosterParametersBuilder.html). -#[derive(Builder, Clone)] +#[derive(Builder, Clone, Default)] #[builder(default)] pub struct BoosterParameters { /// Type of booster (tree, linear or DART) along with its parameters. @@ -43,17 +43,6 @@ pub struct BoosterParameters { threads: Option, } -impl Default for BoosterParameters { - fn default() -> Self { - BoosterParameters { - booster_type: booster::BoosterType::default(), - learning_params: learning::LearningTaskParameters::default(), - verbose: false, - threads: None, - } - } -} - impl BoosterParameters { /// Get type of booster (tree, linear or DART) along with its parameters. pub fn booster_type(&self) -> &booster::BoosterType { @@ -127,41 +116,41 @@ pub struct TrainingParameters<'a> { /// Number of boosting rounds to use during training. /// /// *default*: `10` - #[builder(default="10")] + #[builder(default = "10")] pub(crate) boost_rounds: u32, /// Configuration for the booster model that will be trained. /// /// *default*: `BoosterParameters::default()` - #[builder(default="BoosterParameters::default()")] + #[builder(default = "BoosterParameters::default()")] pub(crate) booster_params: BoosterParameters, - #[builder(default="None")] + #[builder(default = "None")] /// Optional list of DMatrix to evaluate against after each boosting round. /// /// Supplied as a list of tuples of (DMatrix, description). The description is used to differentiate between /// different evaluation datasets when output during training. /// /// *default*: `None` - pub(crate) evaluation_sets: Option<&'a[(&'a DMatrix, &'a str)]>, + pub(crate) evaluation_sets: Option<&'a [(&'a DMatrix, &'a str)]>, /// Optional custom objective function to use for training. /// /// *default*: `None` - #[builder(default="None")] + #[builder(default = "None")] pub(crate) custom_objective_fn: Option, /// Optional custom evaluation function to use during training. /// /// *default*: `None` - #[builder(default="None")] + #[builder(default = "None")] pub(crate) custom_evaluation_fn: Option, // TODO: callbacks } -impl <'a> TrainingParameters<'a> { +impl<'a> TrainingParameters<'a> { pub fn dtrain(&self) -> &'a DMatrix { - &self.dtrain + self.dtrain } pub fn set_dtrain(&mut self, dtrain: &'a DMatrix) { @@ -184,11 +173,11 @@ impl <'a> TrainingParameters<'a> { self.booster_params = booster_params.into(); } - pub fn evaluation_sets(&self) -> &Option<&'a[(&'a DMatrix, &'a str)]> { + pub fn evaluation_sets(&self) -> &Option<&'a [(&'a DMatrix, &'a str)]> { &self.evaluation_sets } - pub fn set_evaluation_sets(&mut self, evaluation_sets: Option<&'a[(&'a DMatrix, &'a str)]>) { + pub fn set_evaluation_sets(&mut self, evaluation_sets: Option<&'a [(&'a DMatrix, &'a str)]>) { self.evaluation_sets = evaluation_sets; } @@ -225,11 +214,11 @@ impl Display for Interval { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let lower = match self.min_inclusion { Inclusion::Closed => '[', - Inclusion::Open => '(', + Inclusion::Open => '(', }; let upper = match self.max_inclusion { Inclusion::Closed => ']', - Inclusion::Open => ')', + Inclusion::Open => ')', }; write!(f, "{}{}, {}{}", lower, self.min, self.max, upper) } @@ -237,7 +226,12 @@ impl Display for Interval { impl Interval { fn new(min: T, min_inclusion: Inclusion, max: T, max_inclusion: Inclusion) -> Self { - Interval { min, min_inclusion, max, max_inclusion } + Interval { + min, + min_inclusion, + max, + max_inclusion, + } } fn new_open_open(min: T, max: T) -> Self { @@ -254,26 +248,45 @@ impl Interval { fn contains(&self, val: &T) -> bool { match self.min_inclusion { - Inclusion::Closed => if !(val >= &self.min) { return false; }, - Inclusion::Open => if !(val > &self.min) { return false; }, + Inclusion::Closed => { + if !(val >= &self.min) { + return false; + } + } + Inclusion::Open => { + if !(val > &self.min) { + return false; + } + } } match self.max_inclusion { - Inclusion::Closed => if !(val <= &self.max) { return false; }, - Inclusion::Open => if !(val < &self.max) { return false; }, + Inclusion::Closed => { + if !(val <= &self.max) { + return false; + } + } + Inclusion::Open => { + if !(val < &self.max) { + return false; + } + } } true } fn validate(&self, val: &Option, name: &str) -> Result<(), String> { - match val { + match &val { Some(ref val) => { - if self.contains(&val) { + if self.contains(val) { Ok(()) } else { - Err(format!("Invalid value for '{}' parameter, {} is not in range {}.", name, &val, self)) + Err(format!( + "Invalid value for '{}' parameter, {} is not in range {}.", + name, &val, self + )) } - }, - None => Ok(()) + } + None => Ok(()), } } } diff --git a/src/parameters/tree.rs b/src/parameters/tree.rs index d20b158..6c7343c 100644 --- a/src/parameters/tree.rs +++ b/src/parameters/tree.rs @@ -9,7 +9,7 @@ use super::Interval; /// [reference paper](http://arxiv.org/abs/1603.02754)). /// /// Distributed and external memory version only support approximate algorithm. -#[derive(Clone)] +#[derive(Clone, Default)] pub enum TreeMethod { /// Use heuristic to choose faster one. /// @@ -17,6 +17,7 @@ pub enum TreeMethod { /// * For very large-dataset, approximate algorithm will be chosen. /// * Because old behavior is always use exact greedy in single machine, user will get a message when /// approximate algorithm is chosen to notify this choice. + #[default] Auto, /// Exact greedy algorithm. @@ -49,33 +50,24 @@ impl ToString for TreeMethod { } } -impl Default for TreeMethod { - fn default() -> Self { TreeMethod::Auto } -} - -impl From for TreeMethod -{ - fn from(s: String) -> Self - { - use std::borrow::Borrow; - Self::from(s.borrow()) +impl From for TreeMethod { + fn from(s: String) -> Self { + use std::borrow::Borrow; + Self::from(s.borrow()) } } -impl<'a> From<&'a str> for TreeMethod -{ - fn from(s: &'a str) -> Self - { - match s - { - "auto" => TreeMethod::Auto, - "exact" => TreeMethod::Exact, - "approx" => TreeMethod::Approx, - "hist" => TreeMethod::Hist, - "gpu_exact" => TreeMethod::GpuExact, - "gpu_hist" => TreeMethod::GpuHist, - _ => panic!("no known tree_method for {}", s) - } +impl<'a> From<&'a str> for TreeMethod { + fn from(s: &'a str) -> Self { + match s { + "auto" => TreeMethod::Auto, + "exact" => TreeMethod::Exact, + "approx" => TreeMethod::Approx, + "hist" => TreeMethod::Hist, + "gpu_exact" => TreeMethod::GpuExact, + "gpu_hist" => TreeMethod::GpuHist, + _ => panic!("no known tree_method for {}", s), + } } } @@ -125,9 +117,10 @@ impl ToString for TreeUpdater { } /// A type of boosting process to run. -#[derive(Clone)] +#[derive(Clone, Default)] pub enum ProcessType { /// The normal boosting process which creates new trees. + #[default] Default, /// Starts from an existing model and only updates its trees. In each boosting iteration, @@ -148,14 +141,11 @@ impl ToString for ProcessType { } } -impl Default for ProcessType { - fn default() -> Self { ProcessType::Default } -} - /// Controls the way new nodes are added to the tree. -#[derive(Clone)] +#[derive(Clone, Default)] pub enum GrowPolicy { /// Split at nodes closest to the root. + #[default] Depthwise, /// Split at noeds with highest loss change. @@ -171,14 +161,11 @@ impl ToString for GrowPolicy { } } -impl Default for GrowPolicy { - fn default() -> Self { GrowPolicy::Depthwise } -} - /// The type of predictor algorithm to use. Provides the same results but allows the use of GPU or CPU. -#[derive(Clone)] +#[derive(Clone, Default)] pub enum Predictor { /// Multicore CPU prediction algorithm. + #[default] Cpu, /// Prediction using GPU. Default for ‘gpu_exact’ and ‘gpu_hist’ tree method. @@ -194,10 +181,6 @@ impl ToString for Predictor { } } -impl Default for Predictor { - fn default() -> Self { Predictor::Cpu } -} - /// BoosterParameters for Tree Booster. Create using /// [`TreeBoosterParametersBuilder`](struct.TreeBoosterParametersBuilder.html). #[derive(Builder, Clone)] @@ -374,39 +357,44 @@ impl Default for TreeBoosterParameters { impl TreeBoosterParameters { pub(crate) fn as_string_pairs(&self) -> Vec<(String, String)> { - let mut v = Vec::new(); - - v.push(("booster".to_owned(), "gbtree".to_owned())); - - v.push(("eta".to_owned(), self.eta.to_string())); - v.push(("gamma".to_owned(), self.gamma.to_string())); - v.push(("max_depth".to_owned(), self.max_depth.to_string())); - v.push(("min_child_weight".to_owned(), self.min_child_weight.to_string())); - v.push(("max_delta_step".to_owned(), self.max_delta_step.to_string())); - v.push(("subsample".to_owned(), self.subsample.to_string())); - v.push(("colsample_bytree".to_owned(), self.colsample_bytree.to_string())); - v.push(("colsample_bylevel".to_owned(), self.colsample_bylevel.to_string())); - v.push(("colsample_bynode".to_owned(), self.colsample_bynode.to_string())); - v.push(("lambda".to_owned(), self.lambda.to_string())); - v.push(("alpha".to_owned(), self.alpha.to_string())); - v.push(("tree_method".to_owned(), self.tree_method.to_string())); - v.push(("sketch_eps".to_owned(), self.sketch_eps.to_string())); - v.push(("scale_pos_weight".to_owned(), self.scale_pos_weight.to_string())); - v.push(("refresh_leaf".to_owned(), (self.refresh_leaf as u8).to_string())); - v.push(("process_type".to_owned(), self.process_type.to_string())); - v.push(("grow_policy".to_owned(), self.grow_policy.to_string())); - v.push(("max_leaves".to_owned(), self.max_leaves.to_string())); - v.push(("max_bin".to_owned(), self.max_bin.to_string())); - v.push(("num_parallel_tree".to_owned(), self.num_parallel_tree.to_string())); - v.push(("predictor".to_owned(), self.predictor.to_string())); + let mut v = vec![ + ("booster".to_owned(), "gbtree".to_owned()), + ("eta".to_owned(), self.eta.to_string()), + ("gamma".to_owned(), self.gamma.to_string()), + ("max_depth".to_owned(), self.max_depth.to_string()), + ("min_child_weight".to_owned(), self.min_child_weight.to_string()), + ("max_delta_step".to_owned(), self.max_delta_step.to_string()), + ("subsample".to_owned(), self.subsample.to_string()), + ("colsample_bytree".to_owned(), self.colsample_bytree.to_string()), + ("colsample_bylevel".to_owned(), self.colsample_bylevel.to_string()), + ("colsample_bynode".to_owned(), self.colsample_bynode.to_string()), + ("lambda".to_owned(), self.lambda.to_string()), + ("alpha".to_owned(), self.alpha.to_string()), + ("tree_method".to_owned(), self.tree_method.to_string()), + ("sketch_eps".to_owned(), self.sketch_eps.to_string()), + ("scale_pos_weight".to_owned(), self.scale_pos_weight.to_string()), + ("refresh_leaf".to_owned(), (self.refresh_leaf as u8).to_string()), + ("process_type".to_owned(), self.process_type.to_string()), + ("grow_policy".to_owned(), self.grow_policy.to_string()), + ("max_leaves".to_owned(), self.max_leaves.to_string()), + ("max_bin".to_owned(), self.max_bin.to_string()), + ("num_parallel_tree".to_owned(), self.num_parallel_tree.to_string()), + ("predictor".to_owned(), self.predictor.to_string()), + ]; // Don't pass anything to XGBoost if the user didn't specify anything. // This allows XGBoost to figure it out on it's own, and suppresses the // warning message during training. // See: https://github.com/davechallis/rust-xgboost/issues/7 - if self.updater.len() != 0 - { - v.push(("updater".to_owned(), self.updater.iter().map(|u| u.to_string()).collect::>().join(","))); + if !self.updater.is_empty() { + v.push(( + "updater".to_owned(), + self.updater + .iter() + .map(|u| u.to_string()) + .collect::>() + .join(","), + )); } v diff --git a/xgboost-sys/.cargo/config b/xgboost-sys/.cargo/config.toml similarity index 100% rename from xgboost-sys/.cargo/config rename to xgboost-sys/.cargo/config.toml diff --git a/xgboost-sys/Cargo.toml b/xgboost-sys/Cargo.toml index cddc0ce..b9749af 100644 --- a/xgboost-sys/Cargo.toml +++ b/xgboost-sys/Cargo.toml @@ -8,10 +8,14 @@ license = "MIT" repository = "https://github.com/davechallis/rust-xgboost" description = "Native bindings to the xgboost library" readme = "README.md" +edition = "2021" [dependencies] libc = "0.2" [build-dependencies] -bindgen = "0.59" +bindgen = "0.71" cmake = "0.1" + +[features] +cuda = [] diff --git a/xgboost-sys/README.md b/xgboost-sys/README.md index df39717..4a42bcc 100644 --- a/xgboost-sys/README.md +++ b/xgboost-sys/README.md @@ -3,4 +3,4 @@ FFI bindings to [XGBoost](https://xgboost.readthedocs.io/), generated at compile time with [bindgen](https://github.com/rust-lang-nursery/rust-bindgen). -Currently uses XGBoost v0.81. +Currently uses XGBoost v2.0. diff --git a/xgboost-sys/build.rs b/xgboost-sys/build.rs index b311d49..7fc9a9a 100644 --- a/xgboost-sys/build.rs +++ b/xgboost-sys/build.rs @@ -2,9 +2,9 @@ extern crate bindgen; extern crate cmake; use cmake::Config; -use std::process::Command; use std::env; use std::path::{Path, PathBuf}; +use std::process::Command; fn main() { let target = env::var("TARGET").unwrap(); @@ -21,43 +21,79 @@ fn main() { }); } + let mut dst = Config::new(&xgb_root); + dst.define("BUILD_STATIC_LIB", "ON").define("CMAKE_CXX_STANDARD", "17"); + // CMake - let dst = Config::new(&xgb_root) - .uses_cxx11() - .define("BUILD_STATIC_LIB", "ON") - .build(); + let mut dst = Config::new(&xgb_root); + let mut dst = dst.define("BUILD_STATIC_LIB", "ON"); + + #[cfg(feature = "cuda")] + let mut dst = dst + .define("USE_CUDA", "ON") + .define("BUILD_WITH_CUDA", "ON") + .define("BUILD_WITH_CUDA_CUB", "ON"); + + #[cfg(target_os = "macos")] + { + let path = PathBuf::from("/opt/homebrew/"); // check for m1 vs intel config + if let Ok(_dir) = std::fs::read_dir(&path) { + dst = dst + .define("CMAKE_C_COMPILER", "/opt/homebrew/opt/llvm/bin/clang") + .define("CMAKE_CXX_COMPILER", "/opt/homebrew/opt/llvm/bin/clang++") + .define("OPENMP_LIBRARIES", "/opt/homebrew/opt/llvm/lib") + .define("OPENMP_INCLUDES", "/opt/homebrew/opt/llvm/include"); + }; + } + let dst = dst.build(); let xgb_root = xgb_root.canonicalize().unwrap(); let bindings = bindgen::Builder::default() .header("wrapper.h") - .clang_args(&["-x", "c++", "-std=c++11"]) + .blocklist_item("std::__1.*") + .clang_args(&["-x", "c++", "-std=c++17"]) .clang_arg(format!("-I{}", xgb_root.join("include").display())) .clang_arg(format!("-I{}", xgb_root.join("rabit/include").display())) - .clang_arg(format!("-I{}", xgb_root.join("dmlc-core/include").display())) + .clang_arg(format!("-I{}", xgb_root.join("dmlc-core/include").display())); + + #[cfg(feature = "cuda")] + let bindings = bindings.clang_arg("-I/usr/local/cuda/include"); + let bindings = bindings .generate() .expect("Unable to generate bindings."); - let out_path = PathBuf::from(env::var("OUT_DIR").unwrap()); + let out_path = PathBuf::from(out_dir); bindings .write_to_file(out_path.join("bindings.rs")) .expect("Couldn't write bindings."); println!("cargo:rustc-link-search={}", xgb_root.join("lib").display()); + println!("cargo:rustc-link-search={}", xgb_root.join("lib64").display()); println!("cargo:rustc-link-search={}", xgb_root.join("rabit/lib").display()); println!("cargo:rustc-link-search={}", xgb_root.join("dmlc-core").display()); // link to appropriate C++ lib if target.contains("apple") { println!("cargo:rustc-link-lib=c++"); + println!("cargo:rustc-link-search=native=/opt/homebrew/opt/libomp/lib"); println!("cargo:rustc-link-lib=dylib=omp"); } else { + println!("cargo:rustc-cxxflags=-std=c++17"); + println!("cargo:rustc-link-lib=stdc++fs"); println!("cargo:rustc-link-lib=stdc++"); println!("cargo:rustc-link-lib=dylib=gomp"); } println!("cargo:rustc-link-search=native={}", dst.display()); println!("cargo:rustc-link-search=native={}", dst.join("lib").display()); + println!("cargo:rustc-link-search=native={}", dst.join("lib64").display()); println!("cargo:rustc-link-lib=static=dmlc"); println!("cargo:rustc-link-lib=static=xgboost"); + + #[cfg(feature = "cuda")] + { + println!("cargo:rustc-link-search={}", "/usr/local/cuda/lib64"); + println!("cargo:rustc-link-lib=static=cudart_static"); + } } diff --git a/xgboost-sys/src/lib.rs b/xgboost-sys/src/lib.rs index 78b8c72..a365c4c 100644 --- a/xgboost-sys/src/lib.rs +++ b/xgboost-sys/src/lib.rs @@ -26,7 +26,7 @@ mod tests { let mut num_cols = 0; let ret_val = unsafe { XGDMatrixNumCol(handle, &mut num_cols) }; assert_eq!(ret_val, 0); - assert_eq!(num_cols, 127); + assert_eq!(num_cols, 126); let ret_val = unsafe { XGDMatrixFree(handle) }; assert_eq!(ret_val, 0); diff --git a/xgboost-sys/xgboost b/xgboost-sys/xgboost index 61671a8..5e64276 160000 --- a/xgboost-sys/xgboost +++ b/xgboost-sys/xgboost @@ -1 +1 @@ -Subproject commit 61671a80dc42946882b562fda7b004b3967f0556 +Subproject commit 5e64276a9b95df57e6dd8f9e63347636f4e5d331