Skip to content

Commit 87db871

Browse files
authored
Merge pull request #6 from infinite-Joy/fix.readme_example
changing the README example as the main api has changed.
2 parents c6a855d + c197ba6 commit 87db871

File tree

1 file changed

+27
-4
lines changed

1 file changed

+27
-4
lines changed

README.md

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ Basic usage example:
1212
```rust
1313
extern crate xgboost;
1414

15-
use xgboost::{parameters, dmatrix::DMatrix, booster::Booster};
15+
use xgboost::{parameters, DMatrix, Booster};
1616

1717
fn main() {
1818
// training matrix with 5 training examples and 3 features
@@ -37,14 +37,37 @@ fn main() {
3737
let mut dtest = DMatrix::from_dense(x_test, num_rows).unwrap();
3838
dtest.set_labels(y_test).unwrap();
3939

40-
// build overall training parameters
41-
let params = parameters::ParametersBuilder::default().build().unwrap();
40+
// configure objectives, metrics, etc.
41+
let learning_params = parameters::learning::LearningTaskParametersBuilder::default()
42+
.objective(parameters::learning::Objective::BinaryLogistic)
43+
.build().unwrap();
44+
45+
// configure the tree-based learning model's parameters
46+
let tree_params = parameters::tree::TreeBoosterParametersBuilder::default()
47+
.max_depth(2)
48+
.eta(1.0)
49+
.build().unwrap();
50+
51+
// overall configuration for Booster
52+
let booster_params = parameters::BoosterParametersBuilder::default()
53+
.booster_type(parameters::BoosterType::Tree(tree_params))
54+
.learning_params(learning_params)
55+
.verbose(true)
56+
.build().unwrap();
4257

4358
// specify datasets to evaluate against during training
4459
let evaluation_sets = &[(&dtrain, "train"), (&dtest, "test")];
4560

61+
// overall configuration for training/evaluation
62+
let params = parameters::TrainingParametersBuilder::default()
63+
.dtrain(&dtrain) // dataset to train with
64+
.boost_rounds(2) // number of training iterations
65+
.booster_params(booster_params) // model parameters
66+
.evaluation_sets(Some(evaluation_sets)) // optional datasets to evaluate against in each iteration
67+
.build().unwrap();
68+
4669
// train model, and print evaluation data
47-
let bst = Booster::train(&params, &dtrain, 3, evaluation_sets).unwrap();
70+
let bst = Booster::train(&params).unwrap();
4871

4972
println!("{:?}", bst.predict(&dtest).unwrap());
5073
}

0 commit comments

Comments
 (0)