Skip to content

Commit dc882f7

Browse files
committed
checkpoint
1 parent d321b19 commit dc882f7

File tree

13 files changed

+406
-292
lines changed

13 files changed

+406
-292
lines changed

Cargo.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@ readme = "README.md"
1212
[dependencies]
1313
xgboost-sys = { path = "xgboost-sys" }
1414
libc = "0.2"
15-
derive_builder = "0.11"
15+
derive_builder = "0.12"
1616
log = "0.4"
17-
tempfile = "3.0"
18-
indexmap = "1.0"
17+
tempfile = "3.9"
18+
indexmap = "2.1"
1919

2020
[features]
2121
cuda = ["xgboost-sys/cuda"]

rustfmt.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
max_width = 120
2+
single_line_if_else_max_width = 80

src/booster.rs

Lines changed: 191 additions & 121 deletions
Large diffs are not rendered by default.

src/dmatrix.rs

Lines changed: 82 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
1-
use std::{slice, ffi, ptr, path::Path};
2-
use libc::{c_uint, c_float};
3-
use std::os::unix::ffi::OsStrExt;
1+
use libc::{c_float, c_uint};
42
use std::convert::TryInto;
3+
use std::os::unix::ffi::OsStrExt;
4+
use std::{ffi, path::Path, ptr, slice};
55

66
use xgboost_sys;
77

8-
use super::{XGBResult, XGBError};
8+
use super::{XGBError, XGBResult};
99

10-
static KEY_GROUP_PTR: &'static str = "group_ptr";
11-
static KEY_GROUP: &'static str = "group";
12-
static KEY_LABEL: &'static str = "label";
13-
static KEY_WEIGHT: &'static str = "weight";
14-
static KEY_BASE_MARGIN: &'static str = "base_margin";
10+
static KEY_GROUP_PTR: &str = "group_ptr";
11+
static KEY_GROUP: &str = "group";
12+
static KEY_LABEL: &str = "label";
13+
static KEY_WEIGHT: &str = "weight";
14+
static KEY_BASE_MARGIN: &str = "base_margin";
1515

1616
/// Data matrix used throughout XGBoost for training/predicting [`Booster`](struct.Booster.html) models.
1717
///
@@ -88,7 +88,11 @@ impl DMatrix {
8888
let num_cols = out as usize;
8989

9090
info!("Loaded DMatrix with shape: {}x{}", num_rows, num_cols);
91-
Ok(DMatrix { handle, num_rows, num_cols })
91+
Ok(DMatrix {
92+
handle,
93+
num_rows,
94+
num_cols,
95+
})
9296
}
9397

9498
/// Create a new `DMatrix` from dense array in row-major order.
@@ -109,11 +113,13 @@ impl DMatrix {
109113
/// ```
110114
pub fn from_dense(data: &[f32], num_rows: usize) -> XGBResult<Self> {
111115
let mut handle = ptr::null_mut();
112-
xgb_call!(xgboost_sys::XGDMatrixCreateFromMat(data.as_ptr(),
113-
num_rows as xgboost_sys::bst_ulong,
114-
(data.len() / num_rows) as xgboost_sys::bst_ulong,
115-
f32::NAN,
116-
&mut handle))?;
116+
xgb_call!(xgboost_sys::XGDMatrixCreateFromMat(
117+
data.as_ptr(),
118+
num_rows as xgboost_sys::bst_ulong,
119+
(data.len() / num_rows) as xgboost_sys::bst_ulong,
120+
f32::NAN,
121+
&mut handle
122+
))?;
117123
Ok(DMatrix::new(handle)?)
118124
}
119125

@@ -130,13 +136,15 @@ impl DMatrix {
130136
let mut handle = ptr::null_mut();
131137
let indices: Vec<u32> = indices.iter().map(|x| *x as u32).collect();
132138
let num_cols = num_cols.unwrap_or(0); // infer from data if 0
133-
xgb_call!(xgboost_sys::XGDMatrixCreateFromCSREx(indptr.as_ptr(),
134-
indices.as_ptr(),
135-
data.as_ptr(),
136-
indptr.len().try_into().unwrap(),
137-
data.len().try_into().unwrap(),
138-
num_cols.try_into().unwrap(),
139-
&mut handle))?;
139+
xgb_call!(xgboost_sys::XGDMatrixCreateFromCSREx(
140+
indptr.as_ptr(),
141+
indices.as_ptr(),
142+
data.as_ptr(),
143+
indptr.len().try_into().unwrap(),
144+
data.len().try_into().unwrap(),
145+
num_cols.try_into().unwrap(),
146+
&mut handle
147+
))?;
140148
Ok(DMatrix::new(handle)?)
141149
}
142150

@@ -153,13 +161,15 @@ impl DMatrix {
153161
let mut handle = ptr::null_mut();
154162
let indices: Vec<u32> = indices.iter().map(|x| *x as u32).collect();
155163
let num_rows = num_rows.unwrap_or(0); // infer from data if 0
156-
xgb_call!(xgboost_sys::XGDMatrixCreateFromCSCEx(indptr.as_ptr(),
157-
indices.as_ptr(),
158-
data.as_ptr(),
159-
indptr.len().try_into().unwrap(),
160-
data.len().try_into().unwrap(),
161-
num_rows.try_into().unwrap(),
162-
&mut handle))?;
164+
xgb_call!(xgboost_sys::XGDMatrixCreateFromCSCEx(
165+
indptr.as_ptr(),
166+
indices.as_ptr(),
167+
data.as_ptr(),
168+
indptr.len().try_into().unwrap(),
169+
data.len().try_into().unwrap(),
170+
num_rows.try_into().unwrap(),
171+
&mut handle
172+
))?;
163173
Ok(DMatrix::new(handle)?)
164174
}
165175

@@ -190,7 +200,11 @@ impl DMatrix {
190200
let mut handle = ptr::null_mut();
191201
let fname = ffi::CString::new(path.as_ref().as_os_str().as_bytes()).unwrap();
192202
let silent = true;
193-
xgb_call!(xgboost_sys::XGDMatrixCreateFromFile(fname.as_ptr(), silent as i32, &mut handle))?;
203+
xgb_call!(xgboost_sys::XGDMatrixCreateFromFile(
204+
fname.as_ptr(),
205+
silent as i32,
206+
&mut handle
207+
))?;
194208
Ok(DMatrix::new(handle)?)
195209
}
196210

@@ -199,7 +213,11 @@ impl DMatrix {
199213
debug!("Writing DMatrix to: {}", path.as_ref().display());
200214
let fname = ffi::CString::new(path.as_ref().as_os_str().as_bytes()).unwrap();
201215
let silent = true;
202-
xgb_call!(xgboost_sys::XGDMatrixSaveBinary(self.handle, fname.as_ptr(), silent as i32))
216+
xgb_call!(xgboost_sys::XGDMatrixSaveBinary(
217+
self.handle,
218+
fname.as_ptr(),
219+
silent as i32
220+
))
203221
}
204222

205223
/// Get the number of rows in this matrix.
@@ -222,10 +240,12 @@ impl DMatrix {
222240
debug!("Slicing {} rows from DMatrix", indices.len());
223241
let mut out_handle = ptr::null_mut();
224242
let indices: Vec<i32> = indices.iter().map(|x| *x as i32).collect();
225-
xgb_call!(xgboost_sys::XGDMatrixSliceDMatrix(self.handle,
226-
indices.as_ptr(),
227-
indices.len() as xgboost_sys::bst_ulong,
228-
&mut out_handle))?;
243+
xgb_call!(xgboost_sys::XGDMatrixSliceDMatrix(
244+
self.handle,
245+
indices.as_ptr(),
246+
indices.len() as xgboost_sys::bst_ulong,
247+
&mut out_handle
248+
))?;
229249
Ok(DMatrix::new(out_handle)?)
230250
}
231251

@@ -280,44 +300,51 @@ impl DMatrix {
280300
self.get_uint_info(KEY_GROUP_PTR)
281301
}
282302

283-
284303
fn get_float_info(&self, field: &str) -> XGBResult<&[f32]> {
285304
let field = ffi::CString::new(field).unwrap();
286305
let mut out_len = 0;
287306
let mut out_dptr = ptr::null();
288-
xgb_call!(xgboost_sys::XGDMatrixGetFloatInfo(self.handle,
289-
field.as_ptr(),
290-
&mut out_len,
291-
&mut out_dptr))?;
307+
xgb_call!(xgboost_sys::XGDMatrixGetFloatInfo(
308+
self.handle,
309+
field.as_ptr(),
310+
&mut out_len,
311+
&mut out_dptr
312+
))?;
292313

293314
Ok(unsafe { slice::from_raw_parts(out_dptr as *mut c_float, out_len as usize) })
294315
}
295316

296317
fn set_float_info(&mut self, field: &str, array: &[f32]) -> XGBResult<()> {
297318
let field = ffi::CString::new(field).unwrap();
298-
xgb_call!(xgboost_sys::XGDMatrixSetFloatInfo(self.handle,
299-
field.as_ptr(),
300-
array.as_ptr(),
301-
array.len() as u64))
319+
xgb_call!(xgboost_sys::XGDMatrixSetFloatInfo(
320+
self.handle,
321+
field.as_ptr(),
322+
array.as_ptr(),
323+
array.len() as u64
324+
))
302325
}
303326

304327
fn get_uint_info(&self, field: &str) -> XGBResult<&[u32]> {
305328
let field = ffi::CString::new(field).unwrap();
306329
let mut out_len = 0;
307330
let mut out_dptr = ptr::null();
308-
xgb_call!(xgboost_sys::XGDMatrixGetUIntInfo(self.handle,
309-
field.as_ptr(),
310-
&mut out_len,
311-
&mut out_dptr))?;
331+
xgb_call!(xgboost_sys::XGDMatrixGetUIntInfo(
332+
self.handle,
333+
field.as_ptr(),
334+
&mut out_len,
335+
&mut out_dptr
336+
))?;
312337
Ok(unsafe { slice::from_raw_parts(out_dptr as *mut c_uint, out_len as usize) })
313338
}
314339

315340
fn set_uint_info(&mut self, field: &str, array: &[u32]) -> XGBResult<()> {
316341
let field = ffi::CString::new(field).unwrap();
317-
xgb_call!(xgboost_sys::XGDMatrixSetUIntInfo(self.handle,
318-
field.as_ptr(),
319-
array.as_ptr(),
320-
array.len() as u64))
342+
xgb_call!(xgboost_sys::XGDMatrixSetUIntInfo(
343+
self.handle,
344+
field.as_ptr(),
345+
array.as_ptr(),
346+
array.len() as u64
347+
))
321348
}
322349
}
323350

@@ -329,8 +356,8 @@ impl Drop for DMatrix {
329356

330357
#[cfg(test)]
331358
mod tests {
332-
use tempfile;
333359
use super::*;
360+
use tempfile;
334361
fn read_train_matrix() -> XGBResult<DMatrix> {
335362
DMatrix::load("xgboost-sys/xgboost/demo/data/agaricus.txt.train")
336363
}
@@ -370,7 +397,7 @@ mod tests {
370397
let mut dmat = read_train_matrix().unwrap();
371398
assert_eq!(dmat.get_labels().unwrap().len(), 6513);
372399

373-
let label = [0.1, 0.0 -4.5, 11.29842, 333333.33];
400+
let label = [0.1, 0.0 - 4.5, 11.29842, 333333.33];
374401
assert!(dmat.set_labels(&label).is_ok());
375402
assert_eq!(dmat.get_labels().unwrap(), label);
376403
}
@@ -416,7 +443,7 @@ mod tests {
416443

417444
let dmat = DMatrix::from_csr(&indptr, &indices, &data, None).unwrap();
418445
assert_eq!(dmat.num_rows(), 4);
419-
assert_eq!(dmat.num_cols(), 0); // https://github.com/dmlc/xgboost/pull/7265
446+
assert_eq!(dmat.num_cols(), 0); // https://github.com/dmlc/xgboost/pull/7265
420447

421448
let dmat = DMatrix::from_csr(&indptr, &indices, &data, Some(10)).unwrap();
422449
assert_eq!(dmat.num_rows(), 4);

src/error.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
//! Functionality related to errors and error handling.
22
33
use std;
4+
use std::error::Error;
45
use std::ffi::CStr;
56
use std::fmt::{self, Display};
6-
use std::error::Error;
77

88
use xgboost_sys;
99

@@ -29,17 +29,19 @@ impl XGBError {
2929
/// Meaning of any other return values are undefined, and will cause a panic.
3030
pub(crate) fn check_return_value(ret_val: i32) -> XGBResult<()> {
3131
match ret_val {
32-
0 => Ok(()),
32+
0 => Ok(()),
3333
-1 => Err(XGBError::from_xgboost()),
34-
_ => panic!("unexpected return value '{}', expected 0 or -1", ret_val),
34+
_ => panic!("unexpected return value '{}', expected 0 or -1", ret_val),
3535
}
3636
}
3737

3838
/// Get the last error message from XGBoost.
3939
fn from_xgboost() -> Self {
4040
let c_str = unsafe { CStr::from_ptr(xgboost_sys::XGBGetLastError()) };
4141
let str_slice = c_str.to_str().unwrap();
42-
XGBError { desc: str_slice.to_owned() }
42+
XGBError {
43+
desc: str_slice.to_owned(),
44+
}
4345
}
4446
}
4547

src/lib.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,10 @@
6060
extern crate derive_builder;
6161
#[macro_use]
6262
extern crate log;
63-
extern crate xgboost_sys;
63+
extern crate indexmap;
6464
extern crate libc;
6565
extern crate tempfile;
66-
extern crate indexmap;
66+
extern crate xgboost_sys;
6767

6868
macro_rules! xgb_call {
6969
($x:expr) => {
@@ -72,7 +72,7 @@ macro_rules! xgb_call {
7272
}
7373

7474
mod error;
75-
pub use error::{XGBResult, XGBError};
75+
pub use error::{XGBError, XGBResult};
7676

7777
mod dmatrix;
7878
pub use dmatrix::DMatrix;

src/parameters/booster.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
//! ```
2121
use std::default::Default;
2222

23-
use super::{tree, linear, dart};
23+
use super::{dart, linear, tree};
2424

2525
/// Type of booster to use when training a [Booster](../struct.Booster.html) model.
2626
#[derive(Clone)]
@@ -46,15 +46,17 @@ pub enum BoosterType {
4646
}
4747

4848
impl Default for BoosterType {
49-
fn default() -> Self { BoosterType::Tree(tree::TreeBoosterParameters::default()) }
49+
fn default() -> Self {
50+
BoosterType::Tree(tree::TreeBoosterParameters::default())
51+
}
5052
}
5153

5254
impl BoosterType {
5355
pub(crate) fn as_string_pairs(&self) -> Vec<(String, String)> {
5456
match *self {
5557
BoosterType::Tree(ref p) => p.as_string_pairs(),
5658
BoosterType::Linear(ref p) => p.as_string_pairs(),
57-
BoosterType::Dart(ref p) => p.as_string_pairs()
59+
BoosterType::Dart(ref p) => p.as_string_pairs(),
5860
}
5961
}
6062
}

src/parameters/dart.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@ impl ToString for SampleType {
2525
}
2626

2727
impl Default for SampleType {
28-
fn default() -> Self { SampleType::Uniform }
28+
fn default() -> Self {
29+
SampleType::Uniform
30+
}
2931
}
3032

3133
/// Type of normalization algorithm.
@@ -53,7 +55,9 @@ impl ToString for NormalizeType {
5355
}
5456

5557
impl Default for NormalizeType {
56-
fn default() -> Self { NormalizeType::Tree }
58+
fn default() -> Self {
59+
NormalizeType::Tree
60+
}
5761
}
5862

5963
/// Additional parameters for Dart Booster.

0 commit comments

Comments
 (0)