@@ -3,105 +3,101 @@ use crate::numerical_derivative::derivator::DerivatorMultiVariable;
33use num_complex:: ComplexFloat ;
44
55#[ derive( Debug ) ]
6- pub struct LinearApproximationResult < T : ComplexFloat , const NUM_VARS : usize >
7- {
6+ pub struct LinearApproximationResult < T : ComplexFloat , const NUM_VARS : usize > {
87 pub intercept : T ,
9- pub coefficients : [ T ; NUM_VARS ]
8+ pub coefficients : [ T ; NUM_VARS ] ,
109}
1110
1211#[ derive( Debug ) ]
13- pub struct LinearApproximationPredictionMetrics < T : ComplexFloat >
14- {
12+ pub struct LinearApproximationPredictionMetrics < T : ComplexFloat > {
1513 pub mean_absolute_error : T :: Real ,
1614 pub mean_squared_error : T :: Real ,
1715 pub root_mean_squared_error : T :: Real ,
1816 pub r_squared : T :: Real ,
19- pub adjusted_r_squared : T :: Real
17+ pub adjusted_r_squared : T :: Real ,
2018}
2119
22- impl < T : ComplexFloat , const NUM_VARS : usize > LinearApproximationResult < T , NUM_VARS >
23- {
20+ impl < T : ComplexFloat , const NUM_VARS : usize > LinearApproximationResult < T , NUM_VARS > {
2421 ///Helper function if you don't care about the details and just want the predictor directly
25- pub fn get_prediction_value ( & self , args : & [ T ; NUM_VARS ] ) -> T
26- {
22+ pub fn get_prediction_value ( & self , args : & [ T ; NUM_VARS ] ) -> T {
2723 let mut result = self . intercept ;
28- for ( iter, arg) in args. iter ( ) . enumerate ( ) . take ( NUM_VARS )
29- {
30- result = result + self . coefficients [ iter] * * arg;
24+ for ( iter, arg) in args. iter ( ) . enumerate ( ) . take ( NUM_VARS ) {
25+ result = result + self . coefficients [ iter] * * arg;
3126 }
32-
27+
3328 return result;
3429 }
3530
3631 //get prediction metrics by feeding a list of points and the original function
37- pub fn get_prediction_metrics < const NUM_POINTS : usize > ( & self , points : & [ [ T ; NUM_VARS ] ; NUM_POINTS ] , original_function : & dyn Fn ( & [ T ; NUM_VARS ] ) -> T ) -> LinearApproximationPredictionMetrics < T >
38- {
32+ pub fn get_prediction_metrics < const NUM_POINTS : usize > (
33+ & self ,
34+ points : & [ [ T ; NUM_VARS ] ; NUM_POINTS ] ,
35+ original_function : & dyn Fn ( & [ T ; NUM_VARS ] ) -> T ,
36+ ) -> LinearApproximationPredictionMetrics < T > {
3937 //let num_points = NUM_POINTS as f64;
4038 let mut mae = T :: zero ( ) ;
4139 let mut mse = T :: zero ( ) ;
42-
43- for point in points. iter ( ) . take ( NUM_POINTS )
44- {
40+
41+ for point in points. iter ( ) . take ( NUM_POINTS ) {
4542 let predicted_y = self . get_prediction_value ( point) ;
46-
43+
4744 mae = mae + ( predicted_y - original_function ( point) ) ;
4845 mse = mse + num_complex:: ComplexFloat :: powi ( predicted_y - original_function ( point) , 2 ) ;
4946 }
5047
51- mae = mae/ T :: from ( NUM_POINTS ) . unwrap ( ) ;
52- mse = mse/ T :: from ( NUM_POINTS ) . unwrap ( ) ;
48+ mae = mae / T :: from ( NUM_POINTS ) . unwrap ( ) ;
49+ mse = mse / T :: from ( NUM_POINTS ) . unwrap ( ) ;
5350
5451 let rmse = mse. sqrt ( ) . abs ( ) ;
5552
5653 let mut r2_numerator = T :: zero ( ) ;
5754 let mut r2_denominator = T :: zero ( ) ;
5855
59- for point in points. iter ( ) . take ( NUM_POINTS )
60- {
56+ for point in points. iter ( ) . take ( NUM_POINTS ) {
6157 let predicted_y = self . get_prediction_value ( point) ;
6258
63- r2_numerator = r2_numerator + num_complex:: ComplexFloat :: powi ( predicted_y - original_function ( point) , 2 ) ;
64- r2_denominator = r2_numerator + num_complex:: ComplexFloat :: powi ( mae - original_function ( point) , 2 ) ;
59+ r2_numerator = r2_numerator
60+ + num_complex:: ComplexFloat :: powi ( predicted_y - original_function ( point) , 2 ) ;
61+ r2_denominator =
62+ r2_numerator + num_complex:: ComplexFloat :: powi ( mae - original_function ( point) , 2 ) ;
6563 }
6664
67- let r2 = T :: one ( ) - ( r2_numerator/ r2_denominator) ;
65+ let r2 = T :: one ( ) - ( r2_numerator / r2_denominator) ;
6866
69- let r2_adj = T :: one ( ) - ( T :: one ( ) - r2) * ( T :: from ( NUM_POINTS ) . unwrap ( ) ) /( T :: from ( NUM_POINTS ) . unwrap ( ) - T :: from ( 2.0 ) . unwrap ( ) ) ;
67+ let r2_adj = T :: one ( )
68+ - ( T :: one ( ) - r2) * ( T :: from ( NUM_POINTS ) . unwrap ( ) )
69+ / ( T :: from ( NUM_POINTS ) . unwrap ( ) - T :: from ( 2.0 ) . unwrap ( ) ) ;
7070
71- return LinearApproximationPredictionMetrics
72- {
71+ return LinearApproximationPredictionMetrics {
7372 mean_absolute_error : mae. abs ( ) ,
7473 mean_squared_error : mse. abs ( ) ,
7574 root_mean_squared_error : rmse,
7675 r_squared : r2. abs ( ) ,
77- adjusted_r_squared : r2_adj. abs ( )
76+ adjusted_r_squared : r2_adj. abs ( ) ,
7877 } ;
7978 }
8079}
8180
82- pub struct LinearApproximator < D : DerivatorMultiVariable >
83- {
84- derivator : D
81+ pub struct LinearApproximator < D : DerivatorMultiVariable > {
82+ derivator : D ,
8583}
8684
87- impl < D : DerivatorMultiVariable > Default for LinearApproximator < D >
88- {
89- fn default ( ) -> Self
90- {
91- return LinearApproximator { derivator : D :: default ( ) } ;
85+ impl < D : DerivatorMultiVariable > Default for LinearApproximator < D > {
86+ fn default ( ) -> Self {
87+ return LinearApproximator {
88+ derivator : D :: default ( ) ,
89+ } ;
9290 }
9391}
9492
95- impl < D : DerivatorMultiVariable > LinearApproximator < D >
96- {
97- pub fn from_derivator ( derivator : D ) -> Self
98- {
99- return LinearApproximator { derivator}
93+ impl < D : DerivatorMultiVariable > LinearApproximator < D > {
94+ pub fn from_derivator ( derivator : D ) -> Self {
95+ return LinearApproximator { derivator } ;
10096 }
10197
10298 /// For an n-dimensional approximation, the equation is linearized as:
10399 /// coefficient[0]*var_1 + coefficient[1]*var_2 + ... + coefficient[n-1]*var_n + intercept
104- ///
100+ ///
105101 /// NOTE: Returns a Result<T, &'static str>
106102 /// Possible &'static str are:
107103 /// NumberOfStepsCannotBeZero -> if the derivative step size is zero
@@ -110,9 +106,9 @@ impl<D: DerivatorMultiVariable> LinearApproximator<D>
110106 ///```
111107 ///use multicalc::approximation::linear_approximation::*;
112108 ///use multicalc::numerical_derivative::finite_difference::MultiVariableSolver;
113- ///
109+ ///
114110 ///let function_to_approximate = | args: &[f64; 3] | -> f64
115- ///{
111+ ///{
116112 /// return args[0] + args[1].powf(2.0) + args[2].powf(3.0);
117113 ///};
118114 ///
@@ -123,29 +119,30 @@ impl<D: DerivatorMultiVariable> LinearApproximator<D>
123119 ///assert!(f64::abs(function_to_approximate(&point) - result.get_prediction_value(&point)) < 1e-9);
124120 /// ```
125121 /// you can also inspect the results of the approximation. For an n-dimensional approximation, the equation is linearized as
126- ///
122+ ///
127123 /// [`LinearApproximationResult::intercept`] gives you the required intercept
128124 /// [`LinearApproximationResult::coefficients`] gives you the required coefficients in order
129- ///
125+ ///
130126 /// if you don't care about the results and want the predictor directly, use [`LinearApproximationResult::get_prediction_value()`]
131127 /// you can also inspect the prediction metrics by providing list of points, use [`LinearApproximationResult::get_prediction_metrics()`]
132128 ///
133- pub fn get < T : ComplexFloat , const NUM_VARS : usize > ( & self , function : & dyn Fn ( & [ T ; NUM_VARS ] ) -> T , point : & [ T ; NUM_VARS ] ) -> Result < LinearApproximationResult < T , NUM_VARS > , & ' static str >
134- {
129+ pub fn get < T : ComplexFloat , const NUM_VARS : usize > (
130+ & self ,
131+ function : & dyn Fn ( & [ T ; NUM_VARS ] ) -> T ,
132+ point : & [ T ; NUM_VARS ] ,
133+ ) -> Result < LinearApproximationResult < T , NUM_VARS > , & ' static str > {
135134 let mut slopes_ = [ T :: zero ( ) ; NUM_VARS ] ;
136135
137136 let mut intercept_ = function ( point) ;
138137
139- for iter in 0 ..NUM_VARS
140- {
138+ for iter in 0 ..NUM_VARS {
141139 slopes_[ iter] = self . derivator . get ( 1 , function, & [ iter] , point) ?;
142- intercept_ = intercept_ - slopes_[ iter] * point[ iter] ;
140+ intercept_ = intercept_ - slopes_[ iter] * point[ iter] ;
143141 }
144142
145- return Ok ( LinearApproximationResult
146- {
143+ return Ok ( LinearApproximationResult {
147144 intercept : intercept_,
148- coefficients : slopes_
145+ coefficients : slopes_,
149146 } ) ;
150147 }
151- }
148+ }
0 commit comments