|
9 | 9 | //! |
10 | 10 | //! Lasso coefficient estimates solve the problem: |
11 | 11 | //! |
12 | | -//! \\[\underset{\beta}{minimize} \space \space \sum_{i=1}^n \left( y_i - \beta_0 - \sum_{j=1}^p \beta_jx_{ij} \right)^2 + \alpha \sum_{j=1}^p \lVert \beta_j \rVert_1\\] |
| 12 | +//! \\[\underset{\beta}{minimize} \space \space \frac{1}{n} \sum_{i=1}^n \left( y_i - \beta_0 - \sum_{j=1}^p \beta_jx_{ij} \right)^2 + \alpha \sum_{j=1}^p \lVert \beta_j \rVert_1\\] |
13 | 13 | //! |
14 | 14 | //! This problem is solved with an interior-point method that is comparable to coordinate descent in solving large problems with modest accuracy, |
15 | 15 | //! but is able to solve them with high accuracy with relatively small additional computational cost. |
@@ -246,7 +246,7 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>> Las |
246 | 246 | pub fn fit(x: &X, y: &Y, parameters: LassoParameters) -> Result<Lasso<TX, TY, X, Y>, Failed> { |
247 | 247 | let (n, p) = x.shape(); |
248 | 248 |
|
249 | | - if n <= p { |
| 249 | + if n < p { |
250 | 250 | return Err(Failed::fit( |
251 | 251 | "Number of rows in X should be >= number of columns in X", |
252 | 252 | )); |
@@ -369,6 +369,7 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>> Las |
369 | 369 | #[cfg(test)] |
370 | 370 | mod tests { |
371 | 371 | use super::*; |
| 372 | + use crate::linalg::basic::arrays::Array; |
372 | 373 | use crate::linalg::basic::matrix::DenseMatrix; |
373 | 374 | use crate::metrics::mean_absolute_error; |
374 | 375 |
|
@@ -448,6 +449,36 @@ mod tests { |
448 | 449 | assert!(mean_absolute_error(&y_hat, &y) < 2.0); |
449 | 450 | } |
450 | 451 |
|
| 452 | + #[cfg_attr( |
| 453 | + all(target_arch = "wasm32", not(target_os = "wasi")), |
| 454 | + wasm_bindgen_test::wasm_bindgen_test |
| 455 | + )] |
| 456 | + #[test] |
| 457 | + fn test_full_rank_x() { |
| 458 | + // x: randn(3,3) * 10, demean, then round to 2 decimal points |
| 459 | + // y = x @ [10.0, 0.2, -3.0], round to 2 decimal points |
| 460 | + let param = LassoParameters::default() |
| 461 | + .with_normalize(false) |
| 462 | + .with_alpha(200.0); |
| 463 | + let x = DenseMatrix::from_2d_array(&[ |
| 464 | + &[-8.9, -2.24, 8.89], |
| 465 | + &[-4.02, 8.89, 12.33], |
| 466 | + &[12.92, -6.65, -21.22], |
| 467 | + ]) |
| 468 | + .unwrap(); |
| 469 | + |
| 470 | + let y = vec![-116.12, -75.41, 191.53]; |
| 471 | + let w = Lasso::fit(&x, &y, param) |
| 472 | + .unwrap() |
| 473 | + .coefficients() |
| 474 | + .iterator(0) |
| 475 | + .copied() |
| 476 | + .collect(); |
| 477 | + |
| 478 | + let expected_w = vec![5.20289531, 0., -5.32823882]; // by coordinate descent |
| 479 | + assert!(mean_absolute_error(&w, &expected_w) < 1e-3); // actual mean_absolute_error is about 2e-4 |
| 480 | + } |
| 481 | + |
451 | 482 | // TODO: serialization for the new DenseMatrix needs to be implemented |
452 | 483 | // #[cfg_attr(all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test)] |
453 | 484 | // #[test] |
|
0 commit comments