1
- use std:: { slice, ffi, ptr, path:: Path } ;
2
- use libc:: { c_uint, c_float} ;
3
- use std:: os:: unix:: ffi:: OsStrExt ;
1
+ use libc:: { c_float, c_uint} ;
4
2
use std:: convert:: TryInto ;
3
+ use std:: os:: unix:: ffi:: OsStrExt ;
4
+ use std:: { ffi, path:: Path , ptr, slice} ;
5
5
6
6
use xgboost_sys;
7
7
8
- use super :: { XGBResult , XGBError } ;
8
+ use super :: { XGBError , XGBResult } ;
9
9
10
- static KEY_GROUP_PTR : & ' static str = "group_ptr" ;
11
- static KEY_GROUP : & ' static str = "group" ;
12
- static KEY_LABEL : & ' static str = "label" ;
13
- static KEY_WEIGHT : & ' static str = "weight" ;
14
- static KEY_BASE_MARGIN : & ' static str = "base_margin" ;
10
+ static KEY_GROUP_PTR : & str = "group_ptr" ;
11
+ static KEY_GROUP : & str = "group" ;
12
+ static KEY_LABEL : & str = "label" ;
13
+ static KEY_WEIGHT : & str = "weight" ;
14
+ static KEY_BASE_MARGIN : & str = "base_margin" ;
15
15
16
16
/// Data matrix used throughout XGBoost for training/predicting [`Booster`](struct.Booster.html) models.
17
17
///
@@ -88,7 +88,11 @@ impl DMatrix {
88
88
let num_cols = out as usize ;
89
89
90
90
info ! ( "Loaded DMatrix with shape: {}x{}" , num_rows, num_cols) ;
91
- Ok ( DMatrix { handle, num_rows, num_cols } )
91
+ Ok ( DMatrix {
92
+ handle,
93
+ num_rows,
94
+ num_cols,
95
+ } )
92
96
}
93
97
94
98
/// Create a new `DMatrix` from dense array in row-major order.
@@ -109,11 +113,13 @@ impl DMatrix {
109
113
/// ```
110
114
pub fn from_dense ( data : & [ f32 ] , num_rows : usize ) -> XGBResult < Self > {
111
115
let mut handle = ptr:: null_mut ( ) ;
112
- xgb_call ! ( xgboost_sys:: XGDMatrixCreateFromMat ( data. as_ptr( ) ,
113
- num_rows as xgboost_sys:: bst_ulong,
114
- ( data. len( ) / num_rows) as xgboost_sys:: bst_ulong,
115
- f32 :: NAN ,
116
- & mut handle) ) ?;
116
+ xgb_call ! ( xgboost_sys:: XGDMatrixCreateFromMat (
117
+ data. as_ptr( ) ,
118
+ num_rows as xgboost_sys:: bst_ulong,
119
+ ( data. len( ) / num_rows) as xgboost_sys:: bst_ulong,
120
+ f32 :: NAN ,
121
+ & mut handle
122
+ ) ) ?;
117
123
Ok ( DMatrix :: new ( handle) ?)
118
124
}
119
125
@@ -130,13 +136,15 @@ impl DMatrix {
130
136
let mut handle = ptr:: null_mut ( ) ;
131
137
let indices: Vec < u32 > = indices. iter ( ) . map ( |x| * x as u32 ) . collect ( ) ;
132
138
let num_cols = num_cols. unwrap_or ( 0 ) ; // infer from data if 0
133
- xgb_call ! ( xgboost_sys:: XGDMatrixCreateFromCSREx ( indptr. as_ptr( ) ,
134
- indices. as_ptr( ) ,
135
- data. as_ptr( ) ,
136
- indptr. len( ) . try_into( ) . unwrap( ) ,
137
- data. len( ) . try_into( ) . unwrap( ) ,
138
- num_cols. try_into( ) . unwrap( ) ,
139
- & mut handle) ) ?;
139
+ xgb_call ! ( xgboost_sys:: XGDMatrixCreateFromCSREx (
140
+ indptr. as_ptr( ) ,
141
+ indices. as_ptr( ) ,
142
+ data. as_ptr( ) ,
143
+ indptr. len( ) . try_into( ) . unwrap( ) ,
144
+ data. len( ) . try_into( ) . unwrap( ) ,
145
+ num_cols. try_into( ) . unwrap( ) ,
146
+ & mut handle
147
+ ) ) ?;
140
148
Ok ( DMatrix :: new ( handle) ?)
141
149
}
142
150
@@ -153,13 +161,15 @@ impl DMatrix {
153
161
let mut handle = ptr:: null_mut ( ) ;
154
162
let indices: Vec < u32 > = indices. iter ( ) . map ( |x| * x as u32 ) . collect ( ) ;
155
163
let num_rows = num_rows. unwrap_or ( 0 ) ; // infer from data if 0
156
- xgb_call ! ( xgboost_sys:: XGDMatrixCreateFromCSCEx ( indptr. as_ptr( ) ,
157
- indices. as_ptr( ) ,
158
- data. as_ptr( ) ,
159
- indptr. len( ) . try_into( ) . unwrap( ) ,
160
- data. len( ) . try_into( ) . unwrap( ) ,
161
- num_rows. try_into( ) . unwrap( ) ,
162
- & mut handle) ) ?;
164
+ xgb_call ! ( xgboost_sys:: XGDMatrixCreateFromCSCEx (
165
+ indptr. as_ptr( ) ,
166
+ indices. as_ptr( ) ,
167
+ data. as_ptr( ) ,
168
+ indptr. len( ) . try_into( ) . unwrap( ) ,
169
+ data. len( ) . try_into( ) . unwrap( ) ,
170
+ num_rows. try_into( ) . unwrap( ) ,
171
+ & mut handle
172
+ ) ) ?;
163
173
Ok ( DMatrix :: new ( handle) ?)
164
174
}
165
175
@@ -190,7 +200,11 @@ impl DMatrix {
190
200
let mut handle = ptr:: null_mut ( ) ;
191
201
let fname = ffi:: CString :: new ( path. as_ref ( ) . as_os_str ( ) . as_bytes ( ) ) . unwrap ( ) ;
192
202
let silent = true ;
193
- xgb_call ! ( xgboost_sys:: XGDMatrixCreateFromFile ( fname. as_ptr( ) , silent as i32 , & mut handle) ) ?;
203
+ xgb_call ! ( xgboost_sys:: XGDMatrixCreateFromFile (
204
+ fname. as_ptr( ) ,
205
+ silent as i32 ,
206
+ & mut handle
207
+ ) ) ?;
194
208
Ok ( DMatrix :: new ( handle) ?)
195
209
}
196
210
@@ -199,7 +213,11 @@ impl DMatrix {
199
213
debug ! ( "Writing DMatrix to: {}" , path. as_ref( ) . display( ) ) ;
200
214
let fname = ffi:: CString :: new ( path. as_ref ( ) . as_os_str ( ) . as_bytes ( ) ) . unwrap ( ) ;
201
215
let silent = true ;
202
- xgb_call ! ( xgboost_sys:: XGDMatrixSaveBinary ( self . handle, fname. as_ptr( ) , silent as i32 ) )
216
+ xgb_call ! ( xgboost_sys:: XGDMatrixSaveBinary (
217
+ self . handle,
218
+ fname. as_ptr( ) ,
219
+ silent as i32
220
+ ) )
203
221
}
204
222
205
223
/// Get the number of rows in this matrix.
@@ -222,10 +240,12 @@ impl DMatrix {
222
240
debug ! ( "Slicing {} rows from DMatrix" , indices. len( ) ) ;
223
241
let mut out_handle = ptr:: null_mut ( ) ;
224
242
let indices: Vec < i32 > = indices. iter ( ) . map ( |x| * x as i32 ) . collect ( ) ;
225
- xgb_call ! ( xgboost_sys:: XGDMatrixSliceDMatrix ( self . handle,
226
- indices. as_ptr( ) ,
227
- indices. len( ) as xgboost_sys:: bst_ulong,
228
- & mut out_handle) ) ?;
243
+ xgb_call ! ( xgboost_sys:: XGDMatrixSliceDMatrix (
244
+ self . handle,
245
+ indices. as_ptr( ) ,
246
+ indices. len( ) as xgboost_sys:: bst_ulong,
247
+ & mut out_handle
248
+ ) ) ?;
229
249
Ok ( DMatrix :: new ( out_handle) ?)
230
250
}
231
251
@@ -280,44 +300,51 @@ impl DMatrix {
280
300
self . get_uint_info ( KEY_GROUP_PTR )
281
301
}
282
302
283
-
284
303
fn get_float_info ( & self , field : & str ) -> XGBResult < & [ f32 ] > {
285
304
let field = ffi:: CString :: new ( field) . unwrap ( ) ;
286
305
let mut out_len = 0 ;
287
306
let mut out_dptr = ptr:: null ( ) ;
288
- xgb_call ! ( xgboost_sys:: XGDMatrixGetFloatInfo ( self . handle,
289
- field. as_ptr( ) ,
290
- & mut out_len,
291
- & mut out_dptr) ) ?;
307
+ xgb_call ! ( xgboost_sys:: XGDMatrixGetFloatInfo (
308
+ self . handle,
309
+ field. as_ptr( ) ,
310
+ & mut out_len,
311
+ & mut out_dptr
312
+ ) ) ?;
292
313
293
314
Ok ( unsafe { slice:: from_raw_parts ( out_dptr as * mut c_float , out_len as usize ) } )
294
315
}
295
316
296
317
fn set_float_info ( & mut self , field : & str , array : & [ f32 ] ) -> XGBResult < ( ) > {
297
318
let field = ffi:: CString :: new ( field) . unwrap ( ) ;
298
- xgb_call ! ( xgboost_sys:: XGDMatrixSetFloatInfo ( self . handle,
299
- field. as_ptr( ) ,
300
- array. as_ptr( ) ,
301
- array. len( ) as u64 ) )
319
+ xgb_call ! ( xgboost_sys:: XGDMatrixSetFloatInfo (
320
+ self . handle,
321
+ field. as_ptr( ) ,
322
+ array. as_ptr( ) ,
323
+ array. len( ) as u64
324
+ ) )
302
325
}
303
326
304
327
fn get_uint_info ( & self , field : & str ) -> XGBResult < & [ u32 ] > {
305
328
let field = ffi:: CString :: new ( field) . unwrap ( ) ;
306
329
let mut out_len = 0 ;
307
330
let mut out_dptr = ptr:: null ( ) ;
308
- xgb_call ! ( xgboost_sys:: XGDMatrixGetUIntInfo ( self . handle,
309
- field. as_ptr( ) ,
310
- & mut out_len,
311
- & mut out_dptr) ) ?;
331
+ xgb_call ! ( xgboost_sys:: XGDMatrixGetUIntInfo (
332
+ self . handle,
333
+ field. as_ptr( ) ,
334
+ & mut out_len,
335
+ & mut out_dptr
336
+ ) ) ?;
312
337
Ok ( unsafe { slice:: from_raw_parts ( out_dptr as * mut c_uint , out_len as usize ) } )
313
338
}
314
339
315
340
fn set_uint_info ( & mut self , field : & str , array : & [ u32 ] ) -> XGBResult < ( ) > {
316
341
let field = ffi:: CString :: new ( field) . unwrap ( ) ;
317
- xgb_call ! ( xgboost_sys:: XGDMatrixSetUIntInfo ( self . handle,
318
- field. as_ptr( ) ,
319
- array. as_ptr( ) ,
320
- array. len( ) as u64 ) )
342
+ xgb_call ! ( xgboost_sys:: XGDMatrixSetUIntInfo (
343
+ self . handle,
344
+ field. as_ptr( ) ,
345
+ array. as_ptr( ) ,
346
+ array. len( ) as u64
347
+ ) )
321
348
}
322
349
}
323
350
@@ -329,8 +356,8 @@ impl Drop for DMatrix {
329
356
330
357
#[ cfg( test) ]
331
358
mod tests {
332
- use tempfile;
333
359
use super :: * ;
360
+ use tempfile;
334
361
fn read_train_matrix ( ) -> XGBResult < DMatrix > {
335
362
DMatrix :: load ( "xgboost-sys/xgboost/demo/data/agaricus.txt.train" )
336
363
}
@@ -370,7 +397,7 @@ mod tests {
370
397
let mut dmat = read_train_matrix ( ) . unwrap ( ) ;
371
398
assert_eq ! ( dmat. get_labels( ) . unwrap( ) . len( ) , 6513 ) ;
372
399
373
- let label = [ 0.1 , 0.0 -4.5 , 11.29842 , 333333.33 ] ;
400
+ let label = [ 0.1 , 0.0 - 4.5 , 11.29842 , 333333.33 ] ;
374
401
assert ! ( dmat. set_labels( & label) . is_ok( ) ) ;
375
402
assert_eq ! ( dmat. get_labels( ) . unwrap( ) , label) ;
376
403
}
@@ -416,7 +443,7 @@ mod tests {
416
443
417
444
let dmat = DMatrix :: from_csr ( & indptr, & indices, & data, None ) . unwrap ( ) ;
418
445
assert_eq ! ( dmat. num_rows( ) , 4 ) ;
419
- assert_eq ! ( dmat. num_cols( ) , 0 ) ; // https://github.com/dmlc/xgboost/pull/7265
446
+ assert_eq ! ( dmat. num_cols( ) , 0 ) ; // https://github.com/dmlc/xgboost/pull/7265
420
447
421
448
let dmat = DMatrix :: from_csr ( & indptr, & indices, & data, Some ( 10 ) ) . unwrap ( ) ;
422
449
assert_eq ! ( dmat. num_rows( ) , 4 ) ;
0 commit comments