1
+ use std:: collections:: { HashMap , HashSet } ;
2
+
1
3
use pg_schema_cache:: SchemaCache ;
4
+ use pg_treesitter_queries:: {
5
+ queries:: { self , QueryResult } ,
6
+ TreeSitterQueriesExecutor ,
7
+ } ;
2
8
3
9
use crate :: CompletionParams ;
4
10
@@ -52,6 +58,9 @@ pub(crate) struct CompletionContext<'a> {
52
58
pub schema_name : Option < String > ,
53
59
pub wrapping_clause_type : Option < ClauseType > ,
54
60
pub is_invocation : bool ,
61
+ pub wrapping_statement_range : Option < tree_sitter:: Range > ,
62
+
63
+ pub mentioned_relations : HashMap < Option < String > , HashSet < String > > ,
55
64
}
56
65
57
66
impl < ' a > CompletionContext < ' a > {
@@ -61,18 +70,56 @@ impl<'a> CompletionContext<'a> {
61
70
text : & params. text ,
62
71
schema_cache : params. schema ,
63
72
position : usize:: from ( params. position ) ,
64
-
65
73
ts_node : None ,
66
74
schema_name : None ,
67
75
wrapping_clause_type : None ,
76
+ wrapping_statement_range : None ,
68
77
is_invocation : false ,
78
+ mentioned_relations : HashMap :: new ( ) ,
69
79
} ;
70
80
71
81
ctx. gather_tree_context ( ) ;
82
+ ctx. gather_info_from_ts_queries ( ) ;
72
83
73
84
ctx
74
85
}
75
86
87
+ fn gather_info_from_ts_queries ( & mut self ) {
88
+ let tree = match self . tree . as_ref ( ) {
89
+ None => return ,
90
+ Some ( t) => t,
91
+ } ;
92
+
93
+ let stmt_range = self . wrapping_statement_range . as_ref ( ) ;
94
+ let sql = self . text ;
95
+
96
+ let mut executor = TreeSitterQueriesExecutor :: new ( tree. root_node ( ) , sql) ;
97
+
98
+ executor. add_query_results :: < queries:: RelationMatch > ( ) ;
99
+
100
+ for relation_match in executor. get_iter ( stmt_range) {
101
+ match relation_match {
102
+ QueryResult :: Relation ( r) => {
103
+ let schema_name = r. get_schema ( sql) ;
104
+ let table_name = r. get_table ( sql) ;
105
+
106
+ let current = self . mentioned_relations . get_mut ( & schema_name) ;
107
+
108
+ match current {
109
+ Some ( c) => {
110
+ c. insert ( table_name) ;
111
+ }
112
+ None => {
113
+ let mut new = HashSet :: new ( ) ;
114
+ new. insert ( table_name) ;
115
+ self . mentioned_relations . insert ( schema_name, new) ;
116
+ }
117
+ } ;
118
+ }
119
+ } ;
120
+ }
121
+ }
122
+
76
123
pub fn get_ts_node_content ( & self , ts_node : tree_sitter:: Node < ' a > ) -> Option < & ' a str > {
77
124
let source = self . text ;
78
125
match ts_node. utf8_text ( source. as_bytes ( ) ) {
@@ -100,36 +147,38 @@ impl<'a> CompletionContext<'a> {
100
147
* We'll therefore adjust the cursor position such that it meets the last node of the AST.
101
148
* `select * from use {}` becomes `select * from use{}`.
102
149
*/
103
- let current_node_kind = cursor. node ( ) . kind ( ) ;
150
+ let current_node = cursor. node ( ) ;
104
151
while cursor. goto_first_child_for_byte ( self . position ) . is_none ( ) && self . position > 0 {
105
152
self . position -= 1 ;
106
153
}
107
154
108
- self . gather_context_from_node ( cursor, current_node_kind ) ;
155
+ self . gather_context_from_node ( cursor, current_node ) ;
109
156
}
110
157
111
158
fn gather_context_from_node (
112
159
& mut self ,
113
160
mut cursor : tree_sitter:: TreeCursor < ' a > ,
114
- previous_node_kind : & str ,
161
+ previous_node : tree_sitter :: Node < ' a > ,
115
162
) {
116
163
let current_node = cursor. node ( ) ;
117
- let current_node_kind = current_node. kind ( ) ;
118
164
119
165
// prevent infinite recursion – this can happen if we only have a PROGRAM node
120
- if current_node_kind == previous_node_kind {
166
+ if current_node . kind ( ) == previous_node . kind ( ) {
121
167
self . ts_node = Some ( current_node) ;
122
168
return ;
123
169
}
124
170
125
- match previous_node_kind {
126
- "statement" => self . wrapping_clause_type = current_node_kind. try_into ( ) . ok ( ) ,
171
+ match previous_node. kind ( ) {
172
+ "statement" | "subquery" => {
173
+ self . wrapping_clause_type = current_node. kind ( ) . try_into ( ) . ok ( ) ;
174
+ self . wrapping_statement_range = Some ( previous_node. range ( ) ) ;
175
+ }
127
176
"invocation" => self . is_invocation = true ,
128
177
129
178
_ => { }
130
179
}
131
180
132
- match current_node_kind {
181
+ match current_node . kind ( ) {
133
182
"object_reference" => {
134
183
let txt = self . get_ts_node_content ( current_node) ;
135
184
if let Some ( txt) = txt {
@@ -159,7 +208,7 @@ impl<'a> CompletionContext<'a> {
159
208
}
160
209
161
210
cursor. goto_first_child_for_byte ( self . position ) ;
162
- self . gather_context_from_node ( cursor, current_node_kind ) ;
211
+ self . gather_context_from_node ( cursor, current_node ) ;
163
212
}
164
213
}
165
214
@@ -209,7 +258,7 @@ mod tests {
209
258
] ;
210
259
211
260
for ( query, expected_clause) in test_cases {
212
- let ( position, text) = get_text_and_position ( query. as_str ( ) ) ;
261
+ let ( position, text) = get_text_and_position ( query. as_str ( ) . into ( ) ) ;
213
262
214
263
let tree = get_tree ( text. as_str ( ) ) ;
215
264
@@ -242,7 +291,7 @@ mod tests {
242
291
] ;
243
292
244
293
for ( query, expected_schema) in test_cases {
245
- let ( position, text) = get_text_and_position ( query. as_str ( ) ) ;
294
+ let ( position, text) = get_text_and_position ( query. as_str ( ) . into ( ) ) ;
246
295
247
296
let tree = get_tree ( text. as_str ( ) ) ;
248
297
let params = crate :: CompletionParams {
@@ -276,7 +325,7 @@ mod tests {
276
325
] ;
277
326
278
327
for ( query, is_invocation) in test_cases {
279
- let ( position, text) = get_text_and_position ( query. as_str ( ) ) ;
328
+ let ( position, text) = get_text_and_position ( query. as_str ( ) . into ( ) ) ;
280
329
281
330
let tree = get_tree ( text. as_str ( ) ) ;
282
331
let params = crate :: CompletionParams {
@@ -300,7 +349,7 @@ mod tests {
300
349
] ;
301
350
302
351
for query in cases {
303
- let ( position, text) = get_text_and_position ( query. as_str ( ) ) ;
352
+ let ( position, text) = get_text_and_position ( query. as_str ( ) . into ( ) ) ;
304
353
305
354
let tree = get_tree ( text. as_str ( ) ) ;
306
355
@@ -328,7 +377,7 @@ mod tests {
328
377
fn does_not_fail_on_trailing_whitespace ( ) {
329
378
let query = format ! ( "select * from {}" , CURSOR_POS ) ;
330
379
331
- let ( position, text) = get_text_and_position ( query. as_str ( ) ) ;
380
+ let ( position, text) = get_text_and_position ( query. as_str ( ) . into ( ) ) ;
332
381
333
382
let tree = get_tree ( text. as_str ( ) ) ;
334
383
@@ -354,7 +403,7 @@ mod tests {
354
403
fn does_not_fail_with_empty_statements ( ) {
355
404
let query = format ! ( "{}" , CURSOR_POS ) ;
356
405
357
- let ( position, text) = get_text_and_position ( query. as_str ( ) ) ;
406
+ let ( position, text) = get_text_and_position ( query. as_str ( ) . into ( ) ) ;
358
407
359
408
let tree = get_tree ( text. as_str ( ) ) ;
360
409
@@ -379,7 +428,7 @@ mod tests {
379
428
// is selecting a certain column name, such as `frozen_account`.
380
429
let query = format ! ( "select * fro{}" , CURSOR_POS ) ;
381
430
382
- let ( position, text) = get_text_and_position ( query. as_str ( ) ) ;
431
+ let ( position, text) = get_text_and_position ( query. as_str ( ) . into ( ) ) ;
383
432
384
433
let tree = get_tree ( text. as_str ( ) ) ;
385
434
0 commit comments