Skip to content

Commit e3bc0c2

Browse files
authored
Merge pull request #161 from juleswritescode/fix/tree-sitter-errors
2 parents 88653d1 + fda6044 commit e3bc0c2

File tree

5 files changed

+170
-22
lines changed

5 files changed

+170
-22
lines changed

crates/pg_completions/src/context.rs

+144-14
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ impl TryFrom<&str> for ClauseType {
1818
match value {
1919
"select" => Ok(Self::Select),
2020
"where" => Ok(Self::Where),
21-
"from" => Ok(Self::From),
21+
"from" | "keyword_from" => Ok(Self::From),
2222
"update" => Ok(Self::Update),
2323
"delete" => Ok(Self::Delete),
2424
_ => {
@@ -88,10 +88,22 @@ impl<'a> CompletionContext<'a> {
8888

8989
let mut cursor = self.tree.as_ref().unwrap().root_node().walk();
9090

91-
// go to the statement node that matches the position
91+
/*
92+
* The head node of any treesitter tree is always the "PROGRAM" node.
93+
*
94+
* We want to enter the next layer and focus on the child node that matches the user's cursor position.
95+
* If there is no node under the users position, however, the cursor won't enter the next level – it
96+
* will stay on the Program node.
97+
*
98+
* This might lead to an unexpected context or infinite recursion.
99+
*
100+
* We'll therefore adjust the cursor position such that it meets the last node of the AST.
101+
* `select * from use {}` becomes `select * from use{}`.
102+
*/
92103
let current_node_kind = cursor.node().kind();
93-
94-
cursor.goto_first_child_for_byte(self.position);
104+
while cursor.goto_first_child_for_byte(self.position).is_none() && self.position > 0 {
105+
self.position -= 1;
106+
}
95107

96108
self.gather_context_from_node(cursor, current_node_kind);
97109
}
@@ -104,6 +116,12 @@ impl<'a> CompletionContext<'a> {
104116
let current_node = cursor.node();
105117
let current_node_kind = current_node.kind();
106118

119+
// prevent infinite recursion – this can happen if we only have a PROGRAM node
120+
if current_node_kind == previous_node_kind {
121+
self.ts_node = Some(current_node);
122+
return;
123+
}
124+
107125
match previous_node_kind {
108126
"statement" => self.wrapping_clause_type = current_node_kind.try_into().ok(),
109127
"invocation" => self.is_invocation = true,
@@ -127,9 +145,14 @@ impl<'a> CompletionContext<'a> {
127145
self.wrapping_clause_type = "where".try_into().ok();
128146
}
129147

148+
"keyword_from" => {
149+
self.wrapping_clause_type = "keyword_from".try_into().ok();
150+
}
151+
130152
_ => {}
131153
}
132154

155+
// We have arrived at the leaf node
133156
if current_node.child_count() == 0 {
134157
self.ts_node = Some(current_node);
135158
return;
@@ -142,7 +165,10 @@ impl<'a> CompletionContext<'a> {
142165

143166
#[cfg(test)]
144167
mod tests {
145-
use crate::{context::CompletionContext, test_helper::CURSOR_POS};
168+
use crate::{
169+
context::{ClauseType, CompletionContext},
170+
test_helper::{get_text_and_position, CURSOR_POS},
171+
};
146172

147173
fn get_tree(input: &str) -> tree_sitter::Tree {
148174
let mut parser = tree_sitter::Parser::new();
@@ -182,11 +208,11 @@ mod tests {
182208
),
183209
];
184210

185-
for (text, expected_clause) in test_cases {
186-
let position = text.find(CURSOR_POS).unwrap();
187-
let text = text.replace(CURSOR_POS, "");
211+
for (query, expected_clause) in test_cases {
212+
let (position, text) = get_text_and_position(query.as_str());
188213

189214
let tree = get_tree(text.as_str());
215+
190216
let params = crate::CompletionParams {
191217
position: (position as u32).into(),
192218
text: text,
@@ -215,9 +241,8 @@ mod tests {
215241
(format!("Select * from u{}sers()", CURSOR_POS), None),
216242
];
217243

218-
for (text, expected_schema) in test_cases {
219-
let position = text.find(CURSOR_POS).unwrap();
220-
let text = text.replace(CURSOR_POS, "");
244+
for (query, expected_schema) in test_cases {
245+
let (position, text) = get_text_and_position(query.as_str());
221246

222247
let tree = get_tree(text.as_str());
223248
let params = crate::CompletionParams {
@@ -250,9 +275,8 @@ mod tests {
250275
),
251276
];
252277

253-
for (text, is_invocation) in test_cases {
254-
let position = text.find(CURSOR_POS).unwrap();
255-
let text = text.replace(CURSOR_POS, "");
278+
for (query, is_invocation) in test_cases {
279+
let (position, text) = get_text_and_position(query.as_str());
256280

257281
let tree = get_tree(text.as_str());
258282
let params = crate::CompletionParams {
@@ -267,4 +291,110 @@ mod tests {
267291
assert_eq!(ctx.is_invocation, is_invocation);
268292
}
269293
}
294+
295+
#[test]
296+
fn does_not_fail_on_leading_whitespace() {
297+
let cases = vec![
298+
format!("{} select * from", CURSOR_POS),
299+
format!(" {} select * from", CURSOR_POS),
300+
];
301+
302+
for query in cases {
303+
let (position, text) = get_text_and_position(query.as_str());
304+
305+
let tree = get_tree(text.as_str());
306+
307+
let params = crate::CompletionParams {
308+
position: (position as u32).into(),
309+
text: text,
310+
tree: Some(&tree),
311+
schema: &pg_schema_cache::SchemaCache::new(),
312+
};
313+
314+
let ctx = CompletionContext::new(&params);
315+
316+
let node = ctx.ts_node.map(|n| n.clone()).unwrap();
317+
318+
assert_eq!(ctx.get_ts_node_content(node), Some("select"));
319+
320+
assert_eq!(
321+
ctx.wrapping_clause_type,
322+
Some(crate::context::ClauseType::Select)
323+
);
324+
}
325+
}
326+
327+
#[test]
328+
fn does_not_fail_on_trailing_whitespace() {
329+
let query = format!("select * from {}", CURSOR_POS);
330+
331+
let (position, text) = get_text_and_position(query.as_str());
332+
333+
let tree = get_tree(text.as_str());
334+
335+
let params = crate::CompletionParams {
336+
position: (position as u32).into(),
337+
text: text,
338+
tree: Some(&tree),
339+
schema: &pg_schema_cache::SchemaCache::new(),
340+
};
341+
342+
let ctx = CompletionContext::new(&params);
343+
344+
let node = ctx.ts_node.map(|n| n.clone()).unwrap();
345+
346+
assert_eq!(ctx.get_ts_node_content(node), Some("from"));
347+
assert_eq!(
348+
ctx.wrapping_clause_type,
349+
Some(crate::context::ClauseType::From)
350+
);
351+
}
352+
353+
#[test]
354+
fn does_not_fail_with_empty_statements() {
355+
let query = format!("{}", CURSOR_POS);
356+
357+
let (position, text) = get_text_and_position(query.as_str());
358+
359+
let tree = get_tree(text.as_str());
360+
361+
let params = crate::CompletionParams {
362+
position: (position as u32).into(),
363+
text: text,
364+
tree: Some(&tree),
365+
schema: &pg_schema_cache::SchemaCache::new(),
366+
};
367+
368+
let ctx = CompletionContext::new(&params);
369+
370+
let node = ctx.ts_node.map(|n| n.clone()).unwrap();
371+
372+
assert_eq!(ctx.get_ts_node_content(node), Some(""));
373+
assert_eq!(ctx.wrapping_clause_type, None);
374+
}
375+
376+
#[test]
377+
fn does_not_fail_on_incomplete_keywords() {
378+
// Instead of autocompleting "FROM", we'll assume that the user
379+
// is selecting a certain column name, such as `frozen_account`.
380+
let query = format!("select * fro{}", CURSOR_POS);
381+
382+
let (position, text) = get_text_and_position(query.as_str());
383+
384+
let tree = get_tree(text.as_str());
385+
386+
let params = crate::CompletionParams {
387+
position: (position as u32).into(),
388+
text: text,
389+
tree: Some(&tree),
390+
schema: &pg_schema_cache::SchemaCache::new(),
391+
};
392+
393+
let ctx = CompletionContext::new(&params);
394+
395+
let node = ctx.ts_node.map(|n| n.clone()).unwrap();
396+
397+
assert_eq!(ctx.get_ts_node_content(node), Some("fro"));
398+
assert_eq!(ctx.wrapping_clause_type, Some(ClauseType::Select));
399+
}
270400
}

crates/pg_completions/src/providers/tables.rs

+1-2
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,7 @@ mod tests {
7575
let test_cases = vec![
7676
(format!("select * from us{}", CURSOR_POS), "users"),
7777
(format!("select * from em{}", CURSOR_POS), "emails"),
78-
// TODO: Fix queries with tree-sitter errors.
79-
// (format!("select * from {}", CURSOR_POS), "addresses"),
78+
(format!("select * from {}", CURSOR_POS), "addresses"),
8079
];
8180

8281
for (query, expected_label) in test_cases {

crates/pg_completions/src/test_helper.rs

+14-5
Original file line numberDiff line numberDiff line change
@@ -31,16 +31,25 @@ pub(crate) async fn get_test_deps(
3131
(tree, schema_cache)
3232
}
3333

34+
pub(crate) fn get_text_and_position(sql: &str) -> (usize, String) {
35+
// the cursor is to the left of the `CURSOR_POS`
36+
let position = sql
37+
.find(|c| c == CURSOR_POS)
38+
.expect("Please insert the CURSOR_POS into your query.")
39+
.checked_sub(1)
40+
.unwrap_or(0);
41+
42+
let text = sql.replace(CURSOR_POS, "");
43+
44+
(position, text)
45+
}
46+
3447
pub(crate) fn get_test_params<'a>(
3548
tree: &'a tree_sitter::Tree,
3649
schema_cache: &'a pg_schema_cache::SchemaCache,
3750
sql: &'a str,
3851
) -> CompletionParams<'a> {
39-
let position = sql
40-
.find(|c| c == CURSOR_POS)
41-
.expect("Please insert the CURSOR_POS into your query.");
42-
43-
let text = sql.replace(CURSOR_POS, "");
52+
let (position, text) = get_text_and_position(sql);
4453

4554
CompletionParams {
4655
position: (position as u32).into(),

crates/pg_test_utils/src/bin/tree_print.rs

+7-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,13 @@ fn main() {
2929

3030
fn print_tree(node: &tree_sitter::Node, source: &str, level: usize) {
3131
let indent = " ".repeat(level);
32-
let node_text = node.utf8_text(source.as_bytes()).unwrap_or("NO_NAME");
32+
33+
let node_text = node
34+
.utf8_text(source.as_bytes())
35+
.unwrap_or("NO_NAME")
36+
.split_whitespace()
37+
.collect::<Vec<&str>>()
38+
.join(" ");
3339

3440
println!(
3541
"{}{} [{}..{}] '{}'",

justfile

+4
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,10 @@ new-crate name:
130130
cargo new --lib crates/{{snakecase(name)}}
131131
cargo run -p xtask_codegen -- new-crate --name={{snakecase(name)}}
132132

133+
# Prints the treesitter tree of the given SQL file
134+
tree-print file:
135+
cargo run --bin tree_print -- -f {{file}}
136+
133137
# Creates a new changeset for the final changelog
134138
# new-changeset:
135139
# knope document-change

0 commit comments

Comments
 (0)