Skip to content

feat(completions): ts_query package, column autocompletion #168

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ pg_schema_cache = { path = "./crates/pg_schema_cache", version = "0.0.
pg_statement_splitter = { path = "./crates/pg_statement_splitter", version = "0.0.0" }
pg_syntax = { path = "./crates/pg_syntax", version = "0.0.0" }
pg_text_edit = { path = "./crates/pg_text_edit", version = "0.0.0" }
pg_treesitter_queries = { path = "./crates/pg_treesitter_queries", version = "0.0.0" }
pg_type_resolver = { path = "./crates/pg_type_resolver", version = "0.0.0" }
pg_typecheck = { path = "./crates/pg_typecheck", version = "0.0.0" }
pg_workspace = { path = "./crates/pg_workspace", version = "0.0.0" }
Expand Down
1 change: 1 addition & 0 deletions crates/pg_completions/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ serde_json = { workspace = true }
pg_schema_cache.workspace = true
tree-sitter.workspace = true
tree_sitter_sql.workspace = true
pg_treesitter_queries.workspace = true

sqlx.workspace = true

Expand Down
3 changes: 2 additions & 1 deletion crates/pg_completions/src/complete.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::{
builder::CompletionBuilder,
context::CompletionContext,
item::CompletionItem,
providers::{complete_functions, complete_tables},
providers::{complete_columns, complete_functions, complete_tables},
};

pub const LIMIT: usize = 50;
Expand Down Expand Up @@ -38,6 +38,7 @@ pub fn complete(params: CompletionParams) -> CompletionResult {

complete_tables(&ctx, &mut builder);
complete_functions(&ctx, &mut builder);
complete_columns(&ctx, &mut builder);

builder.finish()
}
83 changes: 66 additions & 17 deletions crates/pg_completions/src/context.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
use std::collections::{HashMap, HashSet};

use pg_schema_cache::SchemaCache;
use pg_treesitter_queries::{
queries::{self, QueryResult},
TreeSitterQueriesExecutor,
};

use crate::CompletionParams;

Expand Down Expand Up @@ -52,6 +58,9 @@ pub(crate) struct CompletionContext<'a> {
pub schema_name: Option<String>,
pub wrapping_clause_type: Option<ClauseType>,
pub is_invocation: bool,
pub wrapping_statement_range: Option<tree_sitter::Range>,

pub mentioned_relations: HashMap<Option<String>, HashSet<String>>,
}

impl<'a> CompletionContext<'a> {
Expand All @@ -61,18 +70,56 @@ impl<'a> CompletionContext<'a> {
text: &params.text,
schema_cache: params.schema,
position: usize::from(params.position),

ts_node: None,
schema_name: None,
wrapping_clause_type: None,
wrapping_statement_range: None,
is_invocation: false,
mentioned_relations: HashMap::new(),
};

ctx.gather_tree_context();
ctx.gather_info_from_ts_queries();

ctx
}

fn gather_info_from_ts_queries(&mut self) {
let tree = match self.tree.as_ref() {
None => return,
Some(t) => t,
};

let stmt_range = self.wrapping_statement_range.as_ref();
let sql = self.text;

let mut executor = TreeSitterQueriesExecutor::new(tree.root_node(), sql);

executor.add_query_results::<queries::RelationMatch>();

for relation_match in executor.get_iter(stmt_range) {
match relation_match {
QueryResult::Relation(r) => {
let schema_name = r.get_schema(sql);
let table_name = r.get_table(sql);

let current = self.mentioned_relations.get_mut(&schema_name);

match current {
Some(c) => {
c.insert(table_name);
}
None => {
let mut new = HashSet::new();
new.insert(table_name);
self.mentioned_relations.insert(schema_name, new);
}
};
}
};
}
}

pub fn get_ts_node_content(&self, ts_node: tree_sitter::Node<'a>) -> Option<&'a str> {
let source = self.text;
match ts_node.utf8_text(source.as_bytes()) {
Expand Down Expand Up @@ -100,36 +147,38 @@ impl<'a> CompletionContext<'a> {
* We'll therefore adjust the cursor position such that it meets the last node of the AST.
* `select * from use {}` becomes `select * from use{}`.
*/
let current_node_kind = cursor.node().kind();
let current_node = cursor.node();
while cursor.goto_first_child_for_byte(self.position).is_none() && self.position > 0 {
self.position -= 1;
}

self.gather_context_from_node(cursor, current_node_kind);
self.gather_context_from_node(cursor, current_node);
}

fn gather_context_from_node(
&mut self,
mut cursor: tree_sitter::TreeCursor<'a>,
previous_node_kind: &str,
previous_node: tree_sitter::Node<'a>,
) {
let current_node = cursor.node();
let current_node_kind = current_node.kind();

// prevent infinite recursion – this can happen if we only have a PROGRAM node
if current_node_kind == previous_node_kind {
if current_node.kind() == previous_node.kind() {
self.ts_node = Some(current_node);
return;
}

match previous_node_kind {
"statement" => self.wrapping_clause_type = current_node_kind.try_into().ok(),
match previous_node.kind() {
"statement" | "subquery" => {
self.wrapping_clause_type = current_node.kind().try_into().ok();
self.wrapping_statement_range = Some(previous_node.range());
}
"invocation" => self.is_invocation = true,

_ => {}
}

match current_node_kind {
match current_node.kind() {
"object_reference" => {
let txt = self.get_ts_node_content(current_node);
if let Some(txt) = txt {
Expand Down Expand Up @@ -159,7 +208,7 @@ impl<'a> CompletionContext<'a> {
}

cursor.goto_first_child_for_byte(self.position);
self.gather_context_from_node(cursor, current_node_kind);
self.gather_context_from_node(cursor, current_node);
}
}

Expand Down Expand Up @@ -209,7 +258,7 @@ mod tests {
];

for (query, expected_clause) in test_cases {
let (position, text) = get_text_and_position(query.as_str());
let (position, text) = get_text_and_position(query.as_str().into());

let tree = get_tree(text.as_str());

Expand Down Expand Up @@ -242,7 +291,7 @@ mod tests {
];

for (query, expected_schema) in test_cases {
let (position, text) = get_text_and_position(query.as_str());
let (position, text) = get_text_and_position(query.as_str().into());

let tree = get_tree(text.as_str());
let params = crate::CompletionParams {
Expand Down Expand Up @@ -276,7 +325,7 @@ mod tests {
];

for (query, is_invocation) in test_cases {
let (position, text) = get_text_and_position(query.as_str());
let (position, text) = get_text_and_position(query.as_str().into());

let tree = get_tree(text.as_str());
let params = crate::CompletionParams {
Expand All @@ -300,7 +349,7 @@ mod tests {
];

for query in cases {
let (position, text) = get_text_and_position(query.as_str());
let (position, text) = get_text_and_position(query.as_str().into());

let tree = get_tree(text.as_str());

Expand Down Expand Up @@ -328,7 +377,7 @@ mod tests {
fn does_not_fail_on_trailing_whitespace() {
let query = format!("select * from {}", CURSOR_POS);

let (position, text) = get_text_and_position(query.as_str());
let (position, text) = get_text_and_position(query.as_str().into());

let tree = get_tree(text.as_str());

Expand All @@ -354,7 +403,7 @@ mod tests {
fn does_not_fail_with_empty_statements() {
let query = format!("{}", CURSOR_POS);

let (position, text) = get_text_and_position(query.as_str());
let (position, text) = get_text_and_position(query.as_str().into());

let tree = get_tree(text.as_str());

Expand All @@ -379,7 +428,7 @@ mod tests {
// is selecting a certain column name, such as `frozen_account`.
let query = format!("select * fro{}", CURSOR_POS);

let (position, text) = get_text_and_position(query.as_str());
let (position, text) = get_text_and_position(query.as_str().into());

let tree = get_tree(text.as_str());

Expand Down
1 change: 1 addition & 0 deletions crates/pg_completions/src/item.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use serde::{Deserialize, Serialize};
pub enum CompletionItemKind {
Table,
Function,
Column,
}

#[derive(Debug, Serialize, Deserialize)]
Expand Down
114 changes: 114 additions & 0 deletions crates/pg_completions/src/providers/columns.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
use crate::{
builder::CompletionBuilder, context::CompletionContext, relevance::CompletionRelevanceData,
CompletionItem, CompletionItemKind,
};

pub fn complete_columns(ctx: &CompletionContext, builder: &mut CompletionBuilder) {
let available_columns = &ctx.schema_cache.columns;

for col in available_columns {
let item = CompletionItem {
label: col.name.clone(),
score: CompletionRelevanceData::Column(col).get_score(ctx),
description: format!("Table: {}.{}", col.schema_name, col.table_name),
preselected: false,
kind: CompletionItemKind::Column,
};

builder.add_item(item);
}
}

#[cfg(test)]
mod tests {
use crate::{
complete,
test_helper::{get_test_deps, get_test_params, InputQuery, CURSOR_POS},
CompletionItem,
};

struct TestCase {
query: String,
message: &'static str,
label: &'static str,
description: &'static str,
}

impl TestCase {
fn get_input_query(&self) -> InputQuery {
let strs: Vec<&str> = self.query.split_whitespace().collect();
strs.join(" ").as_str().into()
}
}

#[tokio::test]
async fn completes_columns() {
let setup = r#"
create schema private;

create table public.users (
id serial primary key,
name text
);

create table public.audio_books (
id serial primary key,
narrator text
);

create table private.audio_books (
id serial primary key,
narrator_id text
);
"#;

let queries: Vec<TestCase> = vec![
TestCase {
message: "correctly prefers the columns of present tables",
query: format!(r#"select na{} from public.audio_books;"#, CURSOR_POS),
label: "narrator",
description: "Table: public.audio_books",
},
TestCase {
message: "correctly handles nested queries",
query: format!(
r#"
select
*
from (
select id, na{}
from private.audio_books
) as subquery
join public.users u
on u.id = subquery.id;
"#,
CURSOR_POS
),
label: "narrator_id",
description: "Table: private.audio_books",
},
TestCase {
message: "works without a schema",
query: format!(r#"select na{} from users;"#, CURSOR_POS),
label: "name",
description: "Table: public.users",
},
];

for q in queries {
let (tree, cache) = get_test_deps(setup, q.get_input_query()).await;
let params = get_test_params(&tree, &cache, q.get_input_query());
let results = complete(params);

let CompletionItem {
label, description, ..
} = results
.into_iter()
.next()
.expect("Should return at least one completion item");

assert_eq!(label, q.label, "{}", q.message);
assert_eq!(description, q.description, "{}", q.message);
}
}
}
Loading
Loading