Skip to content

Commit 31cfa81

Browse files
Curve Fit - Add 3-point Line solver
This uses the "argmin" and "ndarray" crates (and includes the Intel MKL libraries) GitHub issue #268.
1 parent cd6ef90 commit 31cfa81

16 files changed

+2045
-17
lines changed

Cargo.lock

Lines changed: 783 additions & 14 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,18 @@ publish = false
2020
[workspace.dependencies]
2121
anyhow = "1.0.89"
2222
approx = "0.5.1"
23+
argmin = { version = "0.10.0", default-features = true, features = ["serde1"] } # ,
24+
argmin-math = { version = "0.4", default-features = true, features = ["primitives", "vec", "ndarray_latest"] }
2325
criterion = { version = "0.5.1", default-features = false, features = ["html_reports"] }
2426
exr = "1.72.0"
2527
fastapprox = "0.3.1"
28+
finitediff = { version = "0.1.4", features = ["ndarray"] }
2629
log = "0.4.22"
27-
nalgebra = { version = "0.33.0", default-features = false, features = ["std", "matrixmultiply"] }
30+
nalgebra = { version = "0.33.1", default-features = false, features = ["std", "matrixmultiply"] }
31+
ndarray = "0.15.6"
32+
ndarray-linalg = { version = "0.16.0", features = ["intel-mkl-static"] }
2833
num = "0.4.3"
29-
num-traits = "0.2"
34+
num-traits = "0.2.19"
3035
num_cpus = "1.16.0"
3136
petgraph = { version = "0.6", default-features = false, features = ["stable_graph"] }
3237
plotters = { version = "0.3.7", default-features = false, features = ["image", "bitmap_encoder", "bitmap_backend", "line_series", "ttf"] }

lib/rust/mmscenegraph/Cargo.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,14 @@ harness = false
1717
[dependencies]
1818
anyhow = { workspace = true }
1919
approx = { workspace = true }
20+
argmin = { workspace = true }
21+
argmin-math = { workspace = true }
2022
fastapprox = { workspace = true }
23+
finitediff = { workspace = true }
2124
log = { workspace = true }
2225
nalgebra = { workspace = true }
26+
ndarray = { workspace = true }
27+
ndarray-linalg = { workspace = true }
2328
num-traits = { workspace = true }
2429
petgraph = { workspace = true }
2530
rand = { workspace = true }

lib/rust/mmscenegraph/src/math/curve_fit.rs

Lines changed: 341 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,14 @@
1818
// ====================================================================
1919
//
2020

21+
use anyhow::bail;
2122
use anyhow::Result;
23+
use argmin;
24+
use argmin::core::Gradient;
25+
use argmin::core::State;
26+
use finitediff::FiniteDiff;
2227
use log::debug;
28+
use ndarray::{Array1, Array2};
2329

2430
use crate::constant::Real;
2531
use crate::math::line::curve_fit_linear_regression_type1;
@@ -92,3 +98,338 @@ pub fn linear_regression(
9298

9399
Ok((point, angle))
94100
}
101+
102+
/// Return 'min_value' to 'max_value' linearly, for a 'mix' value
103+
/// between 0.0 and 1.0.
104+
fn lerp_f64(min_value: f64, max_value: f64, mix: f64) -> f64 {
105+
((1.0 - mix) * min_value) + (mix * max_value)
106+
}
107+
108+
/// Return a value between 0.0 and 1.0 for a value in an input range
109+
/// 'from' to 'to'.
110+
fn inverse_lerp_f64(from: f64, to: f64, value: f64) -> f64 {
111+
(value - from) / (to - from)
112+
}
113+
114+
// /// Remap from an 'original' value range to a 'target' value range.
115+
// fn remap_f64(
116+
// original_from: f64,
117+
// original_to: f64,
118+
// target_from: f64,
119+
// target_to: f64,
120+
// value: f64,
121+
// ) -> f64 {
122+
// let map_to_original_range =
123+
// inverse_lerp_f64(original_from, original_to, value);
124+
// lerp_f64(target_from, target_to, map_to_original_range)
125+
// }
126+
127+
fn linear_interpolate_point_y_value_at_value_x(
128+
value_x: f64,
129+
point_a_x: f64,
130+
point_a_y: f64,
131+
point_b_x: f64,
132+
point_b_y: f64,
133+
) -> f64 {
134+
let mix_x = inverse_lerp_f64(point_a_x, point_b_x, value_x);
135+
let mix_y = lerp_f64(point_a_y, point_b_y, mix_x);
136+
mix_y
137+
}
138+
139+
fn linear_interpolate_y_value_at_value_x(
140+
value_x: f64,
141+
point_a_x: f64,
142+
point_a_y: f64,
143+
point_b_x: f64,
144+
point_b_y: f64,
145+
point_c_x: f64,
146+
point_c_y: f64,
147+
) -> f64 {
148+
if value_x < point_b_x {
149+
linear_interpolate_point_y_value_at_value_x(
150+
value_x, point_a_x, point_a_y, point_b_x, point_b_y,
151+
)
152+
} else if value_x > point_b_x {
153+
linear_interpolate_point_y_value_at_value_x(
154+
value_x, point_b_x, point_b_y, point_c_x, point_c_y,
155+
)
156+
} else {
157+
point_b_y
158+
}
159+
}
160+
161+
#[derive(Debug)]
162+
struct CurveFitLinearN3Problem {
163+
// Curve values we are trying to fit to.
164+
reference_values: Vec<(f64, f64)>,
165+
reference_values_first_x: usize,
166+
// reference_values_last_x: usize,
167+
168+
// These are parameters that are hard-coded.
169+
point_a_x: f64,
170+
point_b_x: f64,
171+
point_c_x: f64,
172+
}
173+
174+
impl CurveFitLinearN3Problem {
175+
fn new(
176+
point_a_x: f64,
177+
point_b_x: f64,
178+
point_c_x: f64,
179+
reference_curve: &[(f64, f64)],
180+
) -> Self {
181+
// let count = reference_curve.len();
182+
let x_first: usize = reference_curve[0].0.floor() as usize;
183+
// let x_last: usize = reference_curve[count - 1].0.ceil() as usize;
184+
185+
let reference_values: Vec<(f64, f64)> =
186+
reference_curve.iter().map(|x| *x).collect();
187+
188+
Self {
189+
reference_values,
190+
reference_values_first_x: x_first,
191+
// reference_values_last_x: x_last,
192+
point_a_x,
193+
point_b_x,
194+
point_c_x,
195+
}
196+
}
197+
198+
fn parameter_count(&self) -> usize {
199+
// 3 x 2D points is 6, but the X axis values are always
200+
// locked, therefore 3 values are known, and we do not need to
201+
// solve for them.
202+
3
203+
}
204+
205+
fn reference_y_value_at_value_x(&self, value_x: f64) -> f64 {
206+
let value_start = self.point_a_x.floor() as usize;
207+
let value_end = self.point_c_x.ceil() as usize;
208+
209+
let mut value_index = value_x.round() as usize;
210+
if value_index < value_start {
211+
value_index = value_start;
212+
} else if value_index > value_end {
213+
value_index = value_end;
214+
}
215+
216+
let mut index = value_index - self.reference_values_first_x;
217+
if index > self.reference_values.len() {
218+
index = self.reference_values.len();
219+
}
220+
221+
self.reference_values[index].1
222+
}
223+
224+
fn residuals(
225+
&self,
226+
point_a_y: f64,
227+
point_b_y: f64,
228+
point_c_y: f64,
229+
) -> Vec<f64> {
230+
self.reference_values
231+
.iter()
232+
.map(|x| {
233+
let value_x = x.0;
234+
let curve_y = linear_interpolate_y_value_at_value_x(
235+
value_x,
236+
self.point_a_x,
237+
point_a_y,
238+
self.point_b_x,
239+
point_b_y,
240+
self.point_c_x,
241+
point_c_y,
242+
);
243+
let data_y = self.reference_y_value_at_value_x(value_x);
244+
(curve_y - data_y).abs()
245+
})
246+
.collect()
247+
}
248+
}
249+
250+
impl argmin::core::CostFunction for CurveFitLinearN3Problem {
251+
type Param = Array1<f64>;
252+
type Output = f64;
253+
254+
fn cost(
255+
&self,
256+
parameters: &Self::Param,
257+
) -> Result<Self::Output, argmin::core::Error> {
258+
debug!("Cost: parameters={parameters:?}");
259+
260+
let parameter_count = self.parameter_count();
261+
assert_eq!(parameters.len(), parameter_count);
262+
263+
let residuals_data =
264+
self.residuals(parameters[0], parameters[1], parameters[2]);
265+
let residuals = Array1::from_vec(residuals_data);
266+
267+
let residuals_sum = residuals.sum();
268+
debug!("residuals_sum: {residuals_sum}");
269+
270+
Ok(residuals_sum * residuals_sum)
271+
}
272+
}
273+
274+
impl argmin::core::Gradient for CurveFitLinearN3Problem {
275+
type Param = Array1<f64>;
276+
type Gradient = Array1<f64>;
277+
278+
fn gradient(
279+
&self,
280+
parameters: &Self::Param,
281+
) -> Result<Self::Gradient, argmin::core::Error> {
282+
debug!("Gradient: parameters={parameters:?}");
283+
284+
let parameter_count = self.parameter_count();
285+
assert_eq!(parameters.len(), parameter_count);
286+
287+
let vector = (*parameters).forward_diff(&|x| {
288+
let sum: f64 = self.residuals(x[0], x[1], x[2]).into_iter().sum();
289+
debug!("forward_diff residuals_sum: {sum}");
290+
sum * sum
291+
});
292+
293+
Ok(vector)
294+
}
295+
}
296+
297+
impl argmin::core::Hessian for CurveFitLinearN3Problem {
298+
type Param = Array1<f64>;
299+
type Hessian = Array2<f64>;
300+
301+
fn hessian(
302+
&self,
303+
parameters: &Self::Param,
304+
) -> Result<Self::Hessian, argmin::core::Error> {
305+
debug!("Hessian: parameters={parameters:?}");
306+
307+
let parameter_count = self.parameter_count();
308+
assert_eq!(parameters.len(), parameter_count);
309+
310+
let matrix =
311+
(*parameters).forward_hessian(&|x| self.gradient(x).unwrap());
312+
313+
Ok(matrix)
314+
}
315+
}
316+
317+
/// Perform a non-linear least-squares fits for a line with 3 points.
318+
///
319+
/// The approach here is to start with a linear regression, to get a
320+
/// starting point, and then refine the solution using a solver.
321+
///
322+
/// Rather than solve a direct solution we aim to solve successively
323+
/// more difficult problem in multiple steps, as each one improves the
324+
/// overall fit to the source data values.
325+
///
326+
/// In the future we may wish to allow different types of curve
327+
/// interpolation between the 3 points.
328+
pub fn nonlinear_line_n3(
329+
x_values: &[f64],
330+
y_values: &[f64],
331+
) -> Result<(Point2, Point2, Point2)> {
332+
assert_eq!(x_values.len(), y_values.len());
333+
let value_count = x_values.len();
334+
assert!(value_count > 2);
335+
336+
let mut point_x = 0.0;
337+
let mut point_y = 0.0;
338+
let mut angle = 0.0;
339+
curve_fit_linear_regression_type1(
340+
&x_values,
341+
&y_values,
342+
&mut point_x,
343+
&mut point_y,
344+
&mut angle,
345+
);
346+
debug!("angle={angle}");
347+
348+
let dir_x = angle.cos();
349+
let dir_y = angle.sin();
350+
debug!("point_x={point_x} point_y={point_y}");
351+
debug!("dir_x={dir_x} dir_y={dir_y}");
352+
353+
let x_first = x_values[0];
354+
let y_first = y_values[0];
355+
let x_last = x_values[value_count - 1];
356+
let y_last = y_values[value_count - 1];
357+
let x_diff = (x_last - x_first) as f64 / 2.0;
358+
let y_diff = (y_last - y_first) as f64 / 2.0;
359+
debug!("x_first={x_first}");
360+
debug!("y_first={y_first}");
361+
debug!("x_last={x_last}");
362+
debug!("y_last={y_last}");
363+
debug!("x_diff={x_diff}");
364+
debug!("y_diff={y_diff}");
365+
366+
// Scale up to X axis range.
367+
let dir_x = dir_x * x_diff;
368+
let dir_y = dir_y * x_diff;
369+
370+
// Define initial parameter vector
371+
let point_a = Point2 {
372+
x: point_x - dir_x,
373+
y: point_y - dir_y,
374+
};
375+
let point_b = Point2 {
376+
x: point_x,
377+
y: point_y,
378+
};
379+
let point_c = Point2 {
380+
x: point_x + dir_x,
381+
y: point_y + dir_y,
382+
};
383+
// NOTE: point_a.x, point_a.y and point_c.x are omitted as these
384+
// are hard-coded as known parameters and do not need to be
385+
// solved.
386+
let initial_parameters_vec = vec![point_a.y, point_b.y, point_c.y];
387+
let initial_parameters: Array1<f64> = Array1::from(initial_parameters_vec);
388+
println!("initial_parameters={initial_parameters:?}");
389+
390+
// Define the problem
391+
let reference_values: Vec<(f64, f64)> = x_values
392+
.iter()
393+
.zip(y_values)
394+
.map(|x| (*x.0, *x.1))
395+
.collect();
396+
let problem = CurveFitLinearN3Problem::new(
397+
// Known parameters
398+
point_a.x,
399+
point_b.x,
400+
point_c.x,
401+
// Curve values to match-to.
402+
&reference_values,
403+
);
404+
405+
// Set up the subproblem
406+
let subproblem: argmin::solver::trustregion::CauchyPoint<f64> =
407+
argmin::solver::trustregion::CauchyPoint::new();
408+
409+
// Set up solver
410+
let solver = argmin::solver::trustregion::TrustRegion::new(subproblem);
411+
412+
// Run solver
413+
let result = argmin::core::Executor::new(problem, solver)
414+
.configure(|state| state.param(initial_parameters).max_iters(30))
415+
.run()?;
416+
debug!("Solver Result: {result}");
417+
418+
match result.state().get_best_param() {
419+
Some(parameters) => Ok((
420+
Point2 {
421+
x: point_a.x,
422+
y: parameters[0],
423+
},
424+
Point2 {
425+
x: point_b.x,
426+
y: parameters[1],
427+
},
428+
Point2 {
429+
x: point_c.x,
430+
y: parameters[2],
431+
},
432+
)),
433+
None => bail!("Solve failed."),
434+
}
435+
}

0 commit comments

Comments
 (0)