Skip to content

Commit d606457

Browse files
authored
Added error messages (#2)
Removed unwraps and added better message. Added a unit test to make sure that we generate an expected .mtx file.
1 parent 0f2be56 commit d606457

4 files changed

Lines changed: 94 additions & 36 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
/target
2+
/test/test_output.mtx

src/main.rs

Lines changed: 85 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
use core::panic;
21
use std::env;
32
use std::collections::HashSet;
43
use std::io::{BufWriter, Write};
@@ -9,19 +8,52 @@ use rayon::prelude::*; // For parallel processing
98

109
/// The main function parses command-line arguments, processes the input CSV file,
1110
/// optionally uses a zones CSV file, and writes the output in MTX format.
12-
fn main() {
11+
fn main() -> std::io::Result<()> {
1312
let arg: Vec<String> = env::args().collect();
1413

1514
if arg.len() < 3 {
1615
println!("Usage: csv_to_mtx <input.csv> <output.mtx/.mtx.gz> [zones.csv]");
17-
return;
16+
return Ok(());
1817
}
1918

20-
let data = read_csv(&arg[1]);
21-
let all_zones = get_all_zones(&arg, &data);
19+
let zones_file = if arg.len() > 3 {
20+
Some(&arg[3] as &str)
21+
} else {
22+
None
23+
};
24+
convert_csv_to_mtx(&arg[1], &arg[2], zones_file)
25+
}
26+
27+
/// Converts the input CSV file to MTX format and writes it to the output file.
28+
///
29+
fn convert_csv_to_mtx(
30+
input_file: &str,
31+
output_file: &str,
32+
zones_file: Option<&str>,
33+
) -> std::io::Result<()> {
34+
let data = match read_csv(input_file) {
35+
Ok(data) => data,
36+
Err(e) => {
37+
eprintln!("Error reading CSV file: {}", e);
38+
return Err(e);
39+
}
40+
};
41+
let all_zones = match get_all_zones(zones_file, &data) {
42+
Ok(zones) => zones,
43+
Err(e) => {
44+
eprintln!("Error reading zones file: {}", e);
45+
return Err(e);
46+
}
47+
};
2248
println!("Found {} zones", all_zones.len());
2349
let matrix = build_matrix(&data, &all_zones);
24-
write_mtx_file(&arg[2], &all_zones, &matrix);
50+
match write_mtx_file(output_file, &all_zones, &matrix) {
51+
Err(e) => {
52+
eprintln!("Error writing MTX file: {}", e);
53+
Err(e)
54+
}
55+
_ => Ok(()),
56+
}
2557
}
2658

2759
/// Reads the input CSV file and extracts the data as a vector of tuples containing
@@ -34,13 +66,9 @@ fn main() {
3466
///
3567
/// # Returns
3668
/// A vector of tuples `(i32, i32, f32)` representing the origin, destination, and value.
37-
fn read_csv(input_file: &str) -> Vec<(i32, i32, f32)> {
38-
let file = match File::open(input_file) {
39-
Ok(f) => f,
40-
Err(e) => {
41-
panic!("Error opening file {input_file}: {e}");
42-
}
43-
};
69+
fn read_csv(input_file: &str) -> std::io::Result<Vec<(i32, i32, f32)>> {
70+
let file = File::open(input_file)?;
71+
4472
let mut rdr = csv::ReaderBuilder::new()
4573
.has_headers(false)
4674
.from_reader(file);
@@ -73,13 +101,13 @@ fn read_csv(input_file: &str) -> Vec<(i32, i32, f32)> {
73101
}
74102
}
75103

76-
data
104+
Ok(data)
77105
} else {
78106
// Rectangular format - pass the first record and remaining iterator
79-
read_rectangular_csv_from_records(first_record, records)
107+
Ok(read_rectangular_csv_from_records(first_record, records))
80108
}
81109
} else {
82-
Vec::new()
110+
Ok(Vec::new())
83111
}
84112
}
85113

@@ -130,29 +158,29 @@ fn read_rectangular_csv_from_records(
130158
/// or by extracting unique origins and destinations from the input data.
131159
///
132160
/// # Arguments
133-
/// * `arg` - The command-line arguments.
161+
/// * `zones_file` - Optional path to the zones CSV file.
134162
/// * `data` - The vector of tuples `(i32, i32, f32)` representing the input data.
135163
///
136164
/// # Returns
137165
/// A sorted vector of unique zone numbers.
138-
fn get_all_zones(arg: &[String], data: &[(i32, i32, f32)]) -> Vec<i32> {
139-
if arg.len() > 3 {
140-
let zone_file = File::open(&arg[3]).unwrap();
166+
fn get_all_zones(zones_file: Option<&str>, data: &[(i32, i32, f32)]) -> std::io::Result<Vec<i32>> {
167+
if let Some(zone_file) = zones_file {
168+
let zone_file = File::open(zone_file)?;
141169
let mut zone_rdr = csv::Reader::from_reader(zone_file);
142170
let mut zones: Vec<i32> = zone_rdr
143171
.records()
144172
.filter_map(|result| result.ok()?.get(0)?.parse().ok())
145173
.collect();
146174
zones.sort_unstable();
147-
zones
175+
Ok(zones)
148176
} else {
149177
let zones: HashSet<i32> = data
150178
.par_iter()
151179
.flat_map(|(origin, destination, _)| vec![*origin, *destination])
152180
.collect();
153181
let mut zones: Vec<i32> = zones.into_iter().collect();
154182
zones.sort_unstable();
155-
zones
183+
Ok(zones)
156184
}
157185
}
158186

@@ -194,8 +222,8 @@ fn build_matrix(data: &[(i32, i32, f32)], all_zones: &[i32]) -> Vec<f32> {
194222
///
195223
/// # Panics
196224
/// This function will panic if it fails to create or write to the output file.
197-
fn write_mtx_file(output_file_name: &str, all_zones: &[i32], matrix: &[f32]) {
198-
let output_file = File::create(output_file_name).unwrap();
225+
fn write_mtx_file(output_file_name: &str, all_zones: &[i32], matrix: &[f32]) -> std::io::Result<()> {
226+
let output_file = File::create(output_file_name)?;
199227
let mut writer: Box<dyn Write> = if output_file_name.ends_with(".gz") {
200228
Box::new(BufWriter::new(GzEncoder::new(output_file, Compression::default())))
201229
} else {
@@ -204,41 +232,62 @@ fn write_mtx_file(output_file_name: &str, all_zones: &[i32], matrix: &[f32]) {
204232

205233
let zone_count = all_zones.len() as i32;
206234

207-
writer.write_all(&0xC4D4F1B2u32.to_le_bytes()).unwrap(); // Magic Number
208-
writer.write_all(&1i32.to_le_bytes()).unwrap(); // Version Number
209-
writer.write_all(&1i32.to_le_bytes()).unwrap(); // Type
210-
writer.write_all(&2i32.to_le_bytes()).unwrap(); // Dimensions
211-
writer.write_all(&zone_count.to_le_bytes()).unwrap(); // Index size for origin
212-
writer.write_all(&zone_count.to_le_bytes()).unwrap(); // Index size for destination
235+
writer.write_all(&0xC4D4F1B2u32.to_le_bytes())?; // Magic Number
236+
writer.write_all(&1i32.to_le_bytes())?; // Version Number
237+
writer.write_all(&1i32.to_le_bytes())?; // Type
238+
writer.write_all(&2i32.to_le_bytes())?; // Dimensions
239+
writer.write_all(&zone_count.to_le_bytes())?; // Index size for origin
240+
writer.write_all(&zone_count.to_le_bytes())?; // Index size for destination
213241

214242
let is_little_endian = cfg!(target_endian = "little");
215243

216244
if is_little_endian {
217245
// Write all origin zone numbers in a single call
218246
let origin_zone_bytes: &[u8] = bytemuck::cast_slice(all_zones);
219-
writer.write_all(origin_zone_bytes).unwrap(); // Zone Numbers for Origin
247+
writer.write_all(origin_zone_bytes)?; // Zone Numbers for Origin
220248

221249
// Write all destination zone numbers in a single call
222-
writer.write_all(origin_zone_bytes).unwrap(); // Zone Numbers for Destination
250+
writer.write_all(origin_zone_bytes)?; // Zone Numbers for Destination
223251

224252
// Write all matrix values in a single call
225253
let matrix_bytes: &[u8] = bytemuck::cast_slice(matrix);
226-
writer.write_all(matrix_bytes).unwrap();
254+
writer.write_all(matrix_bytes)?;
227255

228256
} else {
229257
// Convert all_zones to little-endian
230258
let origin_zone_bytes: Vec<u8> = all_zones
231259
.par_iter()
232260
.flat_map(|&zone| zone.to_le_bytes())
233261
.collect();
234-
writer.write_all(&origin_zone_bytes).unwrap(); // Zone Numbers for Origin
235-
writer.write_all(&origin_zone_bytes).unwrap(); // Zone Numbers for Destination
262+
writer.write_all(&origin_zone_bytes)?; // Zone Numbers for Origin
263+
writer.write_all(&origin_zone_bytes)?; // Zone Numbers for Destination
236264

237265
// Convert matrix to little-endian
238266
let matrix_bytes: Vec<u8> = matrix
239267
.par_iter()
240268
.flat_map(|&value| value.to_le_bytes())
241269
.collect();
242-
writer.write_all(&matrix_bytes).unwrap();
270+
writer.write_all(&matrix_bytes)?;
271+
}
272+
Ok(())
273+
}
274+
275+
// Write a test using test.csv to make sure that it converts to an mtx file
276+
#[cfg(test)]
277+
mod tests {
278+
use super::*;
279+
#[test]
280+
fn test_csv_to_mtx() -> std::io::Result<()> {
281+
let input_file = "test/test.csv";
282+
let output_file = "test/test_output.mtx";
283+
284+
convert_csv_to_mtx(input_file, output_file, None)?;
285+
286+
// Compare against a known good output file
287+
let expected_output_file = "test/test_expected.mtx";
288+
let output_data = std::fs::read(output_file)?;
289+
let expected_data = std::fs::read(expected_output_file)?;
290+
assert_eq!(output_data, expected_data);
291+
Ok(())
243292
}
244293
}

test/test.csv

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
Origin,Destination,Value
2+
1,1,0.1
3+
1,2,0.2
4+
1,3,0.3
5+
2,1,1
6+
2,2,2
7+
2,3,3
8+
4,4,0.1

test/test_expected.mtx

120 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)