diff --git a/benches/benchmarks.rs b/benches/benchmarks.rs index 8f04185..49adf2f 100644 --- a/benches/benchmarks.rs +++ b/benches/benchmarks.rs @@ -46,6 +46,16 @@ fn do_decode_bench_slice(b: &mut Bencher, &size: &usize) { }); } +fn do_check_encoded(b: &mut Bencher, &size: &usize) { + let mut v: Vec = Vec::with_capacity(size * 3 / 4); + fill(&mut v); + let encoded = STANDARD.encode(&v); + + b.iter(|| { + STANDARD.check_encoded(&encoded).unwrap(); + }); +} + fn do_decode_bench_stream(b: &mut Bencher, &size: &usize) { let mut v: Vec = Vec::with_capacity(size * 3 / 4); fill(&mut v); @@ -217,6 +227,11 @@ fn decode_benchmarks(c: &mut Criterion, label: &str, byte_sizes: &[usize]) { size, do_decode_bench_slice, ) + .bench_with_input( + BenchmarkId::new("check_encoded", size), + size, + do_check_encoded, + ) .bench_with_input( BenchmarkId::new("decode_stream", size), size, diff --git a/src/engine/general_purpose/mod.rs b/src/engine/general_purpose/mod.rs index 72a02de..3a76921 100644 --- a/src/engine/general_purpose/mod.rs +++ b/src/engine/general_purpose/mod.rs @@ -2,10 +2,9 @@ //! //! See preconfigured engines like [`STANDARD_NO_PAD`] or [`STANDARD_NO_PAD_INDIFFERENT`]. use crate::{ - alphabet, - alphabet::Alphabet, + alphabet::{self, Alphabet}, engine::{Config, DecodeMetadata, DecodePaddingMode}, - DecodeSliceError, + DecodeError, DecodeSliceError, }; use core::convert::TryInto; @@ -190,6 +189,49 @@ impl super::Engine for GeneralPurpose { fn config(&self) -> &Self::Config { &self.config } + + fn check_encoded>(&self, encoded: T) -> Result<(), DecodeError> { + let input = encoded.as_ref(); + let rem = input.len() % 4; + let suffix_start = match (input.len(), rem) { + // there's no prefix, just suffix + (0..4, _) => 0, + // make last partition the suffix + (4.., 0) => input.len() - 4, + // make last incomplete partition the suffix + (4.., rem) => input.len() - rem, + }; + // partition the input without suffix + let prefix_input = &input[..suffix_start]; + + // try to find an invalid value + let invalid_value = prefix_input + .iter() + .position(|&b| self.decode_table[b as usize] == INVALID_VALUE); + if let Some(pos) = invalid_value { + return Err(DecodeError::InvalidByte(pos, prefix_input[pos])); + } + + // reuse `decode_suffix`, even tho it writes to a buffer it's not on the hot codepath + let mut output_buffer = [0_u8; 3]; + _ = super::decode_suffix( + input, + suffix_start, + &mut output_buffer, + 0, + &self.decode_table, + self.config.decode_allow_trailing_bits, + self.config.decode_padding_mode, + ) + .map_err(|e| match e { + DecodeSliceError::DecodeError(err) => err, + DecodeSliceError::OutputSliceTooSmall => { + unreachable!("output buffer sized conservatively"); + } + })?; + + Ok(()) + } } /// Returns a table mapping a 6-bit index to the ASCII byte encoding of the index diff --git a/src/engine/mod.rs b/src/engine/mod.rs index 93ae5d9..af98f29 100644 --- a/src/engine/mod.rs +++ b/src/engine/mod.rs @@ -3,7 +3,9 @@ use crate::chunked_encoder; use crate::{ encode::{encode_with_padding, EncodeSliceError}, - encoded_len, DecodeError, DecodeSliceError, + encoded_len, + engine::general_purpose::decode_suffix::decode_suffix, + DecodeError, DecodeSliceError, }; #[cfg(any(feature = "alloc", test))] use alloc::vec::Vec; @@ -416,6 +418,10 @@ pub trait Engine: Send + Sync { inner(self, input.as_ref(), output) } + + // TODO: more docs + /// Checks for decoding errors without decoding. + fn check_encoded>(&self, encoded: T) -> Result<(), DecodeError>; } /// The minimal level of configuration that engines must support. diff --git a/src/engine/naive.rs b/src/engine/naive.rs index af509bf..1f9fb2a 100644 --- a/src/engine/naive.rs +++ b/src/engine/naive.rs @@ -154,6 +154,11 @@ impl Engine for Naive { fn config(&self) -> &Self::Config { &self.config } + + fn check_encoded>(&self, encoded: T) -> Result<(), DecodeError> { + _ = self.decode(encoded)?; + Ok(()) + } } pub struct NaiveEstimate { diff --git a/src/engine/tests.rs b/src/engine/tests.rs index 72bbf4b..5d4b09d 100644 --- a/src/engine/tests.rs +++ b/src/engine/tests.rs @@ -74,6 +74,9 @@ fn rfc_test_vectors_std_alphabet(engine_wrapper: E) { .decode_slice_unchecked(encoded_without_padding.as_bytes(), &mut decode_buf[..]) .unwrap(); assert_eq!(orig.len(), decode_len); + assert!(engine_no_padding + .check_encoded(encoded_without_padding) + .is_ok()); assert_eq!( orig, @@ -85,7 +88,11 @@ fn rfc_test_vectors_std_alphabet(engine_wrapper: E) { assert_eq!( Err(DecodeError::InvalidPadding), engine_no_padding.decode(encoded) - ) + ); + assert_eq!( + Err(DecodeError::InvalidPadding), + engine_no_padding.check_encoded(encoded), + ); } } @@ -107,6 +114,7 @@ fn rfc_test_vectors_std_alphabet(engine_wrapper: E) { .decode_slice_unchecked(encoded.as_bytes(), &mut decode_buf[..]) .unwrap(); assert_eq!(orig.len(), decode_len); + assert!(engine.check_encoded(encoded).is_ok()); assert_eq!( orig, @@ -118,7 +126,11 @@ fn rfc_test_vectors_std_alphabet(engine_wrapper: E) { assert_eq!( Err(DecodeError::InvalidPadding), engine.decode(encoded_without_padding) - ) + ); + assert_eq!( + Err(DecodeError::InvalidPadding), + engine.check_encoded(encoded_without_padding) + ); } } } @@ -333,6 +345,8 @@ fn decode_detect_invalid_last_symbol(engine_wrapper: E) { assert_eq!(Ok(vec![0x89, 0x85]), engine.decode("iYU=")); assert_eq!(Ok(vec![0xFF]), engine.decode("/w==")); + assert_eq!(Ok(()), engine.check_encoded("iYU=")); + assert_eq!(Ok(()), engine.check_encoded("/w==")); for (suffix, offset) in vec![ // suffix, offset of bad byte from start of suffix @@ -360,6 +374,13 @@ fn decode_detect_invalid_last_symbol(engine_wrapper: E) { )), engine.decode(encoded.as_str()) ); + assert_eq!( + Err(DecodeError::InvalidLastSymbol( + encoded.len() - 4 + offset, + suffix.as_bytes()[offset], + )), + engine.check_encoded(encoded.as_str()) + ); } } } @@ -373,6 +394,10 @@ fn decode_detect_1_valid_symbol_in_last_quad_invalid_length(en let engine = E::standard_with_pad_mode(true, mode); assert_eq!(Err(DecodeError::InvalidLength(len)), engine.decode(&input)); + assert_eq!( + Err(DecodeError::InvalidLength(len)), + engine.check_encoded(&input) + ); // if we add padding, then the first pad byte in the quad is invalid because it should // be the second symbol for _ in 0..3 { @@ -381,6 +406,10 @@ fn decode_detect_1_valid_symbol_in_last_quad_invalid_length(en Err(DecodeError::InvalidByte(len, PAD_BYTE)), engine.decode(&input) ); + assert_eq!( + Err(DecodeError::InvalidByte(len, PAD_BYTE)), + engine.check_encoded(&input) + ); } } } @@ -399,6 +428,10 @@ fn decode_detect_1_invalid_byte_in_last_quad_invalid_byte(engi Err(DecodeError::InvalidByte(prefix_len, b'*')), engine.decode(&input) ); + assert_eq!( + Err(DecodeError::InvalidByte(prefix_len, b'*')), + engine.check_encoded(&input) + ); // adding padding doesn't matter for _ in 0..3 { input.push(PAD_BYTE); @@ -406,6 +439,10 @@ fn decode_detect_1_invalid_byte_in_last_quad_invalid_byte(engi Err(DecodeError::InvalidByte(prefix_len, b'*')), engine.decode(&input) ); + assert_eq!( + Err(DecodeError::InvalidByte(prefix_len, b'*')), + engine.check_encoded(&input) + ); } } } @@ -452,13 +489,21 @@ fn decode_detect_invalid_last_symbol_every_possible_two_symbols { + assert_eq!( + Err(DecodeError::InvalidLastSymbol(1, s2)), + engine.decode(&symbols[..]) + ); + assert_eq!( + Err(DecodeError::InvalidLastSymbol(1, s2)), + engine.check_encoded(&symbols[..]) + ); } - None => assert_eq!( - Err(DecodeError::InvalidLastSymbol(1, s2)), - engine.decode(&symbols[..]) - ), } } } @@ -519,13 +564,21 @@ fn decode_detect_invalid_last_symbol_every_possible_three_symbols { + assert_eq!( + Err(DecodeError::InvalidLastSymbol(2, s3)), + engine.decode(&symbols[..]) + ); + assert_eq!( + Err(DecodeError::InvalidLastSymbol(2, s3)), + engine.check_encoded(&symbols[..]) + ); } - None => assert_eq!( - Err(DecodeError::InvalidLastSymbol(2, s3)), - engine.decode(&symbols[..]) - ), } } } @@ -554,6 +607,7 @@ fn decode_invalid_trailing_bits_ignored_when_configured(engine Ok(expected_decode_bytes), decoded.map(|v| v[decoded_prefix_len..].to_vec()) ); + assert_eq!(Ok(()), engine.check_encoded(prefixed)); } let mut prefix = String::new(); @@ -567,6 +621,12 @@ fn decode_invalid_trailing_bits_ignored_when_configured(engine assert!(strict .decode(prefixed_data(&mut input, prefix.len(), "iYU=")) .is_ok()); + assert!(strict + .check_encoded(prefixed_data(&mut input, prefix.len(), "/w==")) + .is_ok()); + assert!(strict + .check_encoded(prefixed_data(&mut input, prefix.len(), "iYU=")) + .is_ok()); // trailing 01 assert_tolerant_decode(&forgiving, &mut input, prefix.len(), vec![255], "/x=="); assert_tolerant_decode(&forgiving, &mut input, prefix.len(), vec![137, 133], "iYV="); @@ -633,6 +693,10 @@ fn decode_invalid_byte_error(engine_wrapper: E) { &mut decode_buf[..], ) ); + assert_eq!( + Err(DecodeError::InvalidByte(invalid_index, invalid_byte)), + engine.check_encoded(&encode_buf[0..encoded_len_with_padding],) + ); } } @@ -726,6 +790,10 @@ fn decode_padding_before_final_non_padding_char_error_invalid_byte_at_first_pad( encoded.len(), String::from_utf8(encoded).unwrap() ); + assert_eq!( + Err(DecodeError::InvalidByte(padding_start, PAD_BYTE)), + engine.check_encoded(&encoded), + ); } } } @@ -770,6 +838,10 @@ fn decode_padding_starts_before_final_chunk_error_invalid_byte_at_first_pad(en suffix_data_len, padding_len ); + assert_eq!( + Err(DecodeError::InvalidByte( + prefix_quad_len * 4 + suffix_data_len, + PAD_BYTE, + )), + engine.check_encoded(&encoded) + ); } } } @@ -820,6 +899,7 @@ fn decode_malleability_test_case_3_byte_suffix_valid(engine_wr b"Hello".as_slice(), &E::standard().decode("SGVsbG8=").unwrap() ); + assert_eq!(Ok(()), E::standard().check_encoded("SGVsbG8=")); } // https://eprint.iacr.org/2022/361.pdf table 2, test 2 @@ -831,6 +911,10 @@ fn decode_malleability_test_case_3_byte_suffix_invalid_trailing_symbol(e DecodeError::InvalidPadding, E::standard().decode("SGVsbA=").unwrap_err() ); + assert_eq!( + DecodeError::InvalidPadding, + E::standard().check_encoded("SGVsbA=").unwrap_err() + ); } // https://eprint.iacr.org/2022/361.pdf table 2, test 6 @@ -869,6 +962,10 @@ fn decode_malleability_test_case_2_byte_suffix_no_padding(engi DecodeError::InvalidPadding, E::standard().decode("SGVsbA").unwrap_err() ); + assert_eq!( + DecodeError::InvalidPadding, + E::standard().check_encoded("SGVsbA").unwrap_err() + ); } // https://eprint.iacr.org/2022/361.pdf table 2, test 7 @@ -882,6 +979,10 @@ fn decode_malleability_test_case_2_byte_suffix_too_much_padding accepts 2 + 2, 3 + 1, 4 + 0 final quad configurations @@ -905,7 +1006,8 @@ fn decode_pad_mode_requires_canonical_rejects_non_canonical(en encoded.push_str(suffix); let res = engine.decode(&encoded); - + assert_eq!(Err(DecodeError::InvalidPadding), res); + let res = engine.check_encoded(&encoded); assert_eq!(Err(DecodeError::InvalidPadding), res); } } @@ -932,7 +1034,8 @@ fn decode_pad_mode_requires_no_padding_rejects_any_padding(eng encoded.push_str(suffix); let res = engine.decode(&encoded); - + assert_eq!(Err(DecodeError::InvalidPadding), res); + let res = engine.check_encoded(&encoded); assert_eq!(Err(DecodeError::InvalidPadding), res); } } @@ -988,6 +1091,16 @@ fn do_invalid_trailing_byte(engine: impl Engine, mode: DecodePaddingMode) { mode, String::from_utf8(input).unwrap() ); + assert_eq!( + Err(DecodeError::InvalidByte( + num_prefix_quads * 4 + 4, + last_byte + )), + engine.check_encoded(&input), + "mode: {:?}, input: {}", + mode, + String::from_utf8(input).unwrap() + ); } } } @@ -1042,6 +1155,14 @@ fn do_invalid_trailing_padding_as_invalid_byte_at_first_padding( mode, s ); + assert_eq!( + // pad after `g`, not the last one + Err(DecodeError::InvalidByte( + num_prefix_quads * 4 + pad_offset, + PAD_BYTE + )), + engine.check_encoded(&s), + ); } } } @@ -1077,6 +1198,7 @@ fn decode_into_slice_fits_in_precisely_sized_slice(engine_wrap .unwrap(); assert_eq!(orig_data.len(), decode_bytes_written); assert_eq!(orig_data, decode_buf); + assert_eq!(Ok(()), engine.check_encoded(encoded_data.as_bytes())); // same for checked variant decode_buf.clear(); @@ -1087,6 +1209,7 @@ fn decode_into_slice_fits_in_precisely_sized_slice(engine_wrap .unwrap(); assert_eq!(orig_data.len(), decode_bytes_written); assert_eq!(orig_data, decode_buf); + assert_eq!(Ok(()), engine.check_encoded(encoded_data.as_bytes())); } } @@ -1115,6 +1238,7 @@ fn inner_decode_reports_padding_position(engine_wrapper: E) { &mut decoded[..], engine.internal_decoded_len_estimate(b64.len()), ); + let check_res = engine.check_encoded(b64.as_bytes()); if pad_position % 4 < 2 { // impossible padding assert_eq!( @@ -1124,6 +1248,10 @@ fn inner_decode_reports_padding_position(engine_wrapper: E) { ))), decode_res ); + assert_eq!( + Err(DecodeError::InvalidByte(pad_position, PAD_BYTE)), + check_res + ); } else { let decoded_bytes = pad_position / 4 * 3 + match pad_position % 4 { @@ -1136,6 +1264,7 @@ fn inner_decode_reports_padding_position(engine_wrapper: E) { Ok(DecodeMetadata::new(decoded_bytes, Some(pad_position))), decode_res ); + assert_eq!(Ok(()), check_res); } } } @@ -1242,6 +1371,7 @@ fn decode_slice_checked_fails_gracefully_at_all_output_lengths engine.decode_slice(&encoded, &mut decode_buf[..]).unwrap() ); assert_eq!(original, decode_buf); + assert_eq!(Ok(()), engine.check_encoded(&encoded)); } } } @@ -1512,6 +1642,10 @@ impl Engine for DecoderReaderEngine { fn config(&self) -> &Self::Config { self.engine.config() } + + fn check_encoded>(&self, encoded: T) -> Result<(), DecodeError> { + self.engine.check_encoded(encoded) + } } struct DecoderReaderEngineWrapper {} @@ -1574,6 +1708,8 @@ fn assert_all_suffixes_ok(engine: E, suffixes: Vec<&str>) { let res = &engine.decode(&encoded); assert!(res.is_ok()); + let res = &engine.check_encoded(&encoded); + assert!(res.is_ok()); } } }