@@ -37,6 +37,10 @@ pub struct Point2 {
37
37
}
38
38
39
39
impl Point2 {
40
+ pub fn new ( x : f64 , y : f64 ) -> Self {
41
+ Self { x, y }
42
+ }
43
+
40
44
pub fn x ( self ) -> Real {
41
45
self . x
42
46
}
@@ -433,3 +437,255 @@ pub fn nonlinear_line_n3(
433
437
None => bail ! ( "Solve failed." ) ,
434
438
}
435
439
}
440
+
441
+ #[ derive( Debug ) ]
442
+ pub struct CurveFitLinearNPointsProblem {
443
+ // Curve values we are trying to fit to
444
+ reference_values : Vec < ( f64 , f64 ) > ,
445
+ reference_values_first_x : usize ,
446
+
447
+ // X-coordinates of control points (fixed)
448
+ control_points_x : Vec < f64 > ,
449
+ }
450
+
451
+ impl CurveFitLinearNPointsProblem {
452
+ fn new ( control_points_x : Vec < f64 > , reference_curve : & [ ( f64 , f64 ) ] ) -> Self {
453
+ let x_first: usize = reference_curve[ 0 ] . 0 . floor ( ) as usize ;
454
+ let reference_values: Vec < ( f64 , f64 ) > =
455
+ reference_curve. iter ( ) . copied ( ) . collect ( ) ;
456
+
457
+ Self {
458
+ reference_values,
459
+ reference_values_first_x : x_first,
460
+ control_points_x,
461
+ }
462
+ }
463
+
464
+ fn parameter_count ( & self ) -> usize {
465
+ // Only Y coordinates need to be solved for.
466
+ self . control_points_x . len ( )
467
+ }
468
+
469
+ fn reference_y_value_at_value_x ( & self , value_x : f64 ) -> f64 {
470
+ let value_start = self . control_points_x [ 0 ] . floor ( ) as usize ;
471
+ let value_end = self . control_points_x . last ( ) . unwrap ( ) . ceil ( ) as usize ;
472
+
473
+ let mut value_index = value_x. round ( ) as usize ;
474
+ if value_index < value_start {
475
+ value_index = value_start;
476
+ } else if value_index > value_end {
477
+ value_index = value_end;
478
+ }
479
+
480
+ let mut index = value_index - self . reference_values_first_x ;
481
+ if index >= self . reference_values . len ( ) {
482
+ index = self . reference_values . len ( ) - 1 ;
483
+ }
484
+
485
+ self . reference_values [ index] . 1
486
+ }
487
+
488
+ fn interpolate_y_value_at_x (
489
+ & self ,
490
+ value_x : f64 ,
491
+ control_points_y : & [ f64 ] ,
492
+ ) -> f64 {
493
+ debug_assert_eq ! ( self . control_points_x. len( ) , control_points_y. len( ) ) ;
494
+
495
+ // Find the segment containing value_x.
496
+ for i in 0 ..self . control_points_x . len ( ) - 1 {
497
+ if value_x <= self . control_points_x [ i + 1 ] {
498
+ if value_x == self . control_points_x [ i + 1 ] {
499
+ return control_points_y[ i + 1 ] ;
500
+ }
501
+ return linear_interpolate_point_y_value_at_value_x (
502
+ value_x,
503
+ self . control_points_x [ i] ,
504
+ control_points_y[ i] ,
505
+ self . control_points_x [ i + 1 ] ,
506
+ control_points_y[ i + 1 ] ,
507
+ ) ;
508
+ }
509
+ }
510
+
511
+ // If we get here, value_x is beyond the last point.
512
+ * control_points_y. last ( ) . unwrap ( )
513
+ }
514
+
515
+ fn residuals ( & self , control_points_y : & [ f64 ] ) -> Vec < f64 > {
516
+ self . reference_values
517
+ . iter ( )
518
+ . map ( |& ( value_x, data_y) | {
519
+ let curve_y =
520
+ self . interpolate_y_value_at_x ( value_x, control_points_y) ;
521
+ ( curve_y - data_y) . abs ( )
522
+ } )
523
+ . collect ( )
524
+ }
525
+ }
526
+
527
+ impl argmin:: core:: CostFunction for CurveFitLinearNPointsProblem {
528
+ type Param = Array1 < f64 > ;
529
+ type Output = f64 ;
530
+
531
+ fn cost (
532
+ & self ,
533
+ parameters : & Self :: Param ,
534
+ ) -> Result < Self :: Output , argmin:: core:: Error > {
535
+ debug ! ( "Cost: parameters={parameters:?}" ) ;
536
+ assert_eq ! ( parameters. len( ) , self . parameter_count( ) ) ;
537
+
538
+ let residuals_data = self . residuals ( parameters. as_slice ( ) . unwrap ( ) ) ;
539
+ let residuals = Array1 :: from_vec ( residuals_data) ;
540
+
541
+ let residuals_sum = residuals. sum ( ) ;
542
+ debug ! ( "residuals_sum: {residuals_sum}" ) ;
543
+
544
+ Ok ( residuals_sum * residuals_sum)
545
+ }
546
+ }
547
+
548
+ impl argmin:: core:: Gradient for CurveFitLinearNPointsProblem {
549
+ type Param = Array1 < f64 > ;
550
+ type Gradient = Array1 < f64 > ;
551
+
552
+ fn gradient (
553
+ & self ,
554
+ parameters : & Self :: Param ,
555
+ ) -> Result < Self :: Gradient , argmin:: core:: Error > {
556
+ debug ! ( "Gradient: parameters={parameters:?}" ) ;
557
+
558
+ assert_eq ! ( parameters. len( ) , self . parameter_count( ) ) ;
559
+
560
+ let vector = ( * parameters) . forward_diff ( & |x| {
561
+ let sum: f64 =
562
+ self . residuals ( x. as_slice ( ) . unwrap ( ) ) . into_iter ( ) . sum ( ) ;
563
+ debug ! ( "forward_diff residuals_sum: {sum}" ) ;
564
+ sum * sum
565
+ } ) ;
566
+
567
+ Ok ( vector)
568
+ }
569
+ }
570
+
571
+ impl argmin:: core:: Hessian for CurveFitLinearNPointsProblem {
572
+ type Param = Array1 < f64 > ;
573
+ type Hessian = Array2 < f64 > ;
574
+
575
+ fn hessian (
576
+ & self ,
577
+ parameters : & Self :: Param ,
578
+ ) -> Result < Self :: Hessian , argmin:: core:: Error > {
579
+ debug ! ( "Hessian: parameters={parameters:?}" ) ;
580
+ assert_eq ! ( parameters. len( ) , self . parameter_count( ) ) ;
581
+
582
+ let matrix =
583
+ ( * parameters) . forward_hessian ( & |x| self . gradient ( x) . unwrap ( ) ) ;
584
+
585
+ Ok ( matrix)
586
+ }
587
+ }
588
+
589
+ pub fn nonlinear_line_n_points (
590
+ x_values : & [ f64 ] ,
591
+ y_values : & [ f64 ] ,
592
+ control_point_count : usize ,
593
+ ) -> Result < Vec < Point2 > > {
594
+ assert_eq ! ( x_values. len( ) , y_values. len( ) ) ;
595
+ let value_count = x_values. len ( ) ;
596
+ assert ! ( value_count > 2 ) ;
597
+ assert ! (
598
+ control_point_count >= 3 ,
599
+ "Must have at least 3 control points"
600
+ ) ;
601
+
602
+ // TODO: Normalize the input values to a known range, say 0.0 to
603
+ // 1.0, solve and then scale back to the original time and value
604
+ // range.
605
+
606
+ // First get initial guess using linear regression
607
+ let mut point_x = 0.0 ;
608
+ let mut point_y = 0.0 ;
609
+ let mut angle = 0.0 ;
610
+ curve_fit_linear_regression_type1 (
611
+ x_values,
612
+ y_values,
613
+ & mut point_x,
614
+ & mut point_y,
615
+ & mut angle,
616
+ ) ;
617
+
618
+ let dir_x = angle. cos ( ) ;
619
+ let dir_y = angle. sin ( ) ;
620
+ debug ! ( "point_x={point_x} point_y={point_y}" ) ;
621
+ debug ! ( "dir_x={dir_x} dir_y={dir_y}" ) ;
622
+
623
+ // Calculate the range of x values
624
+ let x_first = x_values[ 0 ] ;
625
+ let x_last = x_values[ value_count - 1 ] ;
626
+ let x_range = x_last - x_first;
627
+
628
+ // Generate evenly spaced control points along x-axis
629
+ let mut control_points_x = Vec :: with_capacity ( control_point_count) ;
630
+ let mut initial_y_values = Vec :: with_capacity ( control_point_count) ;
631
+ for i in 0 ..control_point_count {
632
+ let mix = i as f64 / ( control_point_count - 1 ) as f64 ;
633
+ let x = ( x_first + ( mix * x_range) ) . floor ( ) ;
634
+ control_points_x. push ( x) ;
635
+
636
+ // Initial y-values based on linear regression line.
637
+ let y = point_y + ( dir_y * ( x - point_x) / dir_x) ;
638
+ initial_y_values. push ( y) ;
639
+ }
640
+
641
+ // Create reference values
642
+ let reference_values: Vec < ( f64 , f64 ) > = x_values
643
+ . iter ( )
644
+ . zip ( y_values)
645
+ . map ( |( & x, & y) | ( x, y) )
646
+ . collect ( ) ;
647
+
648
+ // Define the problem
649
+ let problem = CurveFitLinearNPointsProblem :: new (
650
+ control_points_x. clone ( ) ,
651
+ & reference_values,
652
+ ) ;
653
+
654
+ // Set up solver
655
+ let epsilon = 1e-3 ;
656
+ let condition =
657
+ argmin:: solver:: linesearch:: condition:: ArmijoCondition :: new ( 1e-5 ) ?;
658
+ let linesearch =
659
+ argmin:: solver:: linesearch:: BacktrackingLineSearch :: new ( condition)
660
+ . rho ( 0.5 ) ?;
661
+ let solver = argmin:: solver:: quasinewton:: BFGS :: new ( linesearch)
662
+ . with_tolerance_cost ( epsilon) ?;
663
+
664
+ // Run solver
665
+ let initial_parameters = Array1 :: from ( initial_y_values) ;
666
+ let initial_hessian: Array2 < f64 > = Array2 :: eye ( control_point_count) ;
667
+ let result = argmin:: core:: Executor :: new ( problem, solver)
668
+ . configure ( |state| {
669
+ state
670
+ . param ( initial_parameters)
671
+ . inv_hessian ( initial_hessian)
672
+ . max_iters ( 50 )
673
+ } )
674
+ . run ( ) ?;
675
+
676
+ debug ! ( "Solver Result: {result}" ) ;
677
+
678
+ match result. state ( ) . get_best_param ( ) {
679
+ Some ( parameters) => {
680
+ let mut control_points = Vec :: with_capacity ( control_point_count) ;
681
+ for i in 0 ..control_point_count {
682
+ control_points. push ( Point2 {
683
+ x : control_points_x[ i] ,
684
+ y : parameters[ i] ,
685
+ } ) ;
686
+ }
687
+ Ok ( control_points)
688
+ }
689
+ None => bail ! ( "Solve failed." ) ,
690
+ }
691
+ }
0 commit comments