Skip to content

Commit f0de6a2

Browse files
Curve Fit - Add linear n-points curve fit solver.
This allows any number of linear points to be fit to an input line. GitHub issue #268.
1 parent cbdf4fd commit f0de6a2

13 files changed

+1218
-50
lines changed

lib/rust/mmscenegraph/src/math/curve_fit.rs

Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ pub struct Point2 {
3737
}
3838

3939
impl Point2 {
40+
pub fn new(x: f64, y: f64) -> Self {
41+
Self { x, y }
42+
}
43+
4044
pub fn x(self) -> Real {
4145
self.x
4246
}
@@ -433,3 +437,255 @@ pub fn nonlinear_line_n3(
433437
None => bail!("Solve failed."),
434438
}
435439
}
440+
441+
#[derive(Debug)]
442+
pub struct CurveFitLinearNPointsProblem {
443+
// Curve values we are trying to fit to
444+
reference_values: Vec<(f64, f64)>,
445+
reference_values_first_x: usize,
446+
447+
// X-coordinates of control points (fixed)
448+
control_points_x: Vec<f64>,
449+
}
450+
451+
impl CurveFitLinearNPointsProblem {
452+
fn new(control_points_x: Vec<f64>, reference_curve: &[(f64, f64)]) -> Self {
453+
let x_first: usize = reference_curve[0].0.floor() as usize;
454+
let reference_values: Vec<(f64, f64)> =
455+
reference_curve.iter().copied().collect();
456+
457+
Self {
458+
reference_values,
459+
reference_values_first_x: x_first,
460+
control_points_x,
461+
}
462+
}
463+
464+
fn parameter_count(&self) -> usize {
465+
// Only Y coordinates need to be solved for.
466+
self.control_points_x.len()
467+
}
468+
469+
fn reference_y_value_at_value_x(&self, value_x: f64) -> f64 {
470+
let value_start = self.control_points_x[0].floor() as usize;
471+
let value_end = self.control_points_x.last().unwrap().ceil() as usize;
472+
473+
let mut value_index = value_x.round() as usize;
474+
if value_index < value_start {
475+
value_index = value_start;
476+
} else if value_index > value_end {
477+
value_index = value_end;
478+
}
479+
480+
let mut index = value_index - self.reference_values_first_x;
481+
if index >= self.reference_values.len() {
482+
index = self.reference_values.len() - 1;
483+
}
484+
485+
self.reference_values[index].1
486+
}
487+
488+
fn interpolate_y_value_at_x(
489+
&self,
490+
value_x: f64,
491+
control_points_y: &[f64],
492+
) -> f64 {
493+
debug_assert_eq!(self.control_points_x.len(), control_points_y.len());
494+
495+
// Find the segment containing value_x.
496+
for i in 0..self.control_points_x.len() - 1 {
497+
if value_x <= self.control_points_x[i + 1] {
498+
if value_x == self.control_points_x[i + 1] {
499+
return control_points_y[i + 1];
500+
}
501+
return linear_interpolate_point_y_value_at_value_x(
502+
value_x,
503+
self.control_points_x[i],
504+
control_points_y[i],
505+
self.control_points_x[i + 1],
506+
control_points_y[i + 1],
507+
);
508+
}
509+
}
510+
511+
// If we get here, value_x is beyond the last point.
512+
*control_points_y.last().unwrap()
513+
}
514+
515+
fn residuals(&self, control_points_y: &[f64]) -> Vec<f64> {
516+
self.reference_values
517+
.iter()
518+
.map(|&(value_x, data_y)| {
519+
let curve_y =
520+
self.interpolate_y_value_at_x(value_x, control_points_y);
521+
(curve_y - data_y).abs()
522+
})
523+
.collect()
524+
}
525+
}
526+
527+
impl argmin::core::CostFunction for CurveFitLinearNPointsProblem {
528+
type Param = Array1<f64>;
529+
type Output = f64;
530+
531+
fn cost(
532+
&self,
533+
parameters: &Self::Param,
534+
) -> Result<Self::Output, argmin::core::Error> {
535+
debug!("Cost: parameters={parameters:?}");
536+
assert_eq!(parameters.len(), self.parameter_count());
537+
538+
let residuals_data = self.residuals(parameters.as_slice().unwrap());
539+
let residuals = Array1::from_vec(residuals_data);
540+
541+
let residuals_sum = residuals.sum();
542+
debug!("residuals_sum: {residuals_sum}");
543+
544+
Ok(residuals_sum * residuals_sum)
545+
}
546+
}
547+
548+
impl argmin::core::Gradient for CurveFitLinearNPointsProblem {
549+
type Param = Array1<f64>;
550+
type Gradient = Array1<f64>;
551+
552+
fn gradient(
553+
&self,
554+
parameters: &Self::Param,
555+
) -> Result<Self::Gradient, argmin::core::Error> {
556+
debug!("Gradient: parameters={parameters:?}");
557+
558+
assert_eq!(parameters.len(), self.parameter_count());
559+
560+
let vector = (*parameters).forward_diff(&|x| {
561+
let sum: f64 =
562+
self.residuals(x.as_slice().unwrap()).into_iter().sum();
563+
debug!("forward_diff residuals_sum: {sum}");
564+
sum * sum
565+
});
566+
567+
Ok(vector)
568+
}
569+
}
570+
571+
impl argmin::core::Hessian for CurveFitLinearNPointsProblem {
572+
type Param = Array1<f64>;
573+
type Hessian = Array2<f64>;
574+
575+
fn hessian(
576+
&self,
577+
parameters: &Self::Param,
578+
) -> Result<Self::Hessian, argmin::core::Error> {
579+
debug!("Hessian: parameters={parameters:?}");
580+
assert_eq!(parameters.len(), self.parameter_count());
581+
582+
let matrix =
583+
(*parameters).forward_hessian(&|x| self.gradient(x).unwrap());
584+
585+
Ok(matrix)
586+
}
587+
}
588+
589+
pub fn nonlinear_line_n_points(
590+
x_values: &[f64],
591+
y_values: &[f64],
592+
control_point_count: usize,
593+
) -> Result<Vec<Point2>> {
594+
assert_eq!(x_values.len(), y_values.len());
595+
let value_count = x_values.len();
596+
assert!(value_count > 2);
597+
assert!(
598+
control_point_count >= 3,
599+
"Must have at least 3 control points"
600+
);
601+
602+
// TODO: Normalize the input values to a known range, say 0.0 to
603+
// 1.0, solve and then scale back to the original time and value
604+
// range.
605+
606+
// First get initial guess using linear regression
607+
let mut point_x = 0.0;
608+
let mut point_y = 0.0;
609+
let mut angle = 0.0;
610+
curve_fit_linear_regression_type1(
611+
x_values,
612+
y_values,
613+
&mut point_x,
614+
&mut point_y,
615+
&mut angle,
616+
);
617+
618+
let dir_x = angle.cos();
619+
let dir_y = angle.sin();
620+
debug!("point_x={point_x} point_y={point_y}");
621+
debug!("dir_x={dir_x} dir_y={dir_y}");
622+
623+
// Calculate the range of x values
624+
let x_first = x_values[0];
625+
let x_last = x_values[value_count - 1];
626+
let x_range = x_last - x_first;
627+
628+
// Generate evenly spaced control points along x-axis
629+
let mut control_points_x = Vec::with_capacity(control_point_count);
630+
let mut initial_y_values = Vec::with_capacity(control_point_count);
631+
for i in 0..control_point_count {
632+
let mix = i as f64 / (control_point_count - 1) as f64;
633+
let x = (x_first + (mix * x_range)).floor();
634+
control_points_x.push(x);
635+
636+
// Initial y-values based on linear regression line.
637+
let y = point_y + (dir_y * (x - point_x) / dir_x);
638+
initial_y_values.push(y);
639+
}
640+
641+
// Create reference values
642+
let reference_values: Vec<(f64, f64)> = x_values
643+
.iter()
644+
.zip(y_values)
645+
.map(|(&x, &y)| (x, y))
646+
.collect();
647+
648+
// Define the problem
649+
let problem = CurveFitLinearNPointsProblem::new(
650+
control_points_x.clone(),
651+
&reference_values,
652+
);
653+
654+
// Set up solver
655+
let epsilon = 1e-3;
656+
let condition =
657+
argmin::solver::linesearch::condition::ArmijoCondition::new(1e-5)?;
658+
let linesearch =
659+
argmin::solver::linesearch::BacktrackingLineSearch::new(condition)
660+
.rho(0.5)?;
661+
let solver = argmin::solver::quasinewton::BFGS::new(linesearch)
662+
.with_tolerance_cost(epsilon)?;
663+
664+
// Run solver
665+
let initial_parameters = Array1::from(initial_y_values);
666+
let initial_hessian: Array2<f64> = Array2::eye(control_point_count);
667+
let result = argmin::core::Executor::new(problem, solver)
668+
.configure(|state| {
669+
state
670+
.param(initial_parameters)
671+
.inv_hessian(initial_hessian)
672+
.max_iters(50)
673+
})
674+
.run()?;
675+
676+
debug!("Solver Result: {result}");
677+
678+
match result.state().get_best_param() {
679+
Some(parameters) => {
680+
let mut control_points = Vec::with_capacity(control_point_count);
681+
for i in 0..control_point_count {
682+
control_points.push(Point2 {
683+
x: control_points_x[i],
684+
y: parameters[i],
685+
});
686+
}
687+
Ok(control_points)
688+
}
689+
None => bail!("Solve failed."),
690+
}
691+
}

0 commit comments

Comments
 (0)