Skip to content

Commit cb400f7

Browse files
committed
checkpoint
1 parent dc882f7 commit cb400f7

File tree

10 files changed

+177
-189
lines changed

10 files changed

+177
-189
lines changed

examples/basic/src/main.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@ fn main() {
1212

1313
// load train and test matrices from text files (in LibSVM format).
1414
println!("Loading train and test matrices...");
15-
let dtrain = DMatrix::load("../../xgboost-sys/xgboost/demo/data/agaricus.txt.train").unwrap();
15+
let dtrain = DMatrix::load(r#"{"uri": "../../xgboost-sys/xgboost/demo/data/agaricus.txt.train?format=libsvm"}"#).unwrap();
1616
println!("Train matrix: {}x{}", dtrain.num_rows(), dtrain.num_cols());
17-
let dtest = DMatrix::load("../../xgboost-sys/xgboost/demo/data/agaricus.txt.test").unwrap();
17+
let dtest = DMatrix::load(r#"{"uri": "../../xgboost-sys/xgboost/demo/data/agaricus.txt.test?format=libsvm"}"#).unwrap();
1818
println!("Test matrix: {}x{}", dtest.num_rows(), dtest.num_cols());
1919

2020
// configure objectives, metrics, etc.

examples/custom_objective/src/main.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ use xgboost::{parameters, DMatrix, Booster};
66
fn main() {
77
// load train and test matrices from text files (in LibSVM format)
88
println!("Custom objective example...");
9-
let dtrain = DMatrix::load("../../xgboost-sys/xgboost/demo/data/agaricus.txt.train").unwrap();
10-
let dtest = DMatrix::load("../../xgboost-sys/xgboost/demo/data/agaricus.txt.test").unwrap();
9+
let dtrain = DMatrix::load(r#"{"uri": "../../xgboost-sys/xgboost/demo/data/agaricus.txt.train?format=libsvm"}"#).unwrap();
10+
let dtest = DMatrix::load(r#"{"uri": "../../xgboost-sys/xgboost/demo/data/agaricus.txt.test?format=libsvm"}"#).unwrap();
1111

1212
// specify datasets to evaluate against during training
1313
let evaluation_sets = [(&dtest, "test"), (&dtrain, "train")];

src/booster.rs

Lines changed: 88 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -591,14 +591,14 @@ impl Booster {
591591
for part in eval.split('\t').skip(1) {
592592
for evname in evnames {
593593
if part.starts_with(evname) {
594-
let metric_parts: Vec<&str> = part[evname.len() + 1..].split(':').into_iter().collect();
594+
let metric_parts: Vec<&str> = part[evname.len() + 1..].split(':').collect();
595595
assert_eq!(metric_parts.len(), 2);
596596
let metric = metric_parts[0];
597597
let score = metric_parts[1]
598598
.parse::<f32>()
599599
.unwrap_or_else(|_| panic!("Unable to parse XGBoost metrics output: {}", eval));
600600

601-
let metric_map = result.entry(evname.to_string()).or_insert_with(IndexMap::new);
601+
let metric_map = result.entry(evname.to_string()).or_default();
602602
metric_map.insert(metric.to_owned(), score);
603603
}
604604
}
@@ -669,7 +669,7 @@ impl FeatureMap {
669669
};
670670

671671
let feature_name = &parts[1];
672-
let feature_type = match FeatureType::from_str(&parts[2]) {
672+
let feature_type = match FeatureType::from_str(parts[2]) {
673673
Ok(feature_type) => feature_type,
674674
Err(msg) => {
675675
let msg = format!("Unable to parse features from line {}: {}", i + 1, msg);
@@ -727,7 +727,7 @@ mod tests {
727727
use parameters::{self, learning, tree};
728728

729729
fn read_train_matrix() -> XGBResult<DMatrix> {
730-
DMatrix::load("xgboost-sys/xgboost/demo/data/agaricus.txt.train")
730+
DMatrix::load(r#"{"uri": "xgboost-sys/xgboost/demo/data/agaricus.txt.train?format=libsvm"}"#)
731731
}
732732

733733
fn load_test_booster() -> Booster {
@@ -761,7 +761,7 @@ mod tests {
761761

762762
#[test]
763763
fn save_and_load_from_buffer() {
764-
let dmat_train = DMatrix::load("xgboost-sys/xgboost/demo/data/agaricus.txt.train").unwrap();
764+
let dmat_train = DMatrix::load(r#"{"uri": "xgboost-sys/xgboost/demo/data/agaricus.txt.train?format=libsvm"}"#).unwrap();
765765
let mut booster = Booster::new_with_cached_dmats(&BoosterParameters::default(), &[&dmat_train]).unwrap();
766766
let attr = booster.get_attribute("foo").expect("Getting attribute failed");
767767
assert_eq!(attr, None);
@@ -804,8 +804,8 @@ mod tests {
804804

805805
#[test]
806806
fn predict() {
807-
let dmat_train = DMatrix::load("xgboost-sys/xgboost/demo/data/agaricus.txt.train").unwrap();
808-
let dmat_test = DMatrix::load("xgboost-sys/xgboost/demo/data/agaricus.txt.test").unwrap();
807+
let dmat_train = DMatrix::load(r#"{"uri": "xgboost-sys/xgboost/demo/data/agaricus.txt.train?format=libsvm"}"#).unwrap();
808+
let dmat_test =DMatrix::load(r#"{"uri": "xgboost-sys/xgboost/demo/data/agaricus.txt.test?format=libsvm"}"#).unwrap();
809809

810810
let tree_params = tree::TreeBoosterParametersBuilder::default()
811811
.max_depth(2)
@@ -835,11 +835,11 @@ mod tests {
835835

836836
let train_metrics = booster.evaluate(&dmat_train).unwrap();
837837
assert_eq!(*train_metrics.get("logloss").unwrap(), 0.006634271);
838-
assert_eq!(*train_metrics.get("map@4-").unwrap(), 0.0012738854);
838+
assert_eq!(*train_metrics.get("map@4-").unwrap(), 1.0);
839839

840840
let test_metrics = booster.evaluate(&dmat_test).unwrap();
841841
assert_eq!(*test_metrics.get("logloss").unwrap(), 0.006919953);
842-
assert_eq!(*test_metrics.get("map@4-").unwrap(), 0.005154639);
842+
assert_eq!(*test_metrics.get("map@4-").unwrap(), 1.0);
843843

844844
let v = booster.predict(&dmat_test).unwrap();
845845
assert_eq!(v.len(), dmat_test.num_rows());
@@ -886,8 +886,8 @@ mod tests {
886886

887887
#[test]
888888
fn predict_leaf() {
889-
let dmat_train = DMatrix::load("xgboost-sys/xgboost/demo/data/agaricus.txt.train").unwrap();
890-
let dmat_test = DMatrix::load("xgboost-sys/xgboost/demo/data/agaricus.txt.test").unwrap();
889+
let dmat_train = DMatrix::load(r#"{"uri": "xgboost-sys/xgboost/demo/data/agaricus.txt.train?format=libsvm"}"#).unwrap();
890+
let dmat_test = DMatrix::load(r#"{"uri": "xgboost-sys/xgboost/demo/data/agaricus.txt.test?format=libsvm"}"#).unwrap();
891891

892892
let tree_params = tree::TreeBoosterParametersBuilder::default()
893893
.max_depth(2)
@@ -919,8 +919,8 @@ mod tests {
919919

920920
#[test]
921921
fn predict_contributions() {
922-
let dmat_train = DMatrix::load("xgboost-sys/xgboost/demo/data/agaricus.txt.train").unwrap();
923-
let dmat_test = DMatrix::load("xgboost-sys/xgboost/demo/data/agaricus.txt.test").unwrap();
922+
let dmat_train = DMatrix::load(r#"{"uri": "xgboost-sys/xgboost/demo/data/agaricus.txt.train?format=libsvm"}"#).unwrap();
923+
let dmat_test = DMatrix::load(r#"{"uri": "xgboost-sys/xgboost/demo/data/agaricus.txt.test?format=libsvm"}"#).unwrap();
924924

925925
let tree_params = tree::TreeBoosterParametersBuilder::default()
926926
.max_depth(2)
@@ -953,8 +953,8 @@ mod tests {
953953

954954
#[test]
955955
fn predict_interactions() {
956-
let dmat_train = DMatrix::load("xgboost-sys/xgboost/demo/data/agaricus.txt.train").unwrap();
957-
let dmat_test = DMatrix::load("xgboost-sys/xgboost/demo/data/agaricus.txt.test").unwrap();
956+
let dmat_train = DMatrix::load(r#"{"uri": "xgboost-sys/xgboost/demo/data/agaricus.txt.train?format=libsvm"}"#).unwrap();
957+
let dmat_test = DMatrix::load(r#"{"uri": "xgboost-sys/xgboost/demo/data/agaricus.txt.test?format=libsvm"}"#).unwrap();
958958

959959
let tree_params = tree::TreeBoosterParametersBuilder::default()
960960
.max_depth(2)
@@ -1005,7 +1005,7 @@ mod tests {
10051005

10061006
#[test]
10071007
fn dump_model() {
1008-
let dmat_train = DMatrix::load("xgboost-sys/xgboost/demo/data/agaricus.txt.train").unwrap();
1008+
let dmat_train = DMatrix::load(r#"{"uri": "xgboost-sys/xgboost/demo/data/agaricus.txt.train?format=libsvm"}"#).unwrap();
10091009

10101010
println!("{:?}", dmat_train.shape());
10111011

@@ -1033,82 +1033,79 @@ mod tests {
10331033
.unwrap();
10341034
let booster = Booster::train(&training_params).unwrap();
10351035

1036-
let features = FeatureMap::from_file("xgboost-sys/xgboost/demo/data/featmap.txt")
1037-
.expect("failed to parse feature map file");
1038-
10391036
assert_eq!(
1040-
booster.dump_model(true, Some(&features)).unwrap(),
1041-
"0:[odor=none] yes=2,no=1,gain=4000.53101,cover=1628.25
1042-
1:[stalk-root=club] yes=4,no=3,gain=1158.21204,cover=924.5
1043-
3:leaf=1.71217716,cover=812
1044-
4:leaf=-1.70044053,cover=112.5
1045-
2:[spore-print-color=green] yes=6,no=5,gain=198.173828,cover=703.75
1046-
5:leaf=-1.94070864,cover=690.5
1047-
6:leaf=1.85964918,cover=13.25
1048-
1049-
0:[stalk-root=rooted] yes=2,no=1,gain=832.545044,cover=788.852051
1050-
1:[odor=none] yes=4,no=3,gain=569.725098,cover=768.389709
1051-
3:leaf=0.78471756,cover=458.936859
1052-
4:leaf=-0.968530357,cover=309.45282
1053-
2:leaf=-6.23624468,cover=20.462389
1054-
1055-
0:[ring-type=pendant] yes=2,no=1,gain=368.744568,cover=457.069458
1056-
1:[stalk-surface-below-ring=scaly] yes=4,no=3,gain=226.33696,cover=221.051468
1057-
3:leaf=0.658725023,cover=212.999451
1058-
4:leaf=5.77228642,cover=8.05200672
1059-
2:[spore-print-color=purple] yes=6,no=5,gain=258.184265,cover=236.018005
1060-
5:leaf=-0.791407049,cover=233.487625
1061-
6:leaf=-9.421422,cover=2.53038669
1062-
1063-
0:[odor=foul] yes=2,no=1,gain=140.486069,cover=364.119354
1064-
1:[gill-size=broad] yes=4,no=3,gain=139.860504,cover=274.101959
1065-
3:leaf=0.614153326,cover=95.8599854
1066-
4:leaf=-0.877905607,cover=178.241974
1067-
2:leaf=1.07747853,cover=90.0174103
1068-
1069-
0:[spore-print-color=green] yes=2,no=1,gain=112.605011,cover=189.202194
1070-
1:[gill-spacing=close] yes=4,no=3,gain=66.4029999,cover=177.771835
1071-
3:leaf=-1.26934469,cover=42.277401
1072-
4:leaf=0.152607277,cover=135.494431
1073-
2:leaf=2.92190909,cover=11.4303684
1074-
1075-
0:[odor=almond] yes=2,no=1,gain=52.5610275,cover=170.612762
1076-
1:[odor=anise] yes=4,no=3,gain=67.3869553,cover=150.881165
1077-
3:leaf=0.431742132,cover=131.902222
1078-
4:leaf=-1.53846073,cover=18.9789505
1079-
2:[gill-spacing=close] yes=6,no=5,gain=12.4420624,cover=19.731596
1080-
5:leaf=-3.02413678,cover=3.65769386
1081-
6:leaf=-1.02315068,cover=16.0739021
1082-
1083-
0:[odor=none] yes=2,no=1,gain=66.2389145,cover=142.360611
1084-
1:[odor=anise] yes=4,no=3,gain=31.2294312,cover=72.7557373
1085-
3:leaf=0.777142286,cover=64.5309982
1086-
4:leaf=-1.19710124,cover=8.22473907
1087-
2:[spore-print-color=green] yes=6,no=5,gain=12.1987419,cover=69.6048737
1088-
5:leaf=-0.912605286,cover=66.1211166
1089-
6:leaf=0.836115122,cover=3.48375821
1090-
1091-
0:[gill-size=broad] yes=2,no=1,gain=20.6531773,cover=79.4027634
1092-
1:[spore-print-color=white] yes=4,no=3,gain=16.0703697,cover=34.9289207
1093-
3:leaf=-0.0180106498,cover=25.0319824
1094-
4:leaf=1.4361918,cover=9.89693928
1095-
2:[odor=foul] yes=6,no=5,gain=22.1144333,cover=44.4738464
1096-
5:leaf=-0.908311546,cover=36.982872
1097-
6:leaf=0.890622675,cover=7.49097395
1098-
1099-
0:[odor=almond] yes=2,no=1,gain=11.7128553,cover=53.3251991
1100-
1:[ring-type=pendant] yes=4,no=3,gain=12.546154,cover=44.299942
1101-
3:leaf=-0.515293062,cover=15.7899179
1102-
4:leaf=0.56883812,cover=28.5100231
1103-
2:leaf=-1.01502442,cover=9.02525806
1104-
1105-
0:[population=clustered] yes=2,no=1,gain=14.8892794,cover=45.9312019
1106-
1:[odor=none] yes=4,no=3,gain=10.1308851,cover=43.0564575
1107-
3:leaf=0.217203051,cover=22.3283749
1108-
4:leaf=-0.734555721,cover=20.7280827
1109-
2:[stalk-root=missing] yes=6,no=5,gain=19.3462334,cover=2.87474418
1110-
5:leaf=3.63442755,cover=1.34154534
1111-
6:leaf=-0.609474957,cover=1.53319895
1037+
booster.dump_model(true, None).unwrap(),
1038+
"0:[f29<2.00001001] yes=1,no=2,missing=2,gain=4000.53101,cover=1628.25
1039+
1:[f109<2.00001001] yes=3,no=4,missing=4,gain=198.173828,cover=703.75
1040+
3:leaf=1.85964918,cover=13.25
1041+
4:leaf=-1.94070864,cover=690.5
1042+
2:[f56<2.00001001] yes=5,no=6,missing=6,gain=1158.21204,cover=924.5
1043+
5:leaf=-1.70044053,cover=112.5
1044+
6:leaf=1.71217716,cover=812
1045+
1046+
0:[f60<2.00001001] yes=1,no=2,missing=2,gain=832.544983,cover=788.852051
1047+
1:leaf=-6.23624468,cover=20.462389
1048+
2:[f29<2.00001001] yes=3,no=4,missing=4,gain=569.725098,cover=768.389709
1049+
3:leaf=-0.968530357,cover=309.45282
1050+
4:leaf=0.78471756,cover=458.936859
1051+
1052+
0:[f102<2.00001001] yes=1,no=2,missing=2,gain=368.744568,cover=457.069458
1053+
1:[f111<2.00001001] yes=3,no=4,missing=4,gain=258.184326,cover=236.018005
1054+
3:leaf=-9.421422,cover=2.53038669
1055+
4:leaf=-0.791407049,cover=233.487625
1056+
2:[f67<2.00001001] yes=5,no=6,missing=6,gain=226.336975,cover=221.051468
1057+
5:leaf=5.77228642,cover=8.05200672
1058+
6:leaf=0.658725023,cover=212.999451
1059+
1060+
0:[f27<2.00001001] yes=1,no=2,missing=2,gain=140.486053,cover=364.119354
1061+
1:leaf=1.07747853,cover=90.0174103
1062+
2:[f39<2.00001001] yes=3,no=4,missing=4,gain=139.860519,cover=274.101959
1063+
3:leaf=-0.877905607,cover=178.241974
1064+
4:leaf=0.614153326,cover=95.8599854
1065+
1066+
0:[f109<2.00001001] yes=1,no=2,missing=2,gain=112.605019,cover=189.202194
1067+
1:leaf=2.92190909,cover=11.4303684
1068+
2:[f36<2.00001001] yes=3,no=4,missing=4,gain=66.4029999,cover=177.771835
1069+
3:leaf=0.152607277,cover=135.494431
1070+
4:leaf=-1.26934469,cover=42.277401
1071+
1072+
0:[f23<2.00001001] yes=1,no=2,missing=2,gain=52.5610313,cover=170.612762
1073+
1:[f36<2.00001001] yes=3,no=4,missing=4,gain=12.4420547,cover=19.731596
1074+
3:leaf=-1.02315068,cover=16.0739021
1075+
4:leaf=-3.02413678,cover=3.65769386
1076+
2:[f24<2.00001001] yes=5,no=6,missing=6,gain=67.3869553,cover=150.881165
1077+
5:leaf=-1.53846073,cover=18.9789505
1078+
6:leaf=0.431742132,cover=131.902222
1079+
1080+
0:[f29<2.00001001] yes=1,no=2,missing=2,gain=66.2389145,cover=142.360611
1081+
1:[f109<2.00001001] yes=3,no=4,missing=4,gain=12.1987419,cover=69.6048737
1082+
3:leaf=0.836115122,cover=3.48375821
1083+
4:leaf=-0.912605286,cover=66.1211166
1084+
2:[f24<2.00001001] yes=5,no=6,missing=6,gain=31.229435,cover=72.7557373
1085+
5:leaf=-1.19710124,cover=8.22473907
1086+
6:leaf=0.777142286,cover=64.5309982
1087+
1088+
0:[f39<2.00001001] yes=1,no=2,missing=2,gain=20.6531773,cover=79.4027634
1089+
1:[f27<2.00001001] yes=3,no=4,missing=4,gain=22.1144371,cover=44.4738464
1090+
3:leaf=0.890622675,cover=7.49097395
1091+
4:leaf=-0.908311546,cover=36.982872
1092+
2:[f112<2.00001001] yes=5,no=6,missing=6,gain=16.0703697,cover=34.9289207
1093+
5:leaf=1.4361918,cover=9.89693928
1094+
6:leaf=-0.0180106498,cover=25.0319824
1095+
1096+
0:[f23<2.00001001] yes=1,no=2,missing=2,gain=11.7128553,cover=53.3251991
1097+
1:leaf=-1.01502442,cover=9.02525806
1098+
2:[f102<2.00001001] yes=3,no=4,missing=4,gain=12.5461531,cover=44.299942
1099+
3:leaf=0.56883812,cover=28.5100231
1100+
4:leaf=-0.515293062,cover=15.7899179
1101+
1102+
0:[f115<2.00001001] yes=1,no=2,missing=2,gain=14.8892794,cover=45.9312019
1103+
1:[f61<2.00001001] yes=3,no=4,missing=4,gain=19.3462334,cover=2.87474418
1104+
3:leaf=-0.609474957,cover=1.53319895
1105+
4:leaf=3.63442755,cover=1.34154534
1106+
2:[f29<2.00001001] yes=5,no=6,missing=6,gain=10.1308861,cover=43.0564575
1107+
5:leaf=-0.734555721,cover=20.7280827
1108+
6:leaf=0.217203051,cover=22.3283749
11121109
"
11131110
);
11141111
}

0 commit comments

Comments
 (0)