Skip to content

Commit e6c8707

Browse files
authored
Merge pull request #9 from KyussCaesar/master
Change types for some `TreeBoosterParameters` to match types in XGBoost
2 parents 87db871 + 1af35db commit e6c8707

File tree

3 files changed

+68
-18
lines changed

3 files changed

+68
-18
lines changed

src/booster.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ impl Booster {
178178
for (dmat, dmat_name) in eval_sets {
179179
let margin = bst.predict_margin(dmat)?;
180180
let eval_result = eval_fn(&margin, dmat);
181-
let mut eval_results = dmat_eval_results.entry(eval_name.to_string())
181+
let eval_results = dmat_eval_results.entry(eval_name.to_string())
182182
.or_insert_with(IndexMap::new);
183183
eval_results.insert(dmat_name.to_string(), eval_result);
184184
}
@@ -188,7 +188,7 @@ impl Booster {
188188
let mut eval_dmat_results = BTreeMap::new();
189189
for (dmat_name, eval_results) in &dmat_eval_results {
190190
for (eval_name, result) in eval_results {
191-
let mut dmat_results = eval_dmat_results.entry(eval_name).or_insert_with(BTreeMap::new);
191+
let dmat_results = eval_dmat_results.entry(eval_name).or_insert_with(BTreeMap::new);
192192
dmat_results.insert(dmat_name, result);
193193
}
194194
}
@@ -548,7 +548,7 @@ impl Booster {
548548
let score = metric_parts[1].parse::<f32>()
549549
.unwrap_or_else(|_| panic!("Unable to parse XGBoost metrics output: {}", eval));
550550

551-
let mut metric_map = result.entry(evname.to_string()).or_insert_with(IndexMap::new);
551+
let metric_map = result.entry(evname.to_string()).or_insert_with(IndexMap::new);
552552
metric_map.insert(metric.to_owned(), score);
553553
}
554554
}
@@ -712,7 +712,7 @@ mod tests {
712712
let attr = booster.get_attribute("foo").expect("Getting attribute failed");
713713
assert_eq!(attr, Some("bar".to_owned()));
714714

715-
let mut dir = tempfile::tempdir().expect("create temp dir");
715+
let dir = tempfile::tempdir().expect("create temp dir");
716716
let path = dir.path().join("test-xgboost-model");
717717
booster.save(&path).expect("saving booster");
718718
drop(booster);

src/parameters/booster.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
//!
1010
//! let tree_params = TreeBoosterParametersBuilder::default()
1111
//! .eta(0.2)
12-
//! .gamma(3)
12+
//! .gamma(3.0)
1313
//! .subsample(0.75)
1414
//! .build()
1515
//! .unwrap();

src/parameters/tree.rs

Lines changed: 63 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,32 @@ impl Default for TreeMethod {
5353
fn default() -> Self { TreeMethod::Auto }
5454
}
5555

56+
impl From<String> for TreeMethod
57+
{
58+
fn from(s: String) -> Self
59+
{
60+
use std::borrow::Borrow;
61+
Self::from(s.borrow())
62+
}
63+
}
64+
65+
impl<'a> From<&'a str> for TreeMethod
66+
{
67+
fn from(s: &'a str) -> Self
68+
{
69+
match s
70+
{
71+
"auto" => TreeMethod::Auto,
72+
"exact" => TreeMethod::Exact,
73+
"approx" => TreeMethod::Approx,
74+
"hist" => TreeMethod::Hist,
75+
"gpu_exact" => TreeMethod::GpuExact,
76+
"gpu_hist" => TreeMethod::GpuHist,
77+
_ => panic!("no known tree_method for {}", s)
78+
}
79+
}
80+
}
81+
5682
/// Provides a modular way to construct and to modify the trees. This is an advanced parameter that is usually set
5783
/// automatically, depending on some other parameters. However, it could be also set explicitly by a user.
5884
#[derive(Clone)]
@@ -191,7 +217,7 @@ pub struct TreeBoosterParameters {
191217
///
192218
/// * range: [0,∞]
193219
/// * default: 0
194-
gamma: u32,
220+
gamma: f32,
195221

196222
/// Maximum depth of a tree, increase this value will make the model more complex / likely to be overfitting.
197223
/// 0 indicates no limit, limit is required for depth-wise grow policy.
@@ -208,7 +234,7 @@ pub struct TreeBoosterParameters {
208234
///
209235
/// * range: [0,∞]
210236
/// * default: 1
211-
min_child_weight: u32,
237+
min_child_weight: f32,
212238

213239
/// Maximum delta step we allow each tree’s weight estimation to be.
214240
/// If the value is set to 0, it means there is no constraint. If it is set to a positive value,
@@ -218,7 +244,7 @@ pub struct TreeBoosterParameters {
218244
///
219245
/// * range: [0,∞]
220246
/// * default: 0
221-
max_delta_step: u32,
247+
max_delta_step: f32,
222248

223249
/// Subsample ratio of the training instance. Setting it to 0.5 means that XGBoost randomly collected half
224250
/// of the data instances to grow trees and this will prevent overfitting.
@@ -239,15 +265,21 @@ pub struct TreeBoosterParameters {
239265
/// * default: 1.0
240266
colsample_bylevel: f32,
241267

268+
/// Subsample ratio of columns for each node.
269+
///
270+
/// * range: (0.0, 1.0]
271+
/// * default: 1.0
272+
colsample_bynode: f32,
273+
242274
/// L2 regularization term on weights, increase this value will make model more conservative.
243275
///
244276
/// * default: 1
245-
lambda: u32,
277+
lambda: f32,
246278

247279
/// L1 regularization term on weights, increase this value will make model more conservative.
248280
///
249281
/// * default: 0
250-
alpha: u32,
282+
alpha: f32,
251283

252284
/// The tree construction algorithm used in XGBoost.
253285
#[builder(default = "TreeMethod::default()")]
@@ -270,7 +302,7 @@ pub struct TreeBoosterParameters {
270302

271303
/// Sequence of tree updaters to run, providing a modular way to construct and to modify the trees.
272304
///
273-
/// * default: [TreeUpdater::GrowColMaker, TreeUpdater::Prune]
305+
/// * default: vec![]
274306
updater: Vec<TreeUpdater>,
275307

276308
/// This is a parameter of the ‘refresh’ updater plugin. When this flag is true, tree leafs as well as tree nodes'
@@ -300,6 +332,11 @@ pub struct TreeBoosterParameters {
300332
/// * default: 256
301333
max_bin: u32,
302334

335+
/// Number of trees to train in parallel for boosted random forest.
336+
///
337+
/// * default: 1
338+
num_parallel_tree: u32,
339+
303340
/// The type of predictor algorithm to use. Provides the same results but allows the use of GPU or CPU.
304341
///
305342
/// * default: [`Predictor::Cpu`](enum.Predictor.html#variant.Cpu)
@@ -310,24 +347,26 @@ impl Default for TreeBoosterParameters {
310347
fn default() -> Self {
311348
TreeBoosterParameters {
312349
eta: 0.3,
313-
gamma: 0,
350+
gamma: 0.0,
314351
max_depth: 6,
315-
min_child_weight: 1,
316-
max_delta_step: 0,
352+
min_child_weight: 1.0,
353+
max_delta_step: 0.0,
317354
subsample: 1.0,
318355
colsample_bytree: 1.0,
319356
colsample_bylevel: 1.0,
320-
lambda: 1,
321-
alpha: 0,
357+
colsample_bynode: 1.0,
358+
lambda: 1.0,
359+
alpha: 0.0,
322360
tree_method: TreeMethod::default(),
323361
sketch_eps: 0.03,
324362
scale_pos_weight: 1.0,
325-
updater: vec![TreeUpdater::GrowColMaker, TreeUpdater::Prune],
363+
updater: Vec::new(),
326364
refresh_leaf: true,
327365
process_type: ProcessType::default(),
328366
grow_policy: GrowPolicy::default(),
329367
max_leaves: 0,
330368
max_bin: 256,
369+
num_parallel_tree: 1,
331370
predictor: Predictor::default(),
332371
}
333372
}
@@ -347,19 +386,29 @@ impl TreeBoosterParameters {
347386
v.push(("subsample".to_owned(), self.subsample.to_string()));
348387
v.push(("colsample_bytree".to_owned(), self.colsample_bytree.to_string()));
349388
v.push(("colsample_bylevel".to_owned(), self.colsample_bylevel.to_string()));
389+
v.push(("colsample_bynode".to_owned(), self.colsample_bynode.to_string()));
350390
v.push(("lambda".to_owned(), self.lambda.to_string()));
351391
v.push(("alpha".to_owned(), self.alpha.to_string()));
352392
v.push(("tree_method".to_owned(), self.tree_method.to_string()));
353393
v.push(("sketch_eps".to_owned(), self.sketch_eps.to_string()));
354394
v.push(("scale_pos_weight".to_owned(), self.scale_pos_weight.to_string()));
355-
v.push(("updater".to_owned(), self.updater.iter().map(|u| u.to_string()).collect::<Vec<String>>().join(",")));
356395
v.push(("refresh_leaf".to_owned(), (self.refresh_leaf as u8).to_string()));
357396
v.push(("process_type".to_owned(), self.process_type.to_string()));
358397
v.push(("grow_policy".to_owned(), self.grow_policy.to_string()));
359398
v.push(("max_leaves".to_owned(), self.max_leaves.to_string()));
360399
v.push(("max_bin".to_owned(), self.max_bin.to_string()));
400+
v.push(("num_parallel_tree".to_owned(), self.num_parallel_tree.to_string()));
361401
v.push(("predictor".to_owned(), self.predictor.to_string()));
362402

403+
// Don't pass anything to XGBoost if the user didn't specify anything.
404+
// This allows XGBoost to figure it out on it's own, and suppresses the
405+
// warning message during training.
406+
// See: https://github.com/davechallis/rust-xgboost/issues/7
407+
if self.updater.len() != 0
408+
{
409+
v.push(("updater".to_owned(), self.updater.iter().map(|u| u.to_string()).collect::<Vec<String>>().join(",")));
410+
}
411+
363412
v
364413
}
365414
}
@@ -370,6 +419,7 @@ impl TreeBoosterParametersBuilder {
370419
Interval::new_open_closed(0.0, 1.0).validate(&self.subsample, "subsample")?;
371420
Interval::new_open_closed(0.0, 1.0).validate(&self.colsample_bytree, "colsample_bytree")?;
372421
Interval::new_open_closed(0.0, 1.0).validate(&self.colsample_bylevel, "colsample_bylevel")?;
422+
Interval::new_open_closed(0.0, 1.0).validate(&self.colsample_bynode, "colsample_bynode")?;
373423
Interval::new_open_open(0.0, 1.0).validate(&self.sketch_eps, "sketch_eps")?;
374424
Ok(())
375425
}

0 commit comments

Comments
 (0)