1- use core:: panic;
21use std:: env;
32use std:: collections:: HashSet ;
43use std:: io:: { BufWriter , Write } ;
@@ -9,19 +8,52 @@ use rayon::prelude::*; // For parallel processing
98
109/// The main function parses command-line arguments, processes the input CSV file,
1110/// optionally uses a zones CSV file, and writes the output in MTX format.
12- fn main ( ) {
11+ fn main ( ) -> std :: io :: Result < ( ) > {
1312 let arg: Vec < String > = env:: args ( ) . collect ( ) ;
1413
1514 if arg. len ( ) < 3 {
1615 println ! ( "Usage: csv_to_mtx <input.csv> <output.mtx/.mtx.gz> [zones.csv]" ) ;
17- return ;
16+ return Ok ( ( ) ) ;
1817 }
1918
20- let data = read_csv ( & arg[ 1 ] ) ;
21- let all_zones = get_all_zones ( & arg, & data) ;
19+ let zones_file = if arg. len ( ) > 3 {
20+ Some ( & arg[ 3 ] as & str )
21+ } else {
22+ None
23+ } ;
24+ convert_csv_to_mtx ( & arg[ 1 ] , & arg[ 2 ] , zones_file)
25+ }
26+
27+ /// Converts the input CSV file to MTX format and writes it to the output file.
28+ ///
29+ fn convert_csv_to_mtx (
30+ input_file : & str ,
31+ output_file : & str ,
32+ zones_file : Option < & str > ,
33+ ) -> std:: io:: Result < ( ) > {
34+ let data = match read_csv ( input_file) {
35+ Ok ( data) => data,
36+ Err ( e) => {
37+ eprintln ! ( "Error reading CSV file: {}" , e) ;
38+ return Err ( e) ;
39+ }
40+ } ;
41+ let all_zones = match get_all_zones ( zones_file, & data) {
42+ Ok ( zones) => zones,
43+ Err ( e) => {
44+ eprintln ! ( "Error reading zones file: {}" , e) ;
45+ return Err ( e) ;
46+ }
47+ } ;
2248 println ! ( "Found {} zones" , all_zones. len( ) ) ;
2349 let matrix = build_matrix ( & data, & all_zones) ;
24- write_mtx_file ( & arg[ 2 ] , & all_zones, & matrix) ;
50+ match write_mtx_file ( output_file, & all_zones, & matrix) {
51+ Err ( e) => {
52+ eprintln ! ( "Error writing MTX file: {}" , e) ;
53+ Err ( e)
54+ }
55+ _ => Ok ( ( ) ) ,
56+ }
2557}
2658
2759/// Reads the input CSV file and extracts the data as a vector of tuples containing
@@ -34,13 +66,9 @@ fn main() {
3466///
3567/// # Returns
3668/// A vector of tuples `(i32, i32, f32)` representing the origin, destination, and value.
37- fn read_csv ( input_file : & str ) -> Vec < ( i32 , i32 , f32 ) > {
38- let file = match File :: open ( input_file) {
39- Ok ( f) => f,
40- Err ( e) => {
41- panic ! ( "Error opening file {input_file}: {e}" ) ;
42- }
43- } ;
69+ fn read_csv ( input_file : & str ) -> std:: io:: Result < Vec < ( i32 , i32 , f32 ) > > {
70+ let file = File :: open ( input_file) ?;
71+
4472 let mut rdr = csv:: ReaderBuilder :: new ( )
4573 . has_headers ( false )
4674 . from_reader ( file) ;
@@ -73,13 +101,13 @@ fn read_csv(input_file: &str) -> Vec<(i32, i32, f32)> {
73101 }
74102 }
75103
76- data
104+ Ok ( data)
77105 } else {
78106 // Rectangular format - pass the first record and remaining iterator
79- read_rectangular_csv_from_records ( first_record, records)
107+ Ok ( read_rectangular_csv_from_records ( first_record, records) )
80108 }
81109 } else {
82- Vec :: new ( )
110+ Ok ( Vec :: new ( ) )
83111 }
84112}
85113
@@ -130,29 +158,29 @@ fn read_rectangular_csv_from_records(
130158/// or by extracting unique origins and destinations from the input data.
131159///
132160/// # Arguments
133- /// * `arg ` - The command-line arguments .
161+ /// * `zones_file ` - Optional path to the zones CSV file .
134162/// * `data` - The vector of tuples `(i32, i32, f32)` representing the input data.
135163///
136164/// # Returns
137165/// A sorted vector of unique zone numbers.
138- fn get_all_zones ( arg : & [ String ] , data : & [ ( i32 , i32 , f32 ) ] ) -> Vec < i32 > {
139- if arg . len ( ) > 3 {
140- let zone_file = File :: open ( & arg [ 3 ] ) . unwrap ( ) ;
166+ fn get_all_zones ( zones_file : Option < & str > , data : & [ ( i32 , i32 , f32 ) ] ) -> std :: io :: Result < Vec < i32 > > {
167+ if let Some ( zone_file ) = zones_file {
168+ let zone_file = File :: open ( zone_file ) ? ;
141169 let mut zone_rdr = csv:: Reader :: from_reader ( zone_file) ;
142170 let mut zones: Vec < i32 > = zone_rdr
143171 . records ( )
144172 . filter_map ( |result| result. ok ( ) ?. get ( 0 ) ?. parse ( ) . ok ( ) )
145173 . collect ( ) ;
146174 zones. sort_unstable ( ) ;
147- zones
175+ Ok ( zones)
148176 } else {
149177 let zones: HashSet < i32 > = data
150178 . par_iter ( )
151179 . flat_map ( |( origin, destination, _) | vec ! [ * origin, * destination] )
152180 . collect ( ) ;
153181 let mut zones: Vec < i32 > = zones. into_iter ( ) . collect ( ) ;
154182 zones. sort_unstable ( ) ;
155- zones
183+ Ok ( zones)
156184 }
157185}
158186
@@ -194,8 +222,8 @@ fn build_matrix(data: &[(i32, i32, f32)], all_zones: &[i32]) -> Vec<f32> {
194222///
195223/// # Panics
196224/// This function will panic if it fails to create or write to the output file.
197- fn write_mtx_file ( output_file_name : & str , all_zones : & [ i32 ] , matrix : & [ f32 ] ) {
198- let output_file = File :: create ( output_file_name) . unwrap ( ) ;
225+ fn write_mtx_file ( output_file_name : & str , all_zones : & [ i32 ] , matrix : & [ f32 ] ) -> std :: io :: Result < ( ) > {
226+ let output_file = File :: create ( output_file_name) ? ;
199227 let mut writer: Box < dyn Write > = if output_file_name. ends_with ( ".gz" ) {
200228 Box :: new ( BufWriter :: new ( GzEncoder :: new ( output_file, Compression :: default ( ) ) ) )
201229 } else {
@@ -204,41 +232,62 @@ fn write_mtx_file(output_file_name: &str, all_zones: &[i32], matrix: &[f32]) {
204232
205233 let zone_count = all_zones. len ( ) as i32 ;
206234
207- writer. write_all ( & 0xC4D4F1B2u32 . to_le_bytes ( ) ) . unwrap ( ) ; // Magic Number
208- writer. write_all ( & 1i32 . to_le_bytes ( ) ) . unwrap ( ) ; // Version Number
209- writer. write_all ( & 1i32 . to_le_bytes ( ) ) . unwrap ( ) ; // Type
210- writer. write_all ( & 2i32 . to_le_bytes ( ) ) . unwrap ( ) ; // Dimensions
211- writer. write_all ( & zone_count. to_le_bytes ( ) ) . unwrap ( ) ; // Index size for origin
212- writer. write_all ( & zone_count. to_le_bytes ( ) ) . unwrap ( ) ; // Index size for destination
235+ writer. write_all ( & 0xC4D4F1B2u32 . to_le_bytes ( ) ) ? ; // Magic Number
236+ writer. write_all ( & 1i32 . to_le_bytes ( ) ) ? ; // Version Number
237+ writer. write_all ( & 1i32 . to_le_bytes ( ) ) ? ; // Type
238+ writer. write_all ( & 2i32 . to_le_bytes ( ) ) ? ; // Dimensions
239+ writer. write_all ( & zone_count. to_le_bytes ( ) ) ? ; // Index size for origin
240+ writer. write_all ( & zone_count. to_le_bytes ( ) ) ? ; // Index size for destination
213241
214242 let is_little_endian = cfg ! ( target_endian = "little" ) ;
215243
216244 if is_little_endian {
217245 // Write all origin zone numbers in a single call
218246 let origin_zone_bytes: & [ u8 ] = bytemuck:: cast_slice ( all_zones) ;
219- writer. write_all ( origin_zone_bytes) . unwrap ( ) ; // Zone Numbers for Origin
247+ writer. write_all ( origin_zone_bytes) ? ; // Zone Numbers for Origin
220248
221249 // Write all destination zone numbers in a single call
222- writer. write_all ( origin_zone_bytes) . unwrap ( ) ; // Zone Numbers for Destination
250+ writer. write_all ( origin_zone_bytes) ? ; // Zone Numbers for Destination
223251
224252 // Write all matrix values in a single call
225253 let matrix_bytes: & [ u8 ] = bytemuck:: cast_slice ( matrix) ;
226- writer. write_all ( matrix_bytes) . unwrap ( ) ;
254+ writer. write_all ( matrix_bytes) ? ;
227255
228256 } else {
229257 // Convert all_zones to little-endian
230258 let origin_zone_bytes: Vec < u8 > = all_zones
231259 . par_iter ( )
232260 . flat_map ( |& zone| zone. to_le_bytes ( ) )
233261 . collect ( ) ;
234- writer. write_all ( & origin_zone_bytes) . unwrap ( ) ; // Zone Numbers for Origin
235- writer. write_all ( & origin_zone_bytes) . unwrap ( ) ; // Zone Numbers for Destination
262+ writer. write_all ( & origin_zone_bytes) ? ; // Zone Numbers for Origin
263+ writer. write_all ( & origin_zone_bytes) ? ; // Zone Numbers for Destination
236264
237265 // Convert matrix to little-endian
238266 let matrix_bytes: Vec < u8 > = matrix
239267 . par_iter ( )
240268 . flat_map ( |& value| value. to_le_bytes ( ) )
241269 . collect ( ) ;
242- writer. write_all ( & matrix_bytes) . unwrap ( ) ;
270+ writer. write_all ( & matrix_bytes) ?;
271+ }
272+ Ok ( ( ) )
273+ }
274+
275+ // Write a test using test.csv to make sure that it converts to an mtx file
276+ #[ cfg( test) ]
277+ mod tests {
278+ use super :: * ;
279+ #[ test]
280+ fn test_csv_to_mtx ( ) -> std:: io:: Result < ( ) > {
281+ let input_file = "test/test.csv" ;
282+ let output_file = "test/test_output.mtx" ;
283+
284+ convert_csv_to_mtx ( input_file, output_file, None ) ?;
285+
286+ // Compare against a known good output file
287+ let expected_output_file = "test/test_expected.mtx" ;
288+ let output_data = std:: fs:: read ( output_file) ?;
289+ let expected_data = std:: fs:: read ( expected_output_file) ?;
290+ assert_eq ! ( output_data, expected_data) ;
291+ Ok ( ( ) )
243292 }
244293}
0 commit comments