|
18 | 18 | // ====================================================================
|
19 | 19 | //
|
20 | 20 |
|
| 21 | +use anyhow::bail; |
21 | 22 | use anyhow::Result;
|
| 23 | +use argmin; |
| 24 | +use argmin::core::Gradient; |
| 25 | +use argmin::core::State; |
| 26 | +use finitediff::FiniteDiff; |
22 | 27 | use log::debug;
|
| 28 | +use ndarray::{Array1, Array2}; |
23 | 29 |
|
24 | 30 | use crate::constant::Real;
|
25 | 31 | use crate::math::line::curve_fit_linear_regression_type1;
|
@@ -92,3 +98,338 @@ pub fn linear_regression(
|
92 | 98 |
|
93 | 99 | Ok((point, angle))
|
94 | 100 | }
|
| 101 | + |
| 102 | +/// Return 'min_value' to 'max_value' linearly, for a 'mix' value |
| 103 | +/// between 0.0 and 1.0. |
| 104 | +fn lerp_f64(min_value: f64, max_value: f64, mix: f64) -> f64 { |
| 105 | + ((1.0 - mix) * min_value) + (mix * max_value) |
| 106 | +} |
| 107 | + |
| 108 | +/// Return a value between 0.0 and 1.0 for a value in an input range |
| 109 | +/// 'from' to 'to'. |
| 110 | +fn inverse_lerp_f64(from: f64, to: f64, value: f64) -> f64 { |
| 111 | + (value - from) / (to - from) |
| 112 | +} |
| 113 | + |
| 114 | +// /// Remap from an 'original' value range to a 'target' value range. |
| 115 | +// fn remap_f64( |
| 116 | +// original_from: f64, |
| 117 | +// original_to: f64, |
| 118 | +// target_from: f64, |
| 119 | +// target_to: f64, |
| 120 | +// value: f64, |
| 121 | +// ) -> f64 { |
| 122 | +// let map_to_original_range = |
| 123 | +// inverse_lerp_f64(original_from, original_to, value); |
| 124 | +// lerp_f64(target_from, target_to, map_to_original_range) |
| 125 | +// } |
| 126 | + |
| 127 | +fn linear_interpolate_point_y_value_at_value_x( |
| 128 | + value_x: f64, |
| 129 | + point_a_x: f64, |
| 130 | + point_a_y: f64, |
| 131 | + point_b_x: f64, |
| 132 | + point_b_y: f64, |
| 133 | +) -> f64 { |
| 134 | + let mix_x = inverse_lerp_f64(point_a_x, point_b_x, value_x); |
| 135 | + let mix_y = lerp_f64(point_a_y, point_b_y, mix_x); |
| 136 | + mix_y |
| 137 | +} |
| 138 | + |
| 139 | +fn linear_interpolate_y_value_at_value_x( |
| 140 | + value_x: f64, |
| 141 | + point_a_x: f64, |
| 142 | + point_a_y: f64, |
| 143 | + point_b_x: f64, |
| 144 | + point_b_y: f64, |
| 145 | + point_c_x: f64, |
| 146 | + point_c_y: f64, |
| 147 | +) -> f64 { |
| 148 | + if value_x < point_b_x { |
| 149 | + linear_interpolate_point_y_value_at_value_x( |
| 150 | + value_x, point_a_x, point_a_y, point_b_x, point_b_y, |
| 151 | + ) |
| 152 | + } else if value_x > point_b_x { |
| 153 | + linear_interpolate_point_y_value_at_value_x( |
| 154 | + value_x, point_b_x, point_b_y, point_c_x, point_c_y, |
| 155 | + ) |
| 156 | + } else { |
| 157 | + point_b_y |
| 158 | + } |
| 159 | +} |
| 160 | + |
| 161 | +#[derive(Debug)] |
| 162 | +struct CurveFitLinearN3Problem { |
| 163 | + // Curve values we are trying to fit to. |
| 164 | + reference_values: Vec<(f64, f64)>, |
| 165 | + reference_values_first_x: usize, |
| 166 | + // reference_values_last_x: usize, |
| 167 | + |
| 168 | + // These are parameters that are hard-coded. |
| 169 | + point_a_x: f64, |
| 170 | + point_b_x: f64, |
| 171 | + point_c_x: f64, |
| 172 | +} |
| 173 | + |
| 174 | +impl CurveFitLinearN3Problem { |
| 175 | + fn new( |
| 176 | + point_a_x: f64, |
| 177 | + point_b_x: f64, |
| 178 | + point_c_x: f64, |
| 179 | + reference_curve: &[(f64, f64)], |
| 180 | + ) -> Self { |
| 181 | + // let count = reference_curve.len(); |
| 182 | + let x_first: usize = reference_curve[0].0.floor() as usize; |
| 183 | + // let x_last: usize = reference_curve[count - 1].0.ceil() as usize; |
| 184 | + |
| 185 | + let reference_values: Vec<(f64, f64)> = |
| 186 | + reference_curve.iter().map(|x| *x).collect(); |
| 187 | + |
| 188 | + Self { |
| 189 | + reference_values, |
| 190 | + reference_values_first_x: x_first, |
| 191 | + // reference_values_last_x: x_last, |
| 192 | + point_a_x, |
| 193 | + point_b_x, |
| 194 | + point_c_x, |
| 195 | + } |
| 196 | + } |
| 197 | + |
| 198 | + fn parameter_count(&self) -> usize { |
| 199 | + // 3 x 2D points is 6, but the X axis values are always |
| 200 | + // locked, therefore 3 values are known, and we do not need to |
| 201 | + // solve for them. |
| 202 | + 3 |
| 203 | + } |
| 204 | + |
| 205 | + fn reference_y_value_at_value_x(&self, value_x: f64) -> f64 { |
| 206 | + let value_start = self.point_a_x.floor() as usize; |
| 207 | + let value_end = self.point_c_x.ceil() as usize; |
| 208 | + |
| 209 | + let mut value_index = value_x.round() as usize; |
| 210 | + if value_index < value_start { |
| 211 | + value_index = value_start; |
| 212 | + } else if value_index > value_end { |
| 213 | + value_index = value_end; |
| 214 | + } |
| 215 | + |
| 216 | + let mut index = value_index - self.reference_values_first_x; |
| 217 | + if index > self.reference_values.len() { |
| 218 | + index = self.reference_values.len(); |
| 219 | + } |
| 220 | + |
| 221 | + self.reference_values[index].1 |
| 222 | + } |
| 223 | + |
| 224 | + fn residuals( |
| 225 | + &self, |
| 226 | + point_a_y: f64, |
| 227 | + point_b_y: f64, |
| 228 | + point_c_y: f64, |
| 229 | + ) -> Vec<f64> { |
| 230 | + self.reference_values |
| 231 | + .iter() |
| 232 | + .map(|x| { |
| 233 | + let value_x = x.0; |
| 234 | + let curve_y = linear_interpolate_y_value_at_value_x( |
| 235 | + value_x, |
| 236 | + self.point_a_x, |
| 237 | + point_a_y, |
| 238 | + self.point_b_x, |
| 239 | + point_b_y, |
| 240 | + self.point_c_x, |
| 241 | + point_c_y, |
| 242 | + ); |
| 243 | + let data_y = self.reference_y_value_at_value_x(value_x); |
| 244 | + (curve_y - data_y).abs() |
| 245 | + }) |
| 246 | + .collect() |
| 247 | + } |
| 248 | +} |
| 249 | + |
| 250 | +impl argmin::core::CostFunction for CurveFitLinearN3Problem { |
| 251 | + type Param = Array1<f64>; |
| 252 | + type Output = f64; |
| 253 | + |
| 254 | + fn cost( |
| 255 | + &self, |
| 256 | + parameters: &Self::Param, |
| 257 | + ) -> Result<Self::Output, argmin::core::Error> { |
| 258 | + debug!("Cost: parameters={parameters:?}"); |
| 259 | + |
| 260 | + let parameter_count = self.parameter_count(); |
| 261 | + assert_eq!(parameters.len(), parameter_count); |
| 262 | + |
| 263 | + let residuals_data = |
| 264 | + self.residuals(parameters[0], parameters[1], parameters[2]); |
| 265 | + let residuals = Array1::from_vec(residuals_data); |
| 266 | + |
| 267 | + let residuals_sum = residuals.sum(); |
| 268 | + debug!("residuals_sum: {residuals_sum}"); |
| 269 | + |
| 270 | + Ok(residuals_sum * residuals_sum) |
| 271 | + } |
| 272 | +} |
| 273 | + |
| 274 | +impl argmin::core::Gradient for CurveFitLinearN3Problem { |
| 275 | + type Param = Array1<f64>; |
| 276 | + type Gradient = Array1<f64>; |
| 277 | + |
| 278 | + fn gradient( |
| 279 | + &self, |
| 280 | + parameters: &Self::Param, |
| 281 | + ) -> Result<Self::Gradient, argmin::core::Error> { |
| 282 | + debug!("Gradient: parameters={parameters:?}"); |
| 283 | + |
| 284 | + let parameter_count = self.parameter_count(); |
| 285 | + assert_eq!(parameters.len(), parameter_count); |
| 286 | + |
| 287 | + let vector = (*parameters).forward_diff(&|x| { |
| 288 | + let sum: f64 = self.residuals(x[0], x[1], x[2]).into_iter().sum(); |
| 289 | + debug!("forward_diff residuals_sum: {sum}"); |
| 290 | + sum * sum |
| 291 | + }); |
| 292 | + |
| 293 | + Ok(vector) |
| 294 | + } |
| 295 | +} |
| 296 | + |
| 297 | +impl argmin::core::Hessian for CurveFitLinearN3Problem { |
| 298 | + type Param = Array1<f64>; |
| 299 | + type Hessian = Array2<f64>; |
| 300 | + |
| 301 | + fn hessian( |
| 302 | + &self, |
| 303 | + parameters: &Self::Param, |
| 304 | + ) -> Result<Self::Hessian, argmin::core::Error> { |
| 305 | + debug!("Hessian: parameters={parameters:?}"); |
| 306 | + |
| 307 | + let parameter_count = self.parameter_count(); |
| 308 | + assert_eq!(parameters.len(), parameter_count); |
| 309 | + |
| 310 | + let matrix = |
| 311 | + (*parameters).forward_hessian(&|x| self.gradient(x).unwrap()); |
| 312 | + |
| 313 | + Ok(matrix) |
| 314 | + } |
| 315 | +} |
| 316 | + |
| 317 | +/// Perform a non-linear least-squares fits for a line with 3 points. |
| 318 | +/// |
| 319 | +/// The approach here is to start with a linear regression, to get a |
| 320 | +/// starting point, and then refine the solution using a solver. |
| 321 | +/// |
| 322 | +/// Rather than solve a direct solution we aim to solve successively |
| 323 | +/// more difficult problem in multiple steps, as each one improves the |
| 324 | +/// overall fit to the source data values. |
| 325 | +/// |
| 326 | +/// In the future we may wish to allow different types of curve |
| 327 | +/// interpolation between the 3 points. |
| 328 | +pub fn nonlinear_line_n3( |
| 329 | + x_values: &[f64], |
| 330 | + y_values: &[f64], |
| 331 | +) -> Result<(Point2, Point2, Point2)> { |
| 332 | + assert_eq!(x_values.len(), y_values.len()); |
| 333 | + let value_count = x_values.len(); |
| 334 | + assert!(value_count > 2); |
| 335 | + |
| 336 | + let mut point_x = 0.0; |
| 337 | + let mut point_y = 0.0; |
| 338 | + let mut angle = 0.0; |
| 339 | + curve_fit_linear_regression_type1( |
| 340 | + &x_values, |
| 341 | + &y_values, |
| 342 | + &mut point_x, |
| 343 | + &mut point_y, |
| 344 | + &mut angle, |
| 345 | + ); |
| 346 | + debug!("angle={angle}"); |
| 347 | + |
| 348 | + let dir_x = angle.cos(); |
| 349 | + let dir_y = angle.sin(); |
| 350 | + debug!("point_x={point_x} point_y={point_y}"); |
| 351 | + debug!("dir_x={dir_x} dir_y={dir_y}"); |
| 352 | + |
| 353 | + let x_first = x_values[0]; |
| 354 | + let y_first = y_values[0]; |
| 355 | + let x_last = x_values[value_count - 1]; |
| 356 | + let y_last = y_values[value_count - 1]; |
| 357 | + let x_diff = (x_last - x_first) as f64 / 2.0; |
| 358 | + let y_diff = (y_last - y_first) as f64 / 2.0; |
| 359 | + debug!("x_first={x_first}"); |
| 360 | + debug!("y_first={y_first}"); |
| 361 | + debug!("x_last={x_last}"); |
| 362 | + debug!("y_last={y_last}"); |
| 363 | + debug!("x_diff={x_diff}"); |
| 364 | + debug!("y_diff={y_diff}"); |
| 365 | + |
| 366 | + // Scale up to X axis range. |
| 367 | + let dir_x = dir_x * x_diff; |
| 368 | + let dir_y = dir_y * x_diff; |
| 369 | + |
| 370 | + // Define initial parameter vector |
| 371 | + let point_a = Point2 { |
| 372 | + x: point_x - dir_x, |
| 373 | + y: point_y - dir_y, |
| 374 | + }; |
| 375 | + let point_b = Point2 { |
| 376 | + x: point_x, |
| 377 | + y: point_y, |
| 378 | + }; |
| 379 | + let point_c = Point2 { |
| 380 | + x: point_x + dir_x, |
| 381 | + y: point_y + dir_y, |
| 382 | + }; |
| 383 | + // NOTE: point_a.x, point_a.y and point_c.x are omitted as these |
| 384 | + // are hard-coded as known parameters and do not need to be |
| 385 | + // solved. |
| 386 | + let initial_parameters_vec = vec![point_a.y, point_b.y, point_c.y]; |
| 387 | + let initial_parameters: Array1<f64> = Array1::from(initial_parameters_vec); |
| 388 | + println!("initial_parameters={initial_parameters:?}"); |
| 389 | + |
| 390 | + // Define the problem |
| 391 | + let reference_values: Vec<(f64, f64)> = x_values |
| 392 | + .iter() |
| 393 | + .zip(y_values) |
| 394 | + .map(|x| (*x.0, *x.1)) |
| 395 | + .collect(); |
| 396 | + let problem = CurveFitLinearN3Problem::new( |
| 397 | + // Known parameters |
| 398 | + point_a.x, |
| 399 | + point_b.x, |
| 400 | + point_c.x, |
| 401 | + // Curve values to match-to. |
| 402 | + &reference_values, |
| 403 | + ); |
| 404 | + |
| 405 | + // Set up the subproblem |
| 406 | + let subproblem: argmin::solver::trustregion::CauchyPoint<f64> = |
| 407 | + argmin::solver::trustregion::CauchyPoint::new(); |
| 408 | + |
| 409 | + // Set up solver |
| 410 | + let solver = argmin::solver::trustregion::TrustRegion::new(subproblem); |
| 411 | + |
| 412 | + // Run solver |
| 413 | + let result = argmin::core::Executor::new(problem, solver) |
| 414 | + .configure(|state| state.param(initial_parameters).max_iters(30)) |
| 415 | + .run()?; |
| 416 | + debug!("Solver Result: {result}"); |
| 417 | + |
| 418 | + match result.state().get_best_param() { |
| 419 | + Some(parameters) => Ok(( |
| 420 | + Point2 { |
| 421 | + x: point_a.x, |
| 422 | + y: parameters[0], |
| 423 | + }, |
| 424 | + Point2 { |
| 425 | + x: point_b.x, |
| 426 | + y: parameters[1], |
| 427 | + }, |
| 428 | + Point2 { |
| 429 | + x: point_c.x, |
| 430 | + y: parameters[2], |
| 431 | + }, |
| 432 | + )), |
| 433 | + None => bail!("Solve failed."), |
| 434 | + } |
| 435 | +} |
0 commit comments