@@ -12,7 +12,7 @@ Basic usage example:
12
12
``` rust
13
13
extern crate xgboost;
14
14
15
- use xgboost :: {parameters, dmatrix :: DMatrix , booster :: Booster };
15
+ use xgboost :: {parameters, DMatrix , Booster };
16
16
17
17
fn main () {
18
18
// training matrix with 5 training examples and 3 features
@@ -37,14 +37,37 @@ fn main() {
37
37
let mut dtest = DMatrix :: from_dense (x_test , num_rows ). unwrap ();
38
38
dtest . set_labels (y_test ). unwrap ();
39
39
40
- // build overall training parameters
41
- let params = parameters :: ParametersBuilder :: default (). build (). unwrap ();
40
+ // configure objectives, metrics, etc.
41
+ let learning_params = parameters :: learning :: LearningTaskParametersBuilder :: default ()
42
+ . objective (parameters :: learning :: Objective :: BinaryLogistic )
43
+ . build (). unwrap ();
44
+
45
+ // configure the tree-based learning model's parameters
46
+ let tree_params = parameters :: tree :: TreeBoosterParametersBuilder :: default ()
47
+ . max_depth (2 )
48
+ . eta (1.0 )
49
+ . build (). unwrap ();
50
+
51
+ // overall configuration for Booster
52
+ let booster_params = parameters :: BoosterParametersBuilder :: default ()
53
+ . booster_type (parameters :: BoosterType :: Tree (tree_params ))
54
+ . learning_params (learning_params )
55
+ . verbose (true )
56
+ . build (). unwrap ();
42
57
43
58
// specify datasets to evaluate against during training
44
59
let evaluation_sets = & [(& dtrain , " train" ), (& dtest , " test" )];
45
60
61
+ // overall configuration for training/evaluation
62
+ let params = parameters :: TrainingParametersBuilder :: default ()
63
+ . dtrain (& dtrain ) // dataset to train with
64
+ . boost_rounds (2 ) // number of training iterations
65
+ . booster_params (booster_params ) // model parameters
66
+ . evaluation_sets (Some (evaluation_sets )) // optional datasets to evaluate against in each iteration
67
+ . build (). unwrap ();
68
+
46
69
// train model, and print evaluation data
47
- let bst = Booster :: train (& params , & dtrain , 3 , evaluation_sets ). unwrap ();
70
+ let bst = Booster :: train (& params ). unwrap ();
48
71
49
72
println! (" {:?}" , bst . predict (& dtest ). unwrap ());
50
73
}
0 commit comments