diff --git a/crates/pg_completions/src/context.rs b/crates/pg_completions/src/context.rs index 82b35b30..79354cf6 100644 --- a/crates/pg_completions/src/context.rs +++ b/crates/pg_completions/src/context.rs @@ -18,7 +18,7 @@ impl TryFrom<&str> for ClauseType { match value { "select" => Ok(Self::Select), "where" => Ok(Self::Where), - "from" => Ok(Self::From), + "from" | "keyword_from" => Ok(Self::From), "update" => Ok(Self::Update), "delete" => Ok(Self::Delete), _ => { @@ -88,10 +88,22 @@ impl<'a> CompletionContext<'a> { let mut cursor = self.tree.as_ref().unwrap().root_node().walk(); - // go to the statement node that matches the position + /* + * The head node of any treesitter tree is always the "PROGRAM" node. + * + * We want to enter the next layer and focus on the child node that matches the user's cursor position. + * If there is no node under the users position, however, the cursor won't enter the next level – it + * will stay on the Program node. + * + * This might lead to an unexpected context or infinite recursion. + * + * We'll therefore adjust the cursor position such that it meets the last node of the AST. + * `select * from use {}` becomes `select * from use{}`. + */ let current_node_kind = cursor.node().kind(); - - cursor.goto_first_child_for_byte(self.position); + while cursor.goto_first_child_for_byte(self.position).is_none() && self.position > 0 { + self.position -= 1; + } self.gather_context_from_node(cursor, current_node_kind); } @@ -104,6 +116,12 @@ impl<'a> CompletionContext<'a> { let current_node = cursor.node(); let current_node_kind = current_node.kind(); + // prevent infinite recursion – this can happen if we only have a PROGRAM node + if current_node_kind == previous_node_kind { + self.ts_node = Some(current_node); + return; + } + match previous_node_kind { "statement" => self.wrapping_clause_type = current_node_kind.try_into().ok(), "invocation" => self.is_invocation = true, @@ -127,9 +145,14 @@ impl<'a> CompletionContext<'a> { self.wrapping_clause_type = "where".try_into().ok(); } + "keyword_from" => { + self.wrapping_clause_type = "keyword_from".try_into().ok(); + } + _ => {} } + // We have arrived at the leaf node if current_node.child_count() == 0 { self.ts_node = Some(current_node); return; @@ -142,7 +165,10 @@ impl<'a> CompletionContext<'a> { #[cfg(test)] mod tests { - use crate::{context::CompletionContext, test_helper::CURSOR_POS}; + use crate::{ + context::{ClauseType, CompletionContext}, + test_helper::{get_text_and_position, CURSOR_POS}, + }; fn get_tree(input: &str) -> tree_sitter::Tree { let mut parser = tree_sitter::Parser::new(); @@ -182,11 +208,11 @@ mod tests { ), ]; - for (text, expected_clause) in test_cases { - let position = text.find(CURSOR_POS).unwrap(); - let text = text.replace(CURSOR_POS, ""); + for (query, expected_clause) in test_cases { + let (position, text) = get_text_and_position(query.as_str()); let tree = get_tree(text.as_str()); + let params = crate::CompletionParams { position: (position as u32).into(), text: text, @@ -215,9 +241,8 @@ mod tests { (format!("Select * from u{}sers()", CURSOR_POS), None), ]; - for (text, expected_schema) in test_cases { - let position = text.find(CURSOR_POS).unwrap(); - let text = text.replace(CURSOR_POS, ""); + for (query, expected_schema) in test_cases { + let (position, text) = get_text_and_position(query.as_str()); let tree = get_tree(text.as_str()); let params = crate::CompletionParams { @@ -250,9 +275,8 @@ mod tests { ), ]; - for (text, is_invocation) in test_cases { - let position = text.find(CURSOR_POS).unwrap(); - let text = text.replace(CURSOR_POS, ""); + for (query, is_invocation) in test_cases { + let (position, text) = get_text_and_position(query.as_str()); let tree = get_tree(text.as_str()); let params = crate::CompletionParams { @@ -267,4 +291,110 @@ mod tests { assert_eq!(ctx.is_invocation, is_invocation); } } + + #[test] + fn does_not_fail_on_leading_whitespace() { + let cases = vec![ + format!("{} select * from", CURSOR_POS), + format!(" {} select * from", CURSOR_POS), + ]; + + for query in cases { + let (position, text) = get_text_and_position(query.as_str()); + + let tree = get_tree(text.as_str()); + + let params = crate::CompletionParams { + position: (position as u32).into(), + text: text, + tree: Some(&tree), + schema: &pg_schema_cache::SchemaCache::new(), + }; + + let ctx = CompletionContext::new(¶ms); + + let node = ctx.ts_node.map(|n| n.clone()).unwrap(); + + assert_eq!(ctx.get_ts_node_content(node), Some("select")); + + assert_eq!( + ctx.wrapping_clause_type, + Some(crate::context::ClauseType::Select) + ); + } + } + + #[test] + fn does_not_fail_on_trailing_whitespace() { + let query = format!("select * from {}", CURSOR_POS); + + let (position, text) = get_text_and_position(query.as_str()); + + let tree = get_tree(text.as_str()); + + let params = crate::CompletionParams { + position: (position as u32).into(), + text: text, + tree: Some(&tree), + schema: &pg_schema_cache::SchemaCache::new(), + }; + + let ctx = CompletionContext::new(¶ms); + + let node = ctx.ts_node.map(|n| n.clone()).unwrap(); + + assert_eq!(ctx.get_ts_node_content(node), Some("from")); + assert_eq!( + ctx.wrapping_clause_type, + Some(crate::context::ClauseType::From) + ); + } + + #[test] + fn does_not_fail_with_empty_statements() { + let query = format!("{}", CURSOR_POS); + + let (position, text) = get_text_and_position(query.as_str()); + + let tree = get_tree(text.as_str()); + + let params = crate::CompletionParams { + position: (position as u32).into(), + text: text, + tree: Some(&tree), + schema: &pg_schema_cache::SchemaCache::new(), + }; + + let ctx = CompletionContext::new(¶ms); + + let node = ctx.ts_node.map(|n| n.clone()).unwrap(); + + assert_eq!(ctx.get_ts_node_content(node), Some("")); + assert_eq!(ctx.wrapping_clause_type, None); + } + + #[test] + fn does_not_fail_on_incomplete_keywords() { + // Instead of autocompleting "FROM", we'll assume that the user + // is selecting a certain column name, such as `frozen_account`. + let query = format!("select * fro{}", CURSOR_POS); + + let (position, text) = get_text_and_position(query.as_str()); + + let tree = get_tree(text.as_str()); + + let params = crate::CompletionParams { + position: (position as u32).into(), + text: text, + tree: Some(&tree), + schema: &pg_schema_cache::SchemaCache::new(), + }; + + let ctx = CompletionContext::new(¶ms); + + let node = ctx.ts_node.map(|n| n.clone()).unwrap(); + + assert_eq!(ctx.get_ts_node_content(node), Some("fro")); + assert_eq!(ctx.wrapping_clause_type, Some(ClauseType::Select)); + } } diff --git a/crates/pg_completions/src/providers/tables.rs b/crates/pg_completions/src/providers/tables.rs index 5faa710e..70574ec8 100644 --- a/crates/pg_completions/src/providers/tables.rs +++ b/crates/pg_completions/src/providers/tables.rs @@ -75,8 +75,7 @@ mod tests { let test_cases = vec![ (format!("select * from us{}", CURSOR_POS), "users"), (format!("select * from em{}", CURSOR_POS), "emails"), - // TODO: Fix queries with tree-sitter errors. - // (format!("select * from {}", CURSOR_POS), "addresses"), + (format!("select * from {}", CURSOR_POS), "addresses"), ]; for (query, expected_label) in test_cases { diff --git a/crates/pg_completions/src/test_helper.rs b/crates/pg_completions/src/test_helper.rs index f1511b94..08a3af2e 100644 --- a/crates/pg_completions/src/test_helper.rs +++ b/crates/pg_completions/src/test_helper.rs @@ -31,16 +31,25 @@ pub(crate) async fn get_test_deps( (tree, schema_cache) } +pub(crate) fn get_text_and_position(sql: &str) -> (usize, String) { + // the cursor is to the left of the `CURSOR_POS` + let position = sql + .find(|c| c == CURSOR_POS) + .expect("Please insert the CURSOR_POS into your query.") + .checked_sub(1) + .unwrap_or(0); + + let text = sql.replace(CURSOR_POS, ""); + + (position, text) +} + pub(crate) fn get_test_params<'a>( tree: &'a tree_sitter::Tree, schema_cache: &'a pg_schema_cache::SchemaCache, sql: &'a str, ) -> CompletionParams<'a> { - let position = sql - .find(|c| c == CURSOR_POS) - .expect("Please insert the CURSOR_POS into your query."); - - let text = sql.replace(CURSOR_POS, ""); + let (position, text) = get_text_and_position(sql); CompletionParams { position: (position as u32).into(), diff --git a/crates/pg_test_utils/src/bin/tree_print.rs b/crates/pg_test_utils/src/bin/tree_print.rs index 8a04365e..469dcc8e 100644 --- a/crates/pg_test_utils/src/bin/tree_print.rs +++ b/crates/pg_test_utils/src/bin/tree_print.rs @@ -29,7 +29,13 @@ fn main() { fn print_tree(node: &tree_sitter::Node, source: &str, level: usize) { let indent = " ".repeat(level); - let node_text = node.utf8_text(source.as_bytes()).unwrap_or("NO_NAME"); + + let node_text = node + .utf8_text(source.as_bytes()) + .unwrap_or("NO_NAME") + .split_whitespace() + .collect::>() + .join(" "); println!( "{}{} [{}..{}] '{}'", diff --git a/justfile b/justfile index 55002e49..5d33f513 100644 --- a/justfile +++ b/justfile @@ -130,6 +130,10 @@ new-crate name: cargo new --lib crates/{{snakecase(name)}} cargo run -p xtask_codegen -- new-crate --name={{snakecase(name)}} +# Prints the treesitter tree of the given SQL file +tree-print file: + cargo run --bin tree_print -- -f {{file}} + # Creates a new changeset for the final changelog # new-changeset: # knope document-change