Skip to content

Commit c509dd6

Browse files
feat(completions): ts_query package, column autocompletion (#168)
* initial commit * jeez * jeez * not necessary * beautiful * fixie fixie * randomly change score until tests do what i want * i like the syntax * format TOML * use lazyLock
1 parent 603a3c9 commit c509dd6

File tree

21 files changed

+657
-75
lines changed

21 files changed

+657
-75
lines changed

Cargo.lock

+10
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ pg_schema_cache = { path = "./crates/pg_schema_cache", version = "0.0.
7575
pg_statement_splitter = { path = "./crates/pg_statement_splitter", version = "0.0.0" }
7676
pg_syntax = { path = "./crates/pg_syntax", version = "0.0.0" }
7777
pg_text_edit = { path = "./crates/pg_text_edit", version = "0.0.0" }
78+
pg_treesitter_queries = { path = "./crates/pg_treesitter_queries", version = "0.0.0" }
7879
pg_type_resolver = { path = "./crates/pg_type_resolver", version = "0.0.0" }
7980
pg_typecheck = { path = "./crates/pg_typecheck", version = "0.0.0" }
8081
pg_workspace = { path = "./crates/pg_workspace", version = "0.0.0" }

crates/pg_completions/Cargo.toml

+6-5
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,12 @@ async-std = "1.12.0"
1616

1717
text-size.workspace = true
1818

19-
serde = { workspace = true, features = ["derive"] }
20-
serde_json = { workspace = true }
21-
pg_schema_cache.workspace = true
22-
tree-sitter.workspace = true
23-
tree_sitter_sql.workspace = true
19+
pg_schema_cache.workspace = true
20+
pg_treesitter_queries.workspace = true
21+
serde = { workspace = true, features = ["derive"] }
22+
serde_json = { workspace = true }
23+
tree-sitter.workspace = true
24+
tree_sitter_sql.workspace = true
2425

2526
sqlx.workspace = true
2627

crates/pg_completions/src/complete.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use crate::{
55
builder::CompletionBuilder,
66
context::CompletionContext,
77
item::CompletionItem,
8-
providers::{complete_functions, complete_tables},
8+
providers::{complete_columns, complete_functions, complete_tables},
99
};
1010

1111
pub const LIMIT: usize = 50;
@@ -38,6 +38,7 @@ pub fn complete(params: CompletionParams) -> CompletionResult {
3838

3939
complete_tables(&ctx, &mut builder);
4040
complete_functions(&ctx, &mut builder);
41+
complete_columns(&ctx, &mut builder);
4142

4243
builder.finish()
4344
}

crates/pg_completions/src/context.rs

+66-17
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
1+
use std::collections::{HashMap, HashSet};
2+
13
use pg_schema_cache::SchemaCache;
4+
use pg_treesitter_queries::{
5+
queries::{self, QueryResult},
6+
TreeSitterQueriesExecutor,
7+
};
28

39
use crate::CompletionParams;
410

@@ -52,6 +58,9 @@ pub(crate) struct CompletionContext<'a> {
5258
pub schema_name: Option<String>,
5359
pub wrapping_clause_type: Option<ClauseType>,
5460
pub is_invocation: bool,
61+
pub wrapping_statement_range: Option<tree_sitter::Range>,
62+
63+
pub mentioned_relations: HashMap<Option<String>, HashSet<String>>,
5564
}
5665

5766
impl<'a> CompletionContext<'a> {
@@ -61,18 +70,56 @@ impl<'a> CompletionContext<'a> {
6170
text: &params.text,
6271
schema_cache: params.schema,
6372
position: usize::from(params.position),
64-
6573
ts_node: None,
6674
schema_name: None,
6775
wrapping_clause_type: None,
76+
wrapping_statement_range: None,
6877
is_invocation: false,
78+
mentioned_relations: HashMap::new(),
6979
};
7080

7181
ctx.gather_tree_context();
82+
ctx.gather_info_from_ts_queries();
7283

7384
ctx
7485
}
7586

87+
fn gather_info_from_ts_queries(&mut self) {
88+
let tree = match self.tree.as_ref() {
89+
None => return,
90+
Some(t) => t,
91+
};
92+
93+
let stmt_range = self.wrapping_statement_range.as_ref();
94+
let sql = self.text;
95+
96+
let mut executor = TreeSitterQueriesExecutor::new(tree.root_node(), sql);
97+
98+
executor.add_query_results::<queries::RelationMatch>();
99+
100+
for relation_match in executor.get_iter(stmt_range) {
101+
match relation_match {
102+
QueryResult::Relation(r) => {
103+
let schema_name = r.get_schema(sql);
104+
let table_name = r.get_table(sql);
105+
106+
let current = self.mentioned_relations.get_mut(&schema_name);
107+
108+
match current {
109+
Some(c) => {
110+
c.insert(table_name);
111+
}
112+
None => {
113+
let mut new = HashSet::new();
114+
new.insert(table_name);
115+
self.mentioned_relations.insert(schema_name, new);
116+
}
117+
};
118+
}
119+
};
120+
}
121+
}
122+
76123
pub fn get_ts_node_content(&self, ts_node: tree_sitter::Node<'a>) -> Option<&'a str> {
77124
let source = self.text;
78125
match ts_node.utf8_text(source.as_bytes()) {
@@ -100,36 +147,38 @@ impl<'a> CompletionContext<'a> {
100147
* We'll therefore adjust the cursor position such that it meets the last node of the AST.
101148
* `select * from use {}` becomes `select * from use{}`.
102149
*/
103-
let current_node_kind = cursor.node().kind();
150+
let current_node = cursor.node();
104151
while cursor.goto_first_child_for_byte(self.position).is_none() && self.position > 0 {
105152
self.position -= 1;
106153
}
107154

108-
self.gather_context_from_node(cursor, current_node_kind);
155+
self.gather_context_from_node(cursor, current_node);
109156
}
110157

111158
fn gather_context_from_node(
112159
&mut self,
113160
mut cursor: tree_sitter::TreeCursor<'a>,
114-
previous_node_kind: &str,
161+
previous_node: tree_sitter::Node<'a>,
115162
) {
116163
let current_node = cursor.node();
117-
let current_node_kind = current_node.kind();
118164

119165
// prevent infinite recursion – this can happen if we only have a PROGRAM node
120-
if current_node_kind == previous_node_kind {
166+
if current_node.kind() == previous_node.kind() {
121167
self.ts_node = Some(current_node);
122168
return;
123169
}
124170

125-
match previous_node_kind {
126-
"statement" => self.wrapping_clause_type = current_node_kind.try_into().ok(),
171+
match previous_node.kind() {
172+
"statement" | "subquery" => {
173+
self.wrapping_clause_type = current_node.kind().try_into().ok();
174+
self.wrapping_statement_range = Some(previous_node.range());
175+
}
127176
"invocation" => self.is_invocation = true,
128177

129178
_ => {}
130179
}
131180

132-
match current_node_kind {
181+
match current_node.kind() {
133182
"object_reference" => {
134183
let txt = self.get_ts_node_content(current_node);
135184
if let Some(txt) = txt {
@@ -159,7 +208,7 @@ impl<'a> CompletionContext<'a> {
159208
}
160209

161210
cursor.goto_first_child_for_byte(self.position);
162-
self.gather_context_from_node(cursor, current_node_kind);
211+
self.gather_context_from_node(cursor, current_node);
163212
}
164213
}
165214

@@ -209,7 +258,7 @@ mod tests {
209258
];
210259

211260
for (query, expected_clause) in test_cases {
212-
let (position, text) = get_text_and_position(query.as_str());
261+
let (position, text) = get_text_and_position(query.as_str().into());
213262

214263
let tree = get_tree(text.as_str());
215264

@@ -242,7 +291,7 @@ mod tests {
242291
];
243292

244293
for (query, expected_schema) in test_cases {
245-
let (position, text) = get_text_and_position(query.as_str());
294+
let (position, text) = get_text_and_position(query.as_str().into());
246295

247296
let tree = get_tree(text.as_str());
248297
let params = crate::CompletionParams {
@@ -276,7 +325,7 @@ mod tests {
276325
];
277326

278327
for (query, is_invocation) in test_cases {
279-
let (position, text) = get_text_and_position(query.as_str());
328+
let (position, text) = get_text_and_position(query.as_str().into());
280329

281330
let tree = get_tree(text.as_str());
282331
let params = crate::CompletionParams {
@@ -300,7 +349,7 @@ mod tests {
300349
];
301350

302351
for query in cases {
303-
let (position, text) = get_text_and_position(query.as_str());
352+
let (position, text) = get_text_and_position(query.as_str().into());
304353

305354
let tree = get_tree(text.as_str());
306355

@@ -328,7 +377,7 @@ mod tests {
328377
fn does_not_fail_on_trailing_whitespace() {
329378
let query = format!("select * from {}", CURSOR_POS);
330379

331-
let (position, text) = get_text_and_position(query.as_str());
380+
let (position, text) = get_text_and_position(query.as_str().into());
332381

333382
let tree = get_tree(text.as_str());
334383

@@ -354,7 +403,7 @@ mod tests {
354403
fn does_not_fail_with_empty_statements() {
355404
let query = format!("{}", CURSOR_POS);
356405

357-
let (position, text) = get_text_and_position(query.as_str());
406+
let (position, text) = get_text_and_position(query.as_str().into());
358407

359408
let tree = get_tree(text.as_str());
360409

@@ -379,7 +428,7 @@ mod tests {
379428
// is selecting a certain column name, such as `frozen_account`.
380429
let query = format!("select * fro{}", CURSOR_POS);
381430

382-
let (position, text) = get_text_and_position(query.as_str());
431+
let (position, text) = get_text_and_position(query.as_str().into());
383432

384433
let tree = get_tree(text.as_str());
385434

crates/pg_completions/src/item.rs

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use serde::{Deserialize, Serialize};
44
pub enum CompletionItemKind {
55
Table,
66
Function,
7+
Column,
78
}
89

910
#[derive(Debug, Serialize, Deserialize)]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
use crate::{
2+
builder::CompletionBuilder, context::CompletionContext, relevance::CompletionRelevanceData,
3+
CompletionItem, CompletionItemKind,
4+
};
5+
6+
pub fn complete_columns(ctx: &CompletionContext, builder: &mut CompletionBuilder) {
7+
let available_columns = &ctx.schema_cache.columns;
8+
9+
for col in available_columns {
10+
let item = CompletionItem {
11+
label: col.name.clone(),
12+
score: CompletionRelevanceData::Column(col).get_score(ctx),
13+
description: format!("Table: {}.{}", col.schema_name, col.table_name),
14+
preselected: false,
15+
kind: CompletionItemKind::Column,
16+
};
17+
18+
builder.add_item(item);
19+
}
20+
}
21+
22+
#[cfg(test)]
23+
mod tests {
24+
use crate::{
25+
complete,
26+
test_helper::{get_test_deps, get_test_params, InputQuery, CURSOR_POS},
27+
CompletionItem,
28+
};
29+
30+
struct TestCase {
31+
query: String,
32+
message: &'static str,
33+
label: &'static str,
34+
description: &'static str,
35+
}
36+
37+
impl TestCase {
38+
fn get_input_query(&self) -> InputQuery {
39+
let strs: Vec<&str> = self.query.split_whitespace().collect();
40+
strs.join(" ").as_str().into()
41+
}
42+
}
43+
44+
#[tokio::test]
45+
async fn completes_columns() {
46+
let setup = r#"
47+
create schema private;
48+
49+
create table public.users (
50+
id serial primary key,
51+
name text
52+
);
53+
54+
create table public.audio_books (
55+
id serial primary key,
56+
narrator text
57+
);
58+
59+
create table private.audio_books (
60+
id serial primary key,
61+
narrator_id text
62+
);
63+
"#;
64+
65+
let queries: Vec<TestCase> = vec![
66+
TestCase {
67+
message: "correctly prefers the columns of present tables",
68+
query: format!(r#"select na{} from public.audio_books;"#, CURSOR_POS),
69+
label: "narrator",
70+
description: "Table: public.audio_books",
71+
},
72+
TestCase {
73+
message: "correctly handles nested queries",
74+
query: format!(
75+
r#"
76+
select
77+
*
78+
from (
79+
select id, na{}
80+
from private.audio_books
81+
) as subquery
82+
join public.users u
83+
on u.id = subquery.id;
84+
"#,
85+
CURSOR_POS
86+
),
87+
label: "narrator_id",
88+
description: "Table: private.audio_books",
89+
},
90+
TestCase {
91+
message: "works without a schema",
92+
query: format!(r#"select na{} from users;"#, CURSOR_POS),
93+
label: "name",
94+
description: "Table: public.users",
95+
},
96+
];
97+
98+
for q in queries {
99+
let (tree, cache) = get_test_deps(setup, q.get_input_query()).await;
100+
let params = get_test_params(&tree, &cache, q.get_input_query());
101+
let results = complete(params);
102+
103+
let CompletionItem {
104+
label, description, ..
105+
} = results
106+
.into_iter()
107+
.next()
108+
.expect("Should return at least one completion item");
109+
110+
assert_eq!(label, q.label, "{}", q.message);
111+
assert_eq!(description, q.description, "{}", q.message);
112+
}
113+
}
114+
}

0 commit comments

Comments
 (0)