Skip to content

Add query! macro providing a more ergonomic way to create parmeterized queries #214

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ pub use crate::errors::{
Error, Neo4jClientErrorKind, Neo4jError, Neo4jErrorKind, Neo4jSecurityErrorKind, Result,
};
pub use crate::graph::{query, Graph};
pub use crate::query::Query;
pub use crate::query::{Query, QueryParameter};
pub use crate::row::{Node, Path, Point2D, Point3D, Relation, Row, UnboundedRelation};
pub use crate::stream::{DetachedRowStream, RowItem, RowStream};
pub use crate::txn::Txn;
Expand Down
183 changes: 183 additions & 0 deletions lib/src/query.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::cell::{Cell, RefCell};

#[cfg(feature = "unstable-bolt-protocol-impl-v2")]
use crate::bolt::{Discard, Summary, WrapExtra as _};
use crate::{
Expand Down Expand Up @@ -26,6 +28,11 @@ impl Query {
}
}

pub fn with_params(mut self, params: BoltMap) -> Self {
self.params = params;
self
}

pub fn param<T: Into<BoltType>>(mut self, key: &str, value: T) -> Self {
self.params.put(key.into(), value.into());
self
Expand Down Expand Up @@ -68,6 +75,14 @@ impl Query {
self.extra.value.contains_key(key)
}

pub fn query(&self) -> &str {
&self.query
}

pub fn get_params(&self) -> &BoltMap {
&self.params
}

pub(crate) async fn run(self, connection: &mut ManagedConnection) -> Result<()> {
let request = BoltRequest::run(&self.query, self.params, self.extra);
Self::try_run(request, connection)
Expand Down Expand Up @@ -170,6 +185,15 @@ impl From<&str> for Query {
}
}

impl std::fmt::Debug for Query {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Query")
.field("query", &self.query)
.field("params", &self.params)
.finish_non_exhaustive()
}
}

type QueryResult<T> = Result<T, backoff::Error<Error>>;

fn wrap_error<T>(resp: impl IntoError, req: &'static str) -> QueryResult<T> {
Expand Down Expand Up @@ -217,6 +241,142 @@ fn unwrap_backoff(err: backoff::Error<Error>) -> Error {
}
}

#[doc(hidden)]
pub struct QueryParameter<'x, T> {
value: Cell<Option<T>>,
name: &'static str,
params: &'x RefCell<BoltMap>,
}

impl<'x, T: Into<BoltType>> QueryParameter<'x, T> {
#[allow(dead_code)]
pub fn new(value: T, name: &'static str, params: &'x RefCell<BoltMap>) -> Self {
Self {
value: Cell::new(Some(value)),
name,
params,
}
}
}

impl<T: Into<BoltType>> std::fmt::Display for QueryParameter<'_, T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let Some(v) = self.value.replace(None) else {
return Err(std::fmt::Error);
};
self.params.borrow_mut().put(self.name.into(), v.into());
write!(f, "${}", self.name)
}
}

impl<T: Into<BoltType>> std::fmt::Debug for QueryParameter<'_, T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
std::fmt::Display::fmt(self, f)
}
}

/// Create a query with a format! like syntax
///
/// `query!` works similar to `format!`:
/// - The first argument is the query string with `{<name>}` placeholders
/// - Following that is a list of `name = value` parmeters arguments
/// - All placeholders in the query strings are replaced with query parameters
///
/// The macro is a compiler-supported alternative to using the `params` method on `Query`.
///
/// ## Differences from `format!` and limitations
///
/// - Implicit `{name}` bindings without adding a `name = <value>` argument does not
/// actually create a new parameter; It does default string interpolation instead.
/// - Formatting parameters are largely ignored and have no effect on the query string.
/// - Argument values need to implement `Into<BoltType>` instead of `Display`
/// (and don't need to implement the latter)
/// - Only named placeholders syntax is supported (`{<name>}` instead of `{}`)
/// - This is because query parameters are always named
/// - By extension, adding an unnamed argument (e.g. `<value>` instead of `name = <value>`) is also not supported
///
/// # Examples
///
/// ```
/// use neo4rs::{query, Query};
///
/// // This creates an unparametrized query.
/// let q: Query = query!("MATCH (n) RETURN n");
/// assert_eq!(q.query(), "MATCH (n) RETURN n");
/// assert!(q.get_params().is_empty());
///
/// // This creates a parametrized query.
/// let q: Query = query!("MATCH (n) WHERE n.value = {answer} RETURN n", answer = 42);
/// assert_eq!(q.query(), "MATCH (n) WHERE n.value = $answer RETURN n");
/// assert_eq!(q.get_params().get::<i64>("answer").unwrap(), 42);
///
/// // by contrast, using the implicit string interpolation syntax does not
/// // create a parameter, effectively being the same as `format!`.
/// let answer = 42;
/// let q: Query = query!("MATCH (n) WHERE n.value = {answer} RETURN n");
/// assert_eq!(q.query(), "MATCH (n) WHERE n.value = 42 RETURN n");
/// assert!(q.has_param_key("answer") == false);
///
/// // The value can be any type that implements Into<BoltType>, it does not
/// // need to implement Display or Debug.
/// use neo4rs::{BoltInteger, BoltType};
///
/// struct Answer;
/// impl Into<BoltType> for Answer {
/// fn into(self) -> BoltType {
/// BoltType::Integer(BoltInteger::new(42))
/// }
/// }
///
/// let q: Query = query!("MATCH (n) WHERE n.value = {answer} RETURN n", answer = Answer);
/// assert_eq!(q.query(), "MATCH (n) WHERE n.value = $answer RETURN n");
/// assert_eq!(q.get_params().get::<i64>("answer").unwrap(), 42);
/// ```
#[macro_export]
macro_rules! query {
// Create a unparametrized query
($query:expr) => {
$crate::Query::new(format!($query))
};

// Create a parametrized query with a format! like syntax
($query:expr $(, $($input:tt)*)?) => {
$crate::query!(@internal $query, [] $(; $($input)*)?)
};

(@internal $query:expr, [$($acc:tt)*]; $name:ident = $value:expr $(, $($rest:tt)*)?) => {
$crate::query!(@internal $query, [$($acc)* ($name = $value)] $(; $($rest)*)?)
};

(@internal $query:expr, [$($acc:tt)*]; $value:expr $(, $($rest:tt)*)?) => {
compile_error!("Only named parameter syntax (`name = value`) is supported");
};

(@internal $query:expr, [$($acc:tt)*];) => {
$crate::query!(@final $query; $($acc)*)
};

(@internal $query:expr, [$($acc:tt)*]) => {
$crate::query!(@final $query; $($acc)*)
};

(@final $query:expr; $(($name:ident = $value:expr))*) => {{
let params = $crate::BoltMap::default();
let params = ::std::cell::RefCell::new(params);

let query = format!($query, $(
$name = $crate::QueryParameter::new(
$value,
stringify!($name),
&params,
),
)*);
let params = params.into_inner();

$crate::Query::new(query).with_params(params)
}};
}

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -238,4 +398,27 @@ mod tests {
assert!(q.has_param_key("name"));
assert!(!q.has_param_key("country"));
}

#[test]
fn query_macro() {
let q = query!(
"MATCH (n) WHERE n.name = {name} AND n.age > {age} RETURN n",
age = 42,
name = "Frobniscante",
);

assert_eq!(
q.query.as_str(),
"MATCH (n) WHERE n.name = $name AND n.age > $age RETURN n"
);

assert_eq!(
q.params.get::<String>("name").unwrap(),
String::from("Frobniscante")
);
assert_eq!(q.params.get::<i64>("age").unwrap(), 42);

assert!(q.has_param_key("name"));
assert!(!q.has_param_key("country"));
}
}
5 changes: 4 additions & 1 deletion lib/tests/missing_properties.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ async fn missing_properties() {

let a_val = None::<String>;
let mut result = graph
.execute(query("CREATE (ts:TestStruct {a: $a}) RETURN ts").param("a", a_val))
.execute(query!(
"CREATE (ts:TestStruct {{a: {a}}}) RETURN ts",
a = a_val
))
.await
.unwrap();
let row = result.next().await.unwrap().unwrap();
Expand Down
23 changes: 9 additions & 14 deletions lib/tests/txn_change_db.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use futures::TryStreamExt;
use neo4rs::*;
use neo4rs::query;
use serde::Deserialize;

mod container;
Expand All @@ -19,7 +19,7 @@ async fn txn_changes_db() {
return;
}

std::panic::panic_any(e);
std::panic::panic_any(e.to_string());
}
};
let graph = neo4j.graph();
Expand All @@ -46,18 +46,13 @@ async fn txn_changes_db() {

let mut txn = graph.start_txn().await.unwrap();
let mut databases = txn
.execute(
query(&format!(
concat!(
"SHOW TRANSACTIONS YIELD * WHERE username = $username AND currentQuery ",
"STARTS WITH $query AND toLower({status_field}) = $status RETURN database"
),
status_field = status_field
))
.param("username", "neo4j")
.param("query", "SHOW TRANSACTIONS YIELD ")
.param("status", "running"),
)
.execute(query!(
"SHOW TRANSACTIONS YIELD * WHERE username = {username} AND currentQuery
STARTS WITH {query} AND toLower({status_field}) = {status} RETURN database",
username = "neo4j",
query = "SHOW TRANSACTIONS YIELD ",
status = "running",
))
.await
.unwrap();

Expand Down
17 changes: 12 additions & 5 deletions lib/tests/use_default_db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,21 @@ async fn use_default_db() {

let id = uuid::Uuid::new_v4();
graph
.run(query("CREATE (:Node { uuid: $uuid })").param("uuid", id.to_string()))
.run(query!(
"CREATE (:Node {{ uuid: {uuid} }})",
uuid = id.to_string()
))
.await
.unwrap();

#[cfg(feature = "unstable-bolt-protocol-impl-v2")]
let query_stream = graph
.execute_on(
dbname.as_str(),
query("MATCH (n:Node {uuid: $uuid}) RETURN count(n) AS result")
.param("uuid", id.to_string()),
query!(
"MATCH (n:Node {{uuid: {uuid}}}) RETURN count(n) AS result",
uuid = id.to_string()
),
Operation::Read,
)
.await;
Expand All @@ -72,8 +77,10 @@ async fn use_default_db() {
let query_stream = graph
.execute_on(
dbname.as_str(),
query("MATCH (n:Node {uuid: $uuid}) RETURN count(n) AS result")
.param("uuid", id.to_string()),
query!(
"MATCH (n:Node {{uuid: {uuid}}}) RETURN count(n) AS result",
uuid = id.to_string()
),
)
.await;

Expand Down
Loading