Skip to content

Commit 2bf5f7a

Browse files
georethZhou Xiaozhou
andauthored
Fix LASSO (first two of #342) (#343)
* Fix LASSO (#342) * change loss function in doc to match code * allow `n == p` case * lasso add test_full_rank_x --------- Co-authored-by: Zhou Xiaozhou <zxz@jiweifund.com>
1 parent 0caa830 commit 2bf5f7a

File tree

1 file changed

+33
-2
lines changed

1 file changed

+33
-2
lines changed

src/linear/lasso.rs

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
//!
1010
//! Lasso coefficient estimates solve the problem:
1111
//!
12-
//! \\[\underset{\beta}{minimize} \space \space \sum_{i=1}^n \left( y_i - \beta_0 - \sum_{j=1}^p \beta_jx_{ij} \right)^2 + \alpha \sum_{j=1}^p \lVert \beta_j \rVert_1\\]
12+
//! \\[\underset{\beta}{minimize} \space \space \frac{1}{n} \sum_{i=1}^n \left( y_i - \beta_0 - \sum_{j=1}^p \beta_jx_{ij} \right)^2 + \alpha \sum_{j=1}^p \lVert \beta_j \rVert_1\\]
1313
//!
1414
//! This problem is solved with an interior-point method that is comparable to coordinate descent in solving large problems with modest accuracy,
1515
//! but is able to solve them with high accuracy with relatively small additional computational cost.
@@ -246,7 +246,7 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>> Las
246246
pub fn fit(x: &X, y: &Y, parameters: LassoParameters) -> Result<Lasso<TX, TY, X, Y>, Failed> {
247247
let (n, p) = x.shape();
248248

249-
if n <= p {
249+
if n < p {
250250
return Err(Failed::fit(
251251
"Number of rows in X should be >= number of columns in X",
252252
));
@@ -369,6 +369,7 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>> Las
369369
#[cfg(test)]
370370
mod tests {
371371
use super::*;
372+
use crate::linalg::basic::arrays::Array;
372373
use crate::linalg::basic::matrix::DenseMatrix;
373374
use crate::metrics::mean_absolute_error;
374375

@@ -448,6 +449,36 @@ mod tests {
448449
assert!(mean_absolute_error(&y_hat, &y) < 2.0);
449450
}
450451

452+
#[cfg_attr(
453+
all(target_arch = "wasm32", not(target_os = "wasi")),
454+
wasm_bindgen_test::wasm_bindgen_test
455+
)]
456+
#[test]
457+
fn test_full_rank_x() {
458+
// x: randn(3,3) * 10, demean, then round to 2 decimal points
459+
// y = x @ [10.0, 0.2, -3.0], round to 2 decimal points
460+
let param = LassoParameters::default()
461+
.with_normalize(false)
462+
.with_alpha(200.0);
463+
let x = DenseMatrix::from_2d_array(&[
464+
&[-8.9, -2.24, 8.89],
465+
&[-4.02, 8.89, 12.33],
466+
&[12.92, -6.65, -21.22],
467+
])
468+
.unwrap();
469+
470+
let y = vec![-116.12, -75.41, 191.53];
471+
let w = Lasso::fit(&x, &y, param)
472+
.unwrap()
473+
.coefficients()
474+
.iterator(0)
475+
.copied()
476+
.collect();
477+
478+
let expected_w = vec![5.20289531, 0., -5.32823882]; // by coordinate descent
479+
assert!(mean_absolute_error(&w, &expected_w) < 1e-3); // actual mean_absolute_error is about 2e-4
480+
}
481+
451482
// TODO: serialization for the new DenseMatrix needs to be implemented
452483
// #[cfg_attr(all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test)]
453484
// #[test]

0 commit comments

Comments
 (0)