Skip to content

Commit 7113f00

Browse files
Add automatic format checking (kmolan#5)
* Add format checking via `cargo fmt` to github build workflow * Format all rust files with cargo fmt
1 parent 06466b5 commit 7113f00

32 files changed

+4896
-4316
lines changed

.github/workflows/build-tests.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ jobs:
1717

1818
steps:
1919
- uses: actions/checkout@v4
20+
- name: format
21+
run: cargo fmt --check
2022
- name: Build
2123
run: cargo build --verbose
2224
- name: Run tests default

src/approximation/linear_approximation.rs

Lines changed: 55 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -3,105 +3,101 @@ use crate::numerical_derivative::derivator::DerivatorMultiVariable;
33
use num_complex::ComplexFloat;
44

55
#[derive(Debug)]
6-
pub struct LinearApproximationResult<T: ComplexFloat, const NUM_VARS: usize>
7-
{
6+
pub struct LinearApproximationResult<T: ComplexFloat, const NUM_VARS: usize> {
87
pub intercept: T,
9-
pub coefficients: [T; NUM_VARS]
8+
pub coefficients: [T; NUM_VARS],
109
}
1110

1211
#[derive(Debug)]
13-
pub struct LinearApproximationPredictionMetrics<T: ComplexFloat>
14-
{
12+
pub struct LinearApproximationPredictionMetrics<T: ComplexFloat> {
1513
pub mean_absolute_error: T::Real,
1614
pub mean_squared_error: T::Real,
1715
pub root_mean_squared_error: T::Real,
1816
pub r_squared: T::Real,
19-
pub adjusted_r_squared: T::Real
17+
pub adjusted_r_squared: T::Real,
2018
}
2119

22-
impl<T: ComplexFloat, const NUM_VARS: usize> LinearApproximationResult<T, NUM_VARS>
23-
{
20+
impl<T: ComplexFloat, const NUM_VARS: usize> LinearApproximationResult<T, NUM_VARS> {
2421
///Helper function if you don't care about the details and just want the predictor directly
25-
pub fn get_prediction_value(&self, args: &[T; NUM_VARS]) -> T
26-
{
22+
pub fn get_prediction_value(&self, args: &[T; NUM_VARS]) -> T {
2723
let mut result = self.intercept;
28-
for (iter, arg) in args.iter().enumerate().take(NUM_VARS)
29-
{
30-
result = result + self.coefficients[iter]**arg;
24+
for (iter, arg) in args.iter().enumerate().take(NUM_VARS) {
25+
result = result + self.coefficients[iter] * *arg;
3126
}
32-
27+
3328
return result;
3429
}
3530

3631
//get prediction metrics by feeding a list of points and the original function
37-
pub fn get_prediction_metrics<const NUM_POINTS: usize>(&self, points: &[[T; NUM_VARS]; NUM_POINTS], original_function: &dyn Fn(&[T; NUM_VARS]) -> T) -> LinearApproximationPredictionMetrics<T>
38-
{
32+
pub fn get_prediction_metrics<const NUM_POINTS: usize>(
33+
&self,
34+
points: &[[T; NUM_VARS]; NUM_POINTS],
35+
original_function: &dyn Fn(&[T; NUM_VARS]) -> T,
36+
) -> LinearApproximationPredictionMetrics<T> {
3937
//let num_points = NUM_POINTS as f64;
4038
let mut mae = T::zero();
4139
let mut mse = T::zero();
42-
43-
for point in points.iter().take(NUM_POINTS)
44-
{
40+
41+
for point in points.iter().take(NUM_POINTS) {
4542
let predicted_y = self.get_prediction_value(point);
46-
43+
4744
mae = mae + (predicted_y - original_function(point));
4845
mse = mse + num_complex::ComplexFloat::powi(predicted_y - original_function(point), 2);
4946
}
5047

51-
mae = mae/T::from(NUM_POINTS).unwrap();
52-
mse = mse/T::from(NUM_POINTS).unwrap();
48+
mae = mae / T::from(NUM_POINTS).unwrap();
49+
mse = mse / T::from(NUM_POINTS).unwrap();
5350

5451
let rmse = mse.sqrt().abs();
5552

5653
let mut r2_numerator = T::zero();
5754
let mut r2_denominator = T::zero();
5855

59-
for point in points.iter().take(NUM_POINTS)
60-
{
56+
for point in points.iter().take(NUM_POINTS) {
6157
let predicted_y = self.get_prediction_value(point);
6258

63-
r2_numerator = r2_numerator + num_complex::ComplexFloat::powi(predicted_y - original_function(point), 2);
64-
r2_denominator = r2_numerator + num_complex::ComplexFloat::powi(mae - original_function(point), 2);
59+
r2_numerator = r2_numerator
60+
+ num_complex::ComplexFloat::powi(predicted_y - original_function(point), 2);
61+
r2_denominator =
62+
r2_numerator + num_complex::ComplexFloat::powi(mae - original_function(point), 2);
6563
}
6664

67-
let r2 = T::one() - (r2_numerator/r2_denominator);
65+
let r2 = T::one() - (r2_numerator / r2_denominator);
6866

69-
let r2_adj = T::one() - (T::one() - r2)*(T::from(NUM_POINTS).unwrap())/(T::from(NUM_POINTS).unwrap() - T::from(2.0).unwrap());
67+
let r2_adj = T::one()
68+
- (T::one() - r2) * (T::from(NUM_POINTS).unwrap())
69+
/ (T::from(NUM_POINTS).unwrap() - T::from(2.0).unwrap());
7070

71-
return LinearApproximationPredictionMetrics
72-
{
71+
return LinearApproximationPredictionMetrics {
7372
mean_absolute_error: mae.abs(),
7473
mean_squared_error: mse.abs(),
7574
root_mean_squared_error: rmse,
7675
r_squared: r2.abs(),
77-
adjusted_r_squared: r2_adj.abs()
76+
adjusted_r_squared: r2_adj.abs(),
7877
};
7978
}
8079
}
8180

82-
pub struct LinearApproximator<D: DerivatorMultiVariable>
83-
{
84-
derivator: D
81+
pub struct LinearApproximator<D: DerivatorMultiVariable> {
82+
derivator: D,
8583
}
8684

87-
impl<D: DerivatorMultiVariable> Default for LinearApproximator<D>
88-
{
89-
fn default() -> Self
90-
{
91-
return LinearApproximator { derivator: D::default() };
85+
impl<D: DerivatorMultiVariable> Default for LinearApproximator<D> {
86+
fn default() -> Self {
87+
return LinearApproximator {
88+
derivator: D::default(),
89+
};
9290
}
9391
}
9492

95-
impl<D: DerivatorMultiVariable> LinearApproximator<D>
96-
{
97-
pub fn from_derivator(derivator: D) -> Self
98-
{
99-
return LinearApproximator {derivator}
93+
impl<D: DerivatorMultiVariable> LinearApproximator<D> {
94+
pub fn from_derivator(derivator: D) -> Self {
95+
return LinearApproximator { derivator };
10096
}
10197

10298
/// For an n-dimensional approximation, the equation is linearized as:
10399
/// coefficient[0]*var_1 + coefficient[1]*var_2 + ... + coefficient[n-1]*var_n + intercept
104-
///
100+
///
105101
/// NOTE: Returns a Result<T, &'static str>
106102
/// Possible &'static str are:
107103
/// NumberOfStepsCannotBeZero -> if the derivative step size is zero
@@ -110,9 +106,9 @@ impl<D: DerivatorMultiVariable> LinearApproximator<D>
110106
///```
111107
///use multicalc::approximation::linear_approximation::*;
112108
///use multicalc::numerical_derivative::finite_difference::MultiVariableSolver;
113-
///
109+
///
114110
///let function_to_approximate = | args: &[f64; 3] | -> f64
115-
///{
111+
///{
116112
/// return args[0] + args[1].powf(2.0) + args[2].powf(3.0);
117113
///};
118114
///
@@ -123,29 +119,30 @@ impl<D: DerivatorMultiVariable> LinearApproximator<D>
123119
///assert!(f64::abs(function_to_approximate(&point) - result.get_prediction_value(&point)) < 1e-9);
124120
/// ```
125121
/// you can also inspect the results of the approximation. For an n-dimensional approximation, the equation is linearized as
126-
///
122+
///
127123
/// [`LinearApproximationResult::intercept`] gives you the required intercept
128124
/// [`LinearApproximationResult::coefficients`] gives you the required coefficients in order
129-
///
125+
///
130126
/// if you don't care about the results and want the predictor directly, use [`LinearApproximationResult::get_prediction_value()`]
131127
/// you can also inspect the prediction metrics by providing list of points, use [`LinearApproximationResult::get_prediction_metrics()`]
132128
///
133-
pub fn get<T: ComplexFloat, const NUM_VARS: usize>(&self, function: &dyn Fn(&[T; NUM_VARS]) -> T, point: &[T; NUM_VARS]) -> Result<LinearApproximationResult<T, NUM_VARS>, &'static str>
134-
{
129+
pub fn get<T: ComplexFloat, const NUM_VARS: usize>(
130+
&self,
131+
function: &dyn Fn(&[T; NUM_VARS]) -> T,
132+
point: &[T; NUM_VARS],
133+
) -> Result<LinearApproximationResult<T, NUM_VARS>, &'static str> {
135134
let mut slopes_ = [T::zero(); NUM_VARS];
136135

137136
let mut intercept_ = function(point);
138137

139-
for iter in 0..NUM_VARS
140-
{
138+
for iter in 0..NUM_VARS {
141139
slopes_[iter] = self.derivator.get(1, function, &[iter], point)?;
142-
intercept_ = intercept_ - slopes_[iter]*point[iter];
140+
intercept_ = intercept_ - slopes_[iter] * point[iter];
143141
}
144142

145-
return Ok(LinearApproximationResult
146-
{
143+
return Ok(LinearApproximationResult {
147144
intercept: intercept_,
148-
coefficients: slopes_
145+
coefficients: slopes_,
149146
});
150147
}
151-
}
148+
}

src/approximation/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@ pub mod linear_approximation;
22
pub mod quadratic_approximation;
33

44
#[cfg(test)]
5-
mod test;
5+
mod test;

0 commit comments

Comments
 (0)