From 13e437bc6f55f7e7a8d6a77f06803c482a723528 Mon Sep 17 00:00:00 2001 From: Mario Alejandro Montoya Cortes Date: Thu, 10 Jul 2025 09:34:07 -0500 Subject: [PATCH 1/2] Add a recursion limit to the evaluation of type_expr & parse_expr --- crates/core/src/sql/execute.rs | 47 +++++++++++++++++++---- crates/expr/src/check.rs | 2 +- crates/expr/src/lib.rs | 26 ++++++++----- crates/expr/src/statement.rs | 4 +- crates/sql-parser/src/parser/errors.rs | 9 +++++ crates/sql-parser/src/parser/mod.rs | 35 +++++++++-------- crates/sql-parser/src/parser/recursion.rs | 34 ++++++++++++++++ crates/sql-parser/src/parser/sql.rs | 6 +-- crates/sql-parser/src/parser/sub.rs | 2 +- 9 files changed, 127 insertions(+), 38 deletions(-) create mode 100644 crates/sql-parser/src/parser/recursion.rs diff --git a/crates/core/src/sql/execute.rs b/crates/core/src/sql/execute.rs index b089db36be6..23e344f7177 100644 --- a/crates/core/src/sql/execute.rs +++ b/crates/core/src/sql/execute.rs @@ -1126,6 +1126,9 @@ pub(crate) mod tests { Ok(()) } + /// Test we are protected against recursion when: + /// 1. The query is too large + /// 2. The AST is too deep #[test] fn test_large_query_no_panic() -> ResultTest<()> { let db = TestDB::durable()?; @@ -1138,16 +1141,46 @@ pub(crate) mod tests { ) .unwrap(); - let mut query = "select * from test where ".to_string(); - for x in 0..1_000 { - for y in 0..1_000 { - let fragment = format!("((x = {x}) and y = {y}) or"); - query.push_str(&fragment); + let build_query = |total| { + let mut sql = "select * from test where ".to_string(); + for x in 0..total { + for y in 0..total { + let fragment = format!("((x = {x}) and (y = {y})) or "); + sql.push_str(&fragment); + } } + sql.push_str("((x = 1000) and (y = 1000))"); + sql + }; + let run = |db: &RelationalDB, sep: char, sql_text: &str| { + run_for_testing(db, sql_text).map_err(|e| e.to_string().split(sep).next().unwrap_or_default().to_string()) + }; + let sql = build_query(1_000); + assert_eq!( + run(&db, ':', &sql), + Err("SQL query exceeds maximum allowed length".to_string()) + ); + + // Exercise the limit [recursion::MAX_RECURSION_EXPR] && [recursion::MAX_RECURSION_TYP_EXPR] + let sql = build_query(8); + assert_eq!(run(&db, ',', &sql), Err("Recursion limit exceeded".to_string())); + + let sql = build_query(7); + assert!(run(&db, ',', &sql).is_ok(), "Expected query to run without panic"); + + // Check no overflow with lot of joins + let mut sql = "SELECT test.* FROM test ".to_string(); + // We could pust up to 700 joins without overflow as long we don't have any conditions, + // but here execution become too slow. + // TODO: Move this test to the `Plan` + for i in 0..200 { + sql.push_str(&format!("JOIN test AS m{i} ON test.x = m{i}.y ")); } - query.push_str("((x = 1000) and (y = 1000))"); - assert!(run_for_testing(&db, &query).is_err()); + assert!( + run(&db, ',', &sql).is_ok(), + "Query with many joins and conditions should not overflow" + ); Ok(()) } diff --git a/crates/expr/src/check.rs b/crates/expr/src/check.rs index 14aaa34a9ed..006878ac26b 100644 --- a/crates/expr/src/check.rs +++ b/crates/expr/src/check.rs @@ -99,7 +99,7 @@ pub trait TypeChecker { vars.insert(rhs.alias.clone(), rhs.schema.clone()); if let Some(on) = on { - if let Expr::BinOp(BinOp::Eq, a, b) = type_expr(vars, on, Some(&AlgebraicType::Bool))? { + if let Expr::BinOp(BinOp::Eq, a, b) = type_expr(vars, on, Some(&AlgebraicType::Bool), &mut 0)? { if let (Expr::Field(a), Expr::Field(b)) = (*a, *b) { join = RelExpr::EqJoin(LeftDeepJoin { lhs, rhs }, a, b); continue; diff --git a/crates/expr/src/lib.rs b/crates/expr/src/lib.rs index 4860bdc882e..9905351a02a 100644 --- a/crates/expr/src/lib.rs +++ b/crates/expr/src/lib.rs @@ -19,6 +19,7 @@ use spacetimedb_sats::algebraic_type::fmt::fmt_algebraic_type; use spacetimedb_sats::algebraic_value::ser::ValueSerializer; use spacetimedb_schema::schema::ColumnSchema; use spacetimedb_sql_parser::ast::{self, BinOp, ProjectElem, SqlExpr, SqlIdent, SqlLiteral}; +use spacetimedb_sql_parser::parser::recursion; pub mod check; pub mod errors; @@ -30,7 +31,7 @@ pub mod statement; pub(crate) fn type_select(input: RelExpr, expr: SqlExpr, vars: &Relvars) -> TypingResult { Ok(RelExpr::Select( Box::new(input), - type_expr(vars, expr, Some(&AlgebraicType::Bool))?, + type_expr(vars, expr, Some(&AlgebraicType::Bool), &mut 0)?, )) } @@ -68,7 +69,7 @@ pub(crate) fn type_proj(input: RelExpr, proj: ast::Project, vars: &Relvars) -> T return Err(DuplicateName(alias.into_string()).into()); } - if let Expr::Field(p) = type_expr(vars, expr.into(), None)? { + if let Expr::Field(p) = type_expr(vars, expr.into(), None, &mut 0)? { projections.push((alias, p)); } } @@ -79,7 +80,14 @@ pub(crate) fn type_proj(input: RelExpr, proj: ast::Project, vars: &Relvars) -> T } /// Type check and lower a [SqlExpr] into a logical [Expr]. -pub(crate) fn type_expr(vars: &Relvars, expr: SqlExpr, expected: Option<&AlgebraicType>) -> TypingResult { +pub(crate) fn type_expr( + vars: &Relvars, + expr: SqlExpr, + expected: Option<&AlgebraicType>, + depth: &mut usize, +) -> TypingResult { + recursion::guard(depth, recursion::MAX_RECURSION_TYP_EXPR, "expr::type_expr")?; + match (expr, expected) { (SqlExpr::Lit(SqlLiteral::Bool(v)), None | Some(AlgebraicType::Bool)) => Ok(Expr::bool(v)), (SqlExpr::Lit(SqlLiteral::Bool(_)), Some(ty)) => Err(UnexpectedType::new(&AlgebraicType::Bool, ty).into()), @@ -117,21 +125,21 @@ pub(crate) fn type_expr(vars: &Relvars, expr: SqlExpr, expected: Option<&Algebra })) } (SqlExpr::Log(a, b, op), None | Some(AlgebraicType::Bool)) => { - let a = type_expr(vars, *a, Some(&AlgebraicType::Bool))?; - let b = type_expr(vars, *b, Some(&AlgebraicType::Bool))?; + let a = type_expr(vars, *a, Some(&AlgebraicType::Bool), depth)?; + let b = type_expr(vars, *b, Some(&AlgebraicType::Bool), depth)?; Ok(Expr::LogOp(op, Box::new(a), Box::new(b))) } (SqlExpr::Bin(a, b, op), None | Some(AlgebraicType::Bool)) if matches!(&*a, SqlExpr::Lit(_)) => { - let b = type_expr(vars, *b, None)?; - let a = type_expr(vars, *a, Some(b.ty()))?; + let b = type_expr(vars, *b, None, depth)?; + let a = type_expr(vars, *a, Some(b.ty()), depth)?; if !op_supports_type(op, a.ty()) { return Err(InvalidOp::new(op, a.ty()).into()); } Ok(Expr::BinOp(op, Box::new(a), Box::new(b))) } (SqlExpr::Bin(a, b, op), None | Some(AlgebraicType::Bool)) => { - let a = type_expr(vars, *a, None)?; - let b = type_expr(vars, *b, Some(a.ty()))?; + let a = type_expr(vars, *a, None, depth)?; + let b = type_expr(vars, *b, Some(a.ty()), depth)?; if !op_supports_type(op, a.ty()) { return Err(InvalidOp::new(op, a.ty()).into()); } diff --git a/crates/expr/src/statement.rs b/crates/expr/src/statement.rs index 39abf447118..fcdcaa8bda3 100644 --- a/crates/expr/src/statement.rs +++ b/crates/expr/src/statement.rs @@ -162,7 +162,7 @@ pub fn type_delete(delete: SqlDelete, tx: &impl SchemaView) -> TypingResult TypingResult Err(SqlUnsupported::JoinType.into()), @@ -204,15 +208,16 @@ pub(crate) fn parse_proj(expr: Expr) -> SqlParseResult { } /// Parse a scalar expression -pub(crate) fn parse_expr(expr: Expr) -> SqlParseResult { +pub(crate) fn parse_expr(expr: Expr, depth: &mut usize) -> SqlParseResult { fn signed_num(sign: impl Into, expr: Expr) -> Result { match expr { Expr::Value(Value::Number(n, _)) => Ok(SqlExpr::Lit(SqlLiteral::Num((sign.into() + &n).into_boxed_str()))), expr => Err(SqlUnsupported::Expr(expr)), } } + recursion::guard(depth, recursion::MAX_RECURSION_EXPR, "sql-parser::parse_expr")?; match expr { - Expr::Nested(expr) => parse_expr(*expr), + Expr::Nested(expr) => parse_expr(*expr, depth), Expr::Value(Value::Placeholder(param)) if ¶m == ":sender" => Ok(SqlExpr::Param(Parameter::Sender)), Expr::Value(v) => Ok(SqlExpr::Lit(parse_literal(v)?)), Expr::UnaryOp { @@ -238,8 +243,8 @@ pub(crate) fn parse_expr(expr: Expr) -> SqlParseResult { op: BinaryOperator::And, right, } => { - let l = parse_expr(*left)?; - let r = parse_expr(*right)?; + let l = parse_expr(*left, depth)?; + let r = parse_expr(*right, depth)?; Ok(SqlExpr::Log(Box::new(l), Box::new(r), LogOp::And)) } Expr::BinaryOp { @@ -247,13 +252,13 @@ pub(crate) fn parse_expr(expr: Expr) -> SqlParseResult { op: BinaryOperator::Or, right, } => { - let l = parse_expr(*left)?; - let r = parse_expr(*right)?; + let l = parse_expr(*left, depth)?; + let r = parse_expr(*right, depth)?; Ok(SqlExpr::Log(Box::new(l), Box::new(r), LogOp::Or)) } Expr::BinaryOp { left, op, right } => { - let l = parse_expr(*left)?; - let r = parse_expr(*right)?; + let l = parse_expr(*left, depth)?; + let r = parse_expr(*right, depth)?; Ok(SqlExpr::Bin(Box::new(l), Box::new(r), parse_binop(op)?)) } _ => Err(SqlUnsupported::Expr(expr).into()), @@ -261,8 +266,8 @@ pub(crate) fn parse_expr(expr: Expr) -> SqlParseResult { } /// Parse an optional scalar expression -pub(crate) fn parse_expr_opt(opt: Option) -> SqlParseResult> { - opt.map(parse_expr).transpose() +pub(crate) fn parse_expr_opt(opt: Option, depth: &mut usize) -> SqlParseResult> { + opt.map(|expr| parse_expr(expr, depth)).transpose() } /// Parse a scalar binary operator diff --git a/crates/sql-parser/src/parser/recursion.rs b/crates/sql-parser/src/parser/recursion.rs new file mode 100644 index 00000000000..b85bd2dccec --- /dev/null +++ b/crates/sql-parser/src/parser/recursion.rs @@ -0,0 +1,34 @@ +//! A utility for guarding against excessive recursion depth in the SQL parser. +//! +//! Different parts of the parser may have different recursion limits. +//! +//! Removing one could allow the others to be higher, but depending on how the `SQL` is structured, it could lead to a `stack overflow` +//! if is not guarded against, so is incorrect to assume that a limit is sufficient for the next part of the parser. +use crate::parser::errors::{RecursionError, SqlParseError}; +use std::fmt::Display; + +/// A conservative limit for recursion depth on `parse_expr`. +pub const MAX_RECURSION_EXPR: usize = 700; +/// A conservative limit for recursion depth on `type_expr`. +pub const MAX_RECURSION_TYP_EXPR: usize = 5_000; + +/// A utility for guarding against excessive recursion depth. +/// +/// **Usage:** +/// ``` +/// use spacetimedb_sql_parser::parser::recursion; +/// let mut depth = 0; +/// assert!(recursion::guard(&mut depth, 10, "test").is_ok()); +/// ``` +pub fn guard(depth: &mut usize, limit: usize, msg: impl Display) -> Result<(), SqlParseError> { + *depth += 1; + if *depth > limit { + Err(RecursionError { + limit, + message: msg.to_string(), + } + .into()) + } else { + Ok(()) + } +} diff --git a/crates/sql-parser/src/parser/sql.rs b/crates/sql-parser/src/parser/sql.rs index a1eb5078726..caedc28e299 100644 --- a/crates/sql-parser/src/parser/sql.rs +++ b/crates/sql-parser/src/parser/sql.rs @@ -202,7 +202,7 @@ fn parse_statement(stmt: Statement) -> SqlParseResult { } if joins.is_empty() && with_hints.is_empty() && partitions.is_empty() => Ok(SqlAst::Update(SqlUpdate { table: parse_ident(name)?, assignments: parse_assignments(assignments)?, - filter: parse_expr_opt(selection)?, + filter: parse_expr_opt(selection, &mut 0)?, })), Statement::Delete { tables, @@ -297,7 +297,7 @@ fn parse_delete(mut from: Vec, selection: Option) -> SqlPa joins, } if joins.is_empty() && with_hints.is_empty() && partitions.is_empty() => Ok(SqlDelete { table: parse_ident(name)?, - filter: parse_expr_opt(selection)?, + filter: parse_expr_opt(selection, &mut 0)?, }), t => Err(SqlUnsupported::DeleteTable(t).into()), } @@ -395,7 +395,7 @@ fn parse_select(select: Select, limit: Option>) -> SqlParseResult SqlParseResult { { Ok(SqlSelect { from: SubParser::parse_from(from)?, - filter: parse_expr_opt(selection)?, + filter: parse_expr_opt(selection, &mut 0)?, project: parse_projection(projection)?, }) } From 18ae75e20e412736b48175e1936f36635a1e519d Mon Sep 17 00:00:00 2001 From: Mario Alejandro Montoya Cortes Date: Mon, 14 Jul 2025 10:15:38 -0500 Subject: [PATCH 2/2] Apply suggestion from PR, and increase the recursion limit thanks to reduction on `Err` type --- crates/core/src/error.rs | 2 -- crates/core/src/host/module_host.rs | 2 -- crates/core/src/sql/execute.rs | 23 +++++++------ crates/expr/src/check.rs | 2 +- crates/expr/src/errors.rs | 2 -- crates/expr/src/lib.rs | 34 ++++++++++--------- crates/expr/src/statement.rs | 40 ++++++++++++++++++----- crates/sql-parser/src/parser/errors.rs | 23 +++++++++---- crates/sql-parser/src/parser/mod.rs | 31 ++++++++++-------- crates/sql-parser/src/parser/recursion.rs | 28 ++++++---------- crates/sql-parser/src/parser/sql.rs | 6 ++-- crates/sql-parser/src/parser/sub.rs | 4 +-- 12 files changed, 113 insertions(+), 84 deletions(-) diff --git a/crates/core/src/error.rs b/crates/core/src/error.rs index 4e550b234b3..4994ef69431 100644 --- a/crates/core/src/error.rs +++ b/crates/core/src/error.rs @@ -82,8 +82,6 @@ pub enum DatabaseError { DatabasedOpened(PathBuf, anyhow::Error), } -// FIXME: reduce type size -#[expect(clippy::large_enum_variant)] #[derive(Error, Debug, EnumAsInner)] pub enum DBError { #[error("LibError: {0}")] diff --git a/crates/core/src/host/module_host.rs b/crates/core/src/host/module_host.rs index dc8a75640a6..b158ef35ee0 100644 --- a/crates/core/src/host/module_host.rs +++ b/crates/core/src/host/module_host.rs @@ -539,8 +539,6 @@ pub enum InitDatabaseError { Other(anyhow::Error), } -// FIXME: reduce type size -#[expect(clippy::large_enum_variant)] #[derive(thiserror::Error, Debug)] pub enum ClientConnectedError { #[error(transparent)] diff --git a/crates/core/src/sql/execute.rs b/crates/core/src/sql/execute.rs index 23e344f7177..d12e811ca7a 100644 --- a/crates/core/src/sql/execute.rs +++ b/crates/core/src/sql/execute.rs @@ -1126,9 +1126,11 @@ pub(crate) mod tests { Ok(()) } - /// Test we are protected against recursion when: - /// 1. The query is too large + /// Test we are protected against stack overflows when: + /// 1. The query is too large (too many characters) /// 2. The AST is too deep + /// + /// Exercise the limit [`recursion::MAX_RECURSION_EXPR`] #[test] fn test_large_query_no_panic() -> ResultTest<()> { let db = TestDB::durable()?; @@ -1143,13 +1145,11 @@ pub(crate) mod tests { let build_query = |total| { let mut sql = "select * from test where ".to_string(); - for x in 0..total { - for y in 0..total { - let fragment = format!("((x = {x}) and (y = {y})) or "); - sql.push_str(&fragment); - } + for x in 1..total { + let fragment = format!("x = {x} or "); + sql.push_str(&fragment.repeat((total - 1) as usize)); } - sql.push_str("((x = 1000) and (y = 1000))"); + sql.push_str("(y = 0)"); sql }; let run = |db: &RelationalDB, sep: char, sql_text: &str| { @@ -1161,16 +1161,15 @@ pub(crate) mod tests { Err("SQL query exceeds maximum allowed length".to_string()) ); - // Exercise the limit [recursion::MAX_RECURSION_EXPR] && [recursion::MAX_RECURSION_TYP_EXPR] - let sql = build_query(8); + let sql = build_query(41); // This causes stack overflow without the limit assert_eq!(run(&db, ',', &sql), Err("Recursion limit exceeded".to_string())); - let sql = build_query(7); + let sql = build_query(40); // The max we can with the current limit assert!(run(&db, ',', &sql).is_ok(), "Expected query to run without panic"); // Check no overflow with lot of joins let mut sql = "SELECT test.* FROM test ".to_string(); - // We could pust up to 700 joins without overflow as long we don't have any conditions, + // We could push up to 700 joins without overflow as long we don't have any conditions, // but here execution become too slow. // TODO: Move this test to the `Plan` for i in 0..200 { diff --git a/crates/expr/src/check.rs b/crates/expr/src/check.rs index 006878ac26b..14aaa34a9ed 100644 --- a/crates/expr/src/check.rs +++ b/crates/expr/src/check.rs @@ -99,7 +99,7 @@ pub trait TypeChecker { vars.insert(rhs.alias.clone(), rhs.schema.clone()); if let Some(on) = on { - if let Expr::BinOp(BinOp::Eq, a, b) = type_expr(vars, on, Some(&AlgebraicType::Bool), &mut 0)? { + if let Expr::BinOp(BinOp::Eq, a, b) = type_expr(vars, on, Some(&AlgebraicType::Bool))? { if let (Expr::Field(a), Expr::Field(b)) = (*a, *b) { join = RelExpr::EqJoin(LeftDeepJoin { lhs, rhs }, a, b); continue; diff --git a/crates/expr/src/errors.rs b/crates/expr/src/errors.rs index d611ec945d2..c523fb9452e 100644 --- a/crates/expr/src/errors.rs +++ b/crates/expr/src/errors.rs @@ -122,8 +122,6 @@ pub struct DuplicateName(pub String); #[error("`filter!` does not support column projections; Must return table rows")] pub struct FilterReturnType; -// FIXME: reduce type size -#[expect(clippy::large_enum_variant)] #[derive(Error, Debug)] pub enum TypingError { #[error(transparent)] diff --git a/crates/expr/src/lib.rs b/crates/expr/src/lib.rs index 9905351a02a..2d9b3cdc5ab 100644 --- a/crates/expr/src/lib.rs +++ b/crates/expr/src/lib.rs @@ -31,7 +31,7 @@ pub mod statement; pub(crate) fn type_select(input: RelExpr, expr: SqlExpr, vars: &Relvars) -> TypingResult { Ok(RelExpr::Select( Box::new(input), - type_expr(vars, expr, Some(&AlgebraicType::Bool), &mut 0)?, + type_expr(vars, expr, Some(&AlgebraicType::Bool))?, )) } @@ -69,7 +69,7 @@ pub(crate) fn type_proj(input: RelExpr, proj: ast::Project, vars: &Relvars) -> T return Err(DuplicateName(alias.into_string()).into()); } - if let Expr::Field(p) = type_expr(vars, expr.into(), None, &mut 0)? { + if let Expr::Field(p) = type_expr(vars, expr.into(), None)? { projections.push((alias, p)); } } @@ -79,13 +79,12 @@ pub(crate) fn type_proj(input: RelExpr, proj: ast::Project, vars: &Relvars) -> T } } -/// Type check and lower a [SqlExpr] into a logical [Expr]. -pub(crate) fn type_expr( - vars: &Relvars, - expr: SqlExpr, - expected: Option<&AlgebraicType>, - depth: &mut usize, -) -> TypingResult { +// These types determine the size of each stack frame during type checking. +// Changing their sizes will require updating the recursion limit to avoid stack overflows. +const _: () = assert!(size_of::>() == 64); +const _: () = assert!(size_of::() == 40); + +fn _type_expr(vars: &Relvars, expr: SqlExpr, expected: Option<&AlgebraicType>, depth: usize) -> TypingResult { recursion::guard(depth, recursion::MAX_RECURSION_TYP_EXPR, "expr::type_expr")?; match (expr, expected) { @@ -125,21 +124,21 @@ pub(crate) fn type_expr( })) } (SqlExpr::Log(a, b, op), None | Some(AlgebraicType::Bool)) => { - let a = type_expr(vars, *a, Some(&AlgebraicType::Bool), depth)?; - let b = type_expr(vars, *b, Some(&AlgebraicType::Bool), depth)?; + let a = _type_expr(vars, *a, Some(&AlgebraicType::Bool), depth + 1)?; + let b = _type_expr(vars, *b, Some(&AlgebraicType::Bool), depth + 1)?; Ok(Expr::LogOp(op, Box::new(a), Box::new(b))) } (SqlExpr::Bin(a, b, op), None | Some(AlgebraicType::Bool)) if matches!(&*a, SqlExpr::Lit(_)) => { - let b = type_expr(vars, *b, None, depth)?; - let a = type_expr(vars, *a, Some(b.ty()), depth)?; + let b = _type_expr(vars, *b, None, depth + 1)?; + let a = _type_expr(vars, *a, Some(b.ty()), depth + 1)?; if !op_supports_type(op, a.ty()) { return Err(InvalidOp::new(op, a.ty()).into()); } Ok(Expr::BinOp(op, Box::new(a), Box::new(b))) } (SqlExpr::Bin(a, b, op), None | Some(AlgebraicType::Bool)) => { - let a = type_expr(vars, *a, None, depth)?; - let b = type_expr(vars, *b, Some(a.ty()), depth)?; + let a = _type_expr(vars, *a, None, depth + 1)?; + let b = _type_expr(vars, *b, Some(a.ty()), depth + 1)?; if !op_supports_type(op, a.ty()) { return Err(InvalidOp::new(op, a.ty()).into()); } @@ -152,6 +151,11 @@ pub(crate) fn type_expr( } } +/// Type check and lower a [SqlExpr] into a logical [Expr]. +pub(crate) fn type_expr(vars: &Relvars, expr: SqlExpr, expected: Option<&AlgebraicType>) -> TypingResult { + _type_expr(vars, expr, expected, 0) +} + /// Is this type compatible with this binary operator? fn op_supports_type(_op: BinOp, t: &AlgebraicType) -> bool { t.is_bool() diff --git a/crates/expr/src/statement.rs b/crates/expr/src/statement.rs index fcdcaa8bda3..50e9bdb4c22 100644 --- a/crates/expr/src/statement.rs +++ b/crates/expr/src/statement.rs @@ -162,7 +162,7 @@ pub fn type_delete(delete: SqlDelete, tx: &impl SchemaView) -> TypingResult TypingResult(sql: &'a str, tx: &impl SchemaView, auth: &AuthCtx) #[cfg(test)] mod tests { - use spacetimedb_lib::{identity::AuthCtx, AlgebraicType, ProductType}; - use spacetimedb_schema::def::ModuleDef; - + use super::Statement; + use crate::ast::LogOp; use crate::check::{ test_utils::{build_module_def, SchemaViewer}, - SchemaView, TypingResult, + Relvars, SchemaView, TypingResult, }; - - use super::Statement; + use crate::type_expr; + use spacetimedb_lib::{identity::AuthCtx, AlgebraicType, ProductType}; + use spacetimedb_schema::def::ModuleDef; + use spacetimedb_sql_parser::ast::{SqlExpr, SqlLiteral}; fn module_def() -> ModuleDef { build_module_def(vec![ @@ -519,4 +520,27 @@ mod tests { assert!(result.is_err()); } } + + /// Manually build the AST for a recursive query, + /// because we limit the length of the query to prevent stack overflow on parsing. + /// Exercise the limit [`recursion::MAX_RECURSION_TYP_EXPR`] + #[test] + fn typing_recursion() { + let build_query = |total, sep: char| { + let mut expr = SqlExpr::Lit(SqlLiteral::Bool(true)); + for _ in 1..total { + let next = SqlExpr::Log( + Box::new(SqlExpr::Lit(SqlLiteral::Bool(true))), + Box::new(SqlExpr::Lit(SqlLiteral::Bool(false))), + LogOp::And, + ); + expr = SqlExpr::Log(Box::new(expr), Box::new(next), LogOp::And); + } + type_expr(&Relvars::default(), expr, Some(&AlgebraicType::Bool)) + .map_err(|e| e.to_string().split(sep).next().unwrap_or_default().to_string()) + }; + assert_eq!(build_query(2_501, ','), Err("Recursion limit exceeded".to_string())); + + assert!(build_query(2_500, ',').is_ok()); + } } diff --git a/crates/sql-parser/src/parser/errors.rs b/crates/sql-parser/src/parser/errors.rs index 07c50658684..953a031b8b8 100644 --- a/crates/sql-parser/src/parser/errors.rs +++ b/crates/sql-parser/src/parser/errors.rs @@ -50,7 +50,7 @@ pub enum SqlUnsupported { #[error("Unsupported FROM expression: {0}")] From(TableFactor), #[error("Unsupported set operation: {0}")] - SetOp(SetExpr), + SetOp(Box), #[error("Unsupported INSERT expression: {0}")] Insert(Query), #[error("Unsupported INSERT value: {0}")] @@ -94,18 +94,17 @@ pub enum SqlRequired { } #[derive(Error, Debug)] -#[error("Recursion limit exceeded, `{message}` limit: {limit}")] +#[error("Recursion limit exceeded, `{source_}`")] pub struct RecursionError { - pub(crate) limit: usize, - pub(crate) message: String, + pub(crate) source_: &'static str, } #[derive(Error, Debug)] pub enum SqlParseError { #[error(transparent)] - SqlUnsupported(#[from] SqlUnsupported), + SqlUnsupported(#[from] Box), #[error(transparent)] - SubscriptionUnsupported(#[from] SubscriptionUnsupported), + SubscriptionUnsupported(#[from] Box), #[error(transparent)] SqlRequired(#[from] SqlRequired), #[error(transparent)] @@ -113,3 +112,15 @@ pub enum SqlParseError { #[error(transparent)] Recursion(#[from] RecursionError), } + +impl From for SqlParseError { + fn from(value: SubscriptionUnsupported) -> Self { + SqlParseError::SubscriptionUnsupported(Box::new(value)) + } +} + +impl From for SqlParseError { + fn from(value: SqlUnsupported) -> Self { + SqlParseError::SqlUnsupported(Box::new(value)) + } +} diff --git a/crates/sql-parser/src/parser/mod.rs b/crates/sql-parser/src/parser/mod.rs index 176b5630f98..9e6e5642bda 100644 --- a/crates/sql-parser/src/parser/mod.rs +++ b/crates/sql-parser/src/parser/mod.rs @@ -68,7 +68,7 @@ trait RelParser { op: BinaryOperator::Eq, right, }, - &mut 0, + 0, )?), }) } @@ -207,17 +207,22 @@ pub(crate) fn parse_proj(expr: Expr) -> SqlParseResult { } } +// These types determine the size of [`parse_expr`]'s stack frame. +// Changing their sizes will require updating the recursion limit to avoid stack overflows. +const _: () = assert!(size_of::() == 168); +const _: () = assert!(size_of::>() == 40); + /// Parse a scalar expression -pub(crate) fn parse_expr(expr: Expr, depth: &mut usize) -> SqlParseResult { - fn signed_num(sign: impl Into, expr: Expr) -> Result { +fn parse_expr(expr: Expr, depth: usize) -> SqlParseResult { + fn signed_num(sign: impl Into, expr: Expr) -> Result> { match expr { Expr::Value(Value::Number(n, _)) => Ok(SqlExpr::Lit(SqlLiteral::Num((sign.into() + &n).into_boxed_str()))), - expr => Err(SqlUnsupported::Expr(expr)), + expr => Err(SqlUnsupported::Expr(expr).into()), } } recursion::guard(depth, recursion::MAX_RECURSION_EXPR, "sql-parser::parse_expr")?; match expr { - Expr::Nested(expr) => parse_expr(*expr, depth), + Expr::Nested(expr) => parse_expr(*expr, depth + 1), Expr::Value(Value::Placeholder(param)) if ¶m == ":sender" => Ok(SqlExpr::Param(Parameter::Sender)), Expr::Value(v) => Ok(SqlExpr::Lit(parse_literal(v)?)), Expr::UnaryOp { @@ -243,8 +248,8 @@ pub(crate) fn parse_expr(expr: Expr, depth: &mut usize) -> SqlParseResult { - let l = parse_expr(*left, depth)?; - let r = parse_expr(*right, depth)?; + let l = parse_expr(*left, depth + 1)?; + let r = parse_expr(*right, depth + 1)?; Ok(SqlExpr::Log(Box::new(l), Box::new(r), LogOp::And)) } Expr::BinaryOp { @@ -252,13 +257,13 @@ pub(crate) fn parse_expr(expr: Expr, depth: &mut usize) -> SqlParseResult { - let l = parse_expr(*left, depth)?; - let r = parse_expr(*right, depth)?; + let l = parse_expr(*left, depth + 1)?; + let r = parse_expr(*right, depth + 1)?; Ok(SqlExpr::Log(Box::new(l), Box::new(r), LogOp::Or)) } Expr::BinaryOp { left, op, right } => { - let l = parse_expr(*left, depth)?; - let r = parse_expr(*right, depth)?; + let l = parse_expr(*left, depth + 1)?; + let r = parse_expr(*right, depth + 1)?; Ok(SqlExpr::Bin(Box::new(l), Box::new(r), parse_binop(op)?)) } _ => Err(SqlUnsupported::Expr(expr).into()), @@ -266,8 +271,8 @@ pub(crate) fn parse_expr(expr: Expr, depth: &mut usize) -> SqlParseResult, depth: &mut usize) -> SqlParseResult> { - opt.map(|expr| parse_expr(expr, depth)).transpose() +pub(crate) fn parse_expr_opt(opt: Option) -> SqlParseResult> { + opt.map(|expr| parse_expr(expr, 0)).transpose() } /// Parse a scalar binary operator diff --git a/crates/sql-parser/src/parser/recursion.rs b/crates/sql-parser/src/parser/recursion.rs index b85bd2dccec..f4e2b1dec65 100644 --- a/crates/sql-parser/src/parser/recursion.rs +++ b/crates/sql-parser/src/parser/recursion.rs @@ -1,33 +1,25 @@ -//! A utility for guarding against excessive recursion depth in the SQL parser. +//! A utility for guarding against stack overflows in the SQL parser. //! -//! Different parts of the parser may have different recursion limits. -//! -//! Removing one could allow the others to be higher, but depending on how the `SQL` is structured, it could lead to a `stack overflow` -//! if is not guarded against, so is incorrect to assume that a limit is sufficient for the next part of the parser. +//! Different parts of the parser may have different recursion limits, based in the size of the structures they parse. + use crate::parser::errors::{RecursionError, SqlParseError}; -use std::fmt::Display; /// A conservative limit for recursion depth on `parse_expr`. -pub const MAX_RECURSION_EXPR: usize = 700; +pub const MAX_RECURSION_EXPR: usize = 1_600; /// A conservative limit for recursion depth on `type_expr`. -pub const MAX_RECURSION_TYP_EXPR: usize = 5_000; +pub const MAX_RECURSION_TYP_EXPR: usize = 2_500; -/// A utility for guarding against excessive recursion depth. +/// A utility for guarding against stack overflows in the SQL parser. /// /// **Usage:** /// ``` /// use spacetimedb_sql_parser::parser::recursion; /// let mut depth = 0; -/// assert!(recursion::guard(&mut depth, 10, "test").is_ok()); +/// assert!(recursion::guard(depth, 10, "test").is_ok()); /// ``` -pub fn guard(depth: &mut usize, limit: usize, msg: impl Display) -> Result<(), SqlParseError> { - *depth += 1; - if *depth > limit { - Err(RecursionError { - limit, - message: msg.to_string(), - } - .into()) +pub fn guard(depth: usize, limit: usize, source: &'static str) -> Result<(), SqlParseError> { + if depth > limit { + Err(RecursionError { source_: source }.into()) } else { Ok(()) } diff --git a/crates/sql-parser/src/parser/sql.rs b/crates/sql-parser/src/parser/sql.rs index caedc28e299..a1eb5078726 100644 --- a/crates/sql-parser/src/parser/sql.rs +++ b/crates/sql-parser/src/parser/sql.rs @@ -202,7 +202,7 @@ fn parse_statement(stmt: Statement) -> SqlParseResult { } if joins.is_empty() && with_hints.is_empty() && partitions.is_empty() => Ok(SqlAst::Update(SqlUpdate { table: parse_ident(name)?, assignments: parse_assignments(assignments)?, - filter: parse_expr_opt(selection, &mut 0)?, + filter: parse_expr_opt(selection)?, })), Statement::Delete { tables, @@ -297,7 +297,7 @@ fn parse_delete(mut from: Vec, selection: Option) -> SqlPa joins, } if joins.is_empty() && with_hints.is_empty() && partitions.is_empty() => Ok(SqlDelete { table: parse_ident(name)?, - filter: parse_expr_opt(selection, &mut 0)?, + filter: parse_expr_opt(selection)?, }), t => Err(SqlUnsupported::DeleteTable(t).into()), } @@ -395,7 +395,7 @@ fn parse_select(select: Select, limit: Option>) -> SqlParseResult SqlParseResult { match expr { SetExpr::Select(select) => parse_select(*select).map(SqlSelect::qualify_vars), - _ => Err(SqlUnsupported::SetOp(expr).into()), + _ => Err(SqlUnsupported::SetOp(Box::new(expr)).into()), } } @@ -142,7 +142,7 @@ fn parse_select(select: Select) -> SqlParseResult { { Ok(SqlSelect { from: SubParser::parse_from(from)?, - filter: parse_expr_opt(selection, &mut 0)?, + filter: parse_expr_opt(selection)?, project: parse_projection(projection)?, }) }