@@ -18,7 +18,7 @@ impl TryFrom<&str> for ClauseType {
18
18
match value {
19
19
"select" => Ok ( Self :: Select ) ,
20
20
"where" => Ok ( Self :: Where ) ,
21
- "from" => Ok ( Self :: From ) ,
21
+ "from" | "keyword_from" => Ok ( Self :: From ) ,
22
22
"update" => Ok ( Self :: Update ) ,
23
23
"delete" => Ok ( Self :: Delete ) ,
24
24
_ => {
@@ -88,10 +88,22 @@ impl<'a> CompletionContext<'a> {
88
88
89
89
let mut cursor = self . tree . as_ref ( ) . unwrap ( ) . root_node ( ) . walk ( ) ;
90
90
91
- // go to the statement node that matches the position
91
+ /*
92
+ * The head node of any treesitter tree is always the "PROGRAM" node.
93
+ *
94
+ * We want to enter the next layer and focus on the child node that matches the user's cursor position.
95
+ * If there is no node under the users position, however, the cursor won't enter the next level – it
96
+ * will stay on the Program node.
97
+ *
98
+ * This might lead to an unexpected context or infinite recursion.
99
+ *
100
+ * We'll therefore adjust the cursor position such that it meets the last node of the AST.
101
+ * `select * from use {}` becomes `select * from use{}`.
102
+ */
92
103
let current_node_kind = cursor. node ( ) . kind ( ) ;
93
-
94
- cursor. goto_first_child_for_byte ( self . position ) ;
104
+ while cursor. goto_first_child_for_byte ( self . position ) . is_none ( ) && self . position > 0 {
105
+ self . position -= 1 ;
106
+ }
95
107
96
108
self . gather_context_from_node ( cursor, current_node_kind) ;
97
109
}
@@ -104,6 +116,12 @@ impl<'a> CompletionContext<'a> {
104
116
let current_node = cursor. node ( ) ;
105
117
let current_node_kind = current_node. kind ( ) ;
106
118
119
+ // prevent infinite recursion – this can happen if we only have a PROGRAM node
120
+ if current_node_kind == previous_node_kind {
121
+ self . ts_node = Some ( current_node) ;
122
+ return ;
123
+ }
124
+
107
125
match previous_node_kind {
108
126
"statement" => self . wrapping_clause_type = current_node_kind. try_into ( ) . ok ( ) ,
109
127
"invocation" => self . is_invocation = true ,
@@ -127,9 +145,14 @@ impl<'a> CompletionContext<'a> {
127
145
self . wrapping_clause_type = "where" . try_into ( ) . ok ( ) ;
128
146
}
129
147
148
+ "keyword_from" => {
149
+ self . wrapping_clause_type = "keyword_from" . try_into ( ) . ok ( ) ;
150
+ }
151
+
130
152
_ => { }
131
153
}
132
154
155
+ // We have arrived at the leaf node
133
156
if current_node. child_count ( ) == 0 {
134
157
self . ts_node = Some ( current_node) ;
135
158
return ;
@@ -142,7 +165,10 @@ impl<'a> CompletionContext<'a> {
142
165
143
166
#[ cfg( test) ]
144
167
mod tests {
145
- use crate :: { context:: CompletionContext , test_helper:: CURSOR_POS } ;
168
+ use crate :: {
169
+ context:: { ClauseType , CompletionContext } ,
170
+ test_helper:: { get_text_and_position, CURSOR_POS } ,
171
+ } ;
146
172
147
173
fn get_tree ( input : & str ) -> tree_sitter:: Tree {
148
174
let mut parser = tree_sitter:: Parser :: new ( ) ;
@@ -182,11 +208,11 @@ mod tests {
182
208
) ,
183
209
] ;
184
210
185
- for ( text, expected_clause) in test_cases {
186
- let position = text. find ( CURSOR_POS ) . unwrap ( ) ;
187
- let text = text. replace ( CURSOR_POS , "" ) ;
211
+ for ( query, expected_clause) in test_cases {
212
+ let ( position, text) = get_text_and_position ( query. as_str ( ) ) ;
188
213
189
214
let tree = get_tree ( text. as_str ( ) ) ;
215
+
190
216
let params = crate :: CompletionParams {
191
217
position : ( position as u32 ) . into ( ) ,
192
218
text : text,
@@ -215,9 +241,8 @@ mod tests {
215
241
( format!( "Select * from u{}sers()" , CURSOR_POS ) , None ) ,
216
242
] ;
217
243
218
- for ( text, expected_schema) in test_cases {
219
- let position = text. find ( CURSOR_POS ) . unwrap ( ) ;
220
- let text = text. replace ( CURSOR_POS , "" ) ;
244
+ for ( query, expected_schema) in test_cases {
245
+ let ( position, text) = get_text_and_position ( query. as_str ( ) ) ;
221
246
222
247
let tree = get_tree ( text. as_str ( ) ) ;
223
248
let params = crate :: CompletionParams {
@@ -250,9 +275,8 @@ mod tests {
250
275
) ,
251
276
] ;
252
277
253
- for ( text, is_invocation) in test_cases {
254
- let position = text. find ( CURSOR_POS ) . unwrap ( ) ;
255
- let text = text. replace ( CURSOR_POS , "" ) ;
278
+ for ( query, is_invocation) in test_cases {
279
+ let ( position, text) = get_text_and_position ( query. as_str ( ) ) ;
256
280
257
281
let tree = get_tree ( text. as_str ( ) ) ;
258
282
let params = crate :: CompletionParams {
@@ -267,4 +291,110 @@ mod tests {
267
291
assert_eq ! ( ctx. is_invocation, is_invocation) ;
268
292
}
269
293
}
294
+
295
+ #[ test]
296
+ fn does_not_fail_on_leading_whitespace ( ) {
297
+ let cases = vec ! [
298
+ format!( "{} select * from" , CURSOR_POS ) ,
299
+ format!( " {} select * from" , CURSOR_POS ) ,
300
+ ] ;
301
+
302
+ for query in cases {
303
+ let ( position, text) = get_text_and_position ( query. as_str ( ) ) ;
304
+
305
+ let tree = get_tree ( text. as_str ( ) ) ;
306
+
307
+ let params = crate :: CompletionParams {
308
+ position : ( position as u32 ) . into ( ) ,
309
+ text : text,
310
+ tree : Some ( & tree) ,
311
+ schema : & pg_schema_cache:: SchemaCache :: new ( ) ,
312
+ } ;
313
+
314
+ let ctx = CompletionContext :: new ( & params) ;
315
+
316
+ let node = ctx. ts_node . map ( |n| n. clone ( ) ) . unwrap ( ) ;
317
+
318
+ assert_eq ! ( ctx. get_ts_node_content( node) , Some ( "select" ) ) ;
319
+
320
+ assert_eq ! (
321
+ ctx. wrapping_clause_type,
322
+ Some ( crate :: context:: ClauseType :: Select )
323
+ ) ;
324
+ }
325
+ }
326
+
327
+ #[ test]
328
+ fn does_not_fail_on_trailing_whitespace ( ) {
329
+ let query = format ! ( "select * from {}" , CURSOR_POS ) ;
330
+
331
+ let ( position, text) = get_text_and_position ( query. as_str ( ) ) ;
332
+
333
+ let tree = get_tree ( text. as_str ( ) ) ;
334
+
335
+ let params = crate :: CompletionParams {
336
+ position : ( position as u32 ) . into ( ) ,
337
+ text : text,
338
+ tree : Some ( & tree) ,
339
+ schema : & pg_schema_cache:: SchemaCache :: new ( ) ,
340
+ } ;
341
+
342
+ let ctx = CompletionContext :: new ( & params) ;
343
+
344
+ let node = ctx. ts_node . map ( |n| n. clone ( ) ) . unwrap ( ) ;
345
+
346
+ assert_eq ! ( ctx. get_ts_node_content( node) , Some ( "from" ) ) ;
347
+ assert_eq ! (
348
+ ctx. wrapping_clause_type,
349
+ Some ( crate :: context:: ClauseType :: From )
350
+ ) ;
351
+ }
352
+
353
+ #[ test]
354
+ fn does_not_fail_with_empty_statements ( ) {
355
+ let query = format ! ( "{}" , CURSOR_POS ) ;
356
+
357
+ let ( position, text) = get_text_and_position ( query. as_str ( ) ) ;
358
+
359
+ let tree = get_tree ( text. as_str ( ) ) ;
360
+
361
+ let params = crate :: CompletionParams {
362
+ position : ( position as u32 ) . into ( ) ,
363
+ text : text,
364
+ tree : Some ( & tree) ,
365
+ schema : & pg_schema_cache:: SchemaCache :: new ( ) ,
366
+ } ;
367
+
368
+ let ctx = CompletionContext :: new ( & params) ;
369
+
370
+ let node = ctx. ts_node . map ( |n| n. clone ( ) ) . unwrap ( ) ;
371
+
372
+ assert_eq ! ( ctx. get_ts_node_content( node) , Some ( "" ) ) ;
373
+ assert_eq ! ( ctx. wrapping_clause_type, None ) ;
374
+ }
375
+
376
+ #[ test]
377
+ fn does_not_fail_on_incomplete_keywords ( ) {
378
+ // Instead of autocompleting "FROM", we'll assume that the user
379
+ // is selecting a certain column name, such as `frozen_account`.
380
+ let query = format ! ( "select * fro{}" , CURSOR_POS ) ;
381
+
382
+ let ( position, text) = get_text_and_position ( query. as_str ( ) ) ;
383
+
384
+ let tree = get_tree ( text. as_str ( ) ) ;
385
+
386
+ let params = crate :: CompletionParams {
387
+ position : ( position as u32 ) . into ( ) ,
388
+ text : text,
389
+ tree : Some ( & tree) ,
390
+ schema : & pg_schema_cache:: SchemaCache :: new ( ) ,
391
+ } ;
392
+
393
+ let ctx = CompletionContext :: new ( & params) ;
394
+
395
+ let node = ctx. ts_node . map ( |n| n. clone ( ) ) . unwrap ( ) ;
396
+
397
+ assert_eq ! ( ctx. get_ts_node_content( node) , Some ( "fro" ) ) ;
398
+ assert_eq ! ( ctx. wrapping_clause_type, Some ( ClauseType :: Select ) ) ;
399
+ }
270
400
}
0 commit comments