diff --git a/Cargo.lock b/Cargo.lock index 84fac274..f36aea5e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2313,6 +2313,7 @@ dependencies = [ "async-std", "pg_schema_cache", "pg_test_utils", + "pg_treesitter_queries", "serde", "serde_json", "sqlx", @@ -2668,6 +2669,15 @@ dependencies = [ "text-size", ] +[[package]] +name = "pg_treesitter_queries" +version = "0.0.0" +dependencies = [ + "clap", + "tree-sitter", + "tree_sitter_sql", +] + [[package]] name = "pg_type_resolver" version = "0.0.0" diff --git a/Cargo.toml b/Cargo.toml index 54a18bd8..5b6fb00a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -75,6 +75,7 @@ pg_schema_cache = { path = "./crates/pg_schema_cache", version = "0.0. pg_statement_splitter = { path = "./crates/pg_statement_splitter", version = "0.0.0" } pg_syntax = { path = "./crates/pg_syntax", version = "0.0.0" } pg_text_edit = { path = "./crates/pg_text_edit", version = "0.0.0" } +pg_treesitter_queries = { path = "./crates/pg_treesitter_queries", version = "0.0.0" } pg_type_resolver = { path = "./crates/pg_type_resolver", version = "0.0.0" } pg_typecheck = { path = "./crates/pg_typecheck", version = "0.0.0" } pg_workspace = { path = "./crates/pg_workspace", version = "0.0.0" } diff --git a/crates/pg_completions/Cargo.toml b/crates/pg_completions/Cargo.toml index c1cf8afe..140ef910 100644 --- a/crates/pg_completions/Cargo.toml +++ b/crates/pg_completions/Cargo.toml @@ -16,11 +16,12 @@ async-std = "1.12.0" text-size.workspace = true -serde = { workspace = true, features = ["derive"] } -serde_json = { workspace = true } -pg_schema_cache.workspace = true -tree-sitter.workspace = true -tree_sitter_sql.workspace = true +pg_schema_cache.workspace = true +pg_treesitter_queries.workspace = true +serde = { workspace = true, features = ["derive"] } +serde_json = { workspace = true } +tree-sitter.workspace = true +tree_sitter_sql.workspace = true sqlx.workspace = true diff --git a/crates/pg_completions/src/complete.rs b/crates/pg_completions/src/complete.rs index fd417f46..c45d01ac 100644 --- a/crates/pg_completions/src/complete.rs +++ b/crates/pg_completions/src/complete.rs @@ -5,7 +5,7 @@ use crate::{ builder::CompletionBuilder, context::CompletionContext, item::CompletionItem, - providers::{complete_functions, complete_tables}, + providers::{complete_columns, complete_functions, complete_tables}, }; pub const LIMIT: usize = 50; @@ -38,6 +38,7 @@ pub fn complete(params: CompletionParams) -> CompletionResult { complete_tables(&ctx, &mut builder); complete_functions(&ctx, &mut builder); + complete_columns(&ctx, &mut builder); builder.finish() } diff --git a/crates/pg_completions/src/context.rs b/crates/pg_completions/src/context.rs index 912a7cec..a5fb0c6b 100644 --- a/crates/pg_completions/src/context.rs +++ b/crates/pg_completions/src/context.rs @@ -1,4 +1,10 @@ +use std::collections::{HashMap, HashSet}; + use pg_schema_cache::SchemaCache; +use pg_treesitter_queries::{ + queries::{self, QueryResult}, + TreeSitterQueriesExecutor, +}; use crate::CompletionParams; @@ -52,6 +58,9 @@ pub(crate) struct CompletionContext<'a> { pub schema_name: Option, pub wrapping_clause_type: Option, pub is_invocation: bool, + pub wrapping_statement_range: Option, + + pub mentioned_relations: HashMap, HashSet>, } impl<'a> CompletionContext<'a> { @@ -61,18 +70,56 @@ impl<'a> CompletionContext<'a> { text: ¶ms.text, schema_cache: params.schema, position: usize::from(params.position), - ts_node: None, schema_name: None, wrapping_clause_type: None, + wrapping_statement_range: None, is_invocation: false, + mentioned_relations: HashMap::new(), }; ctx.gather_tree_context(); + ctx.gather_info_from_ts_queries(); ctx } + fn gather_info_from_ts_queries(&mut self) { + let tree = match self.tree.as_ref() { + None => return, + Some(t) => t, + }; + + let stmt_range = self.wrapping_statement_range.as_ref(); + let sql = self.text; + + let mut executor = TreeSitterQueriesExecutor::new(tree.root_node(), sql); + + executor.add_query_results::(); + + for relation_match in executor.get_iter(stmt_range) { + match relation_match { + QueryResult::Relation(r) => { + let schema_name = r.get_schema(sql); + let table_name = r.get_table(sql); + + let current = self.mentioned_relations.get_mut(&schema_name); + + match current { + Some(c) => { + c.insert(table_name); + } + None => { + let mut new = HashSet::new(); + new.insert(table_name); + self.mentioned_relations.insert(schema_name, new); + } + }; + } + }; + } + } + pub fn get_ts_node_content(&self, ts_node: tree_sitter::Node<'a>) -> Option<&'a str> { let source = self.text; match ts_node.utf8_text(source.as_bytes()) { @@ -100,36 +147,38 @@ impl<'a> CompletionContext<'a> { * We'll therefore adjust the cursor position such that it meets the last node of the AST. * `select * from use {}` becomes `select * from use{}`. */ - let current_node_kind = cursor.node().kind(); + let current_node = cursor.node(); while cursor.goto_first_child_for_byte(self.position).is_none() && self.position > 0 { self.position -= 1; } - self.gather_context_from_node(cursor, current_node_kind); + self.gather_context_from_node(cursor, current_node); } fn gather_context_from_node( &mut self, mut cursor: tree_sitter::TreeCursor<'a>, - previous_node_kind: &str, + previous_node: tree_sitter::Node<'a>, ) { let current_node = cursor.node(); - let current_node_kind = current_node.kind(); // prevent infinite recursion – this can happen if we only have a PROGRAM node - if current_node_kind == previous_node_kind { + if current_node.kind() == previous_node.kind() { self.ts_node = Some(current_node); return; } - match previous_node_kind { - "statement" => self.wrapping_clause_type = current_node_kind.try_into().ok(), + match previous_node.kind() { + "statement" | "subquery" => { + self.wrapping_clause_type = current_node.kind().try_into().ok(); + self.wrapping_statement_range = Some(previous_node.range()); + } "invocation" => self.is_invocation = true, _ => {} } - match current_node_kind { + match current_node.kind() { "object_reference" => { let txt = self.get_ts_node_content(current_node); if let Some(txt) = txt { @@ -159,7 +208,7 @@ impl<'a> CompletionContext<'a> { } cursor.goto_first_child_for_byte(self.position); - self.gather_context_from_node(cursor, current_node_kind); + self.gather_context_from_node(cursor, current_node); } } @@ -209,7 +258,7 @@ mod tests { ]; for (query, expected_clause) in test_cases { - let (position, text) = get_text_and_position(query.as_str()); + let (position, text) = get_text_and_position(query.as_str().into()); let tree = get_tree(text.as_str()); @@ -242,7 +291,7 @@ mod tests { ]; for (query, expected_schema) in test_cases { - let (position, text) = get_text_and_position(query.as_str()); + let (position, text) = get_text_and_position(query.as_str().into()); let tree = get_tree(text.as_str()); let params = crate::CompletionParams { @@ -276,7 +325,7 @@ mod tests { ]; for (query, is_invocation) in test_cases { - let (position, text) = get_text_and_position(query.as_str()); + let (position, text) = get_text_and_position(query.as_str().into()); let tree = get_tree(text.as_str()); let params = crate::CompletionParams { @@ -300,7 +349,7 @@ mod tests { ]; for query in cases { - let (position, text) = get_text_and_position(query.as_str()); + let (position, text) = get_text_and_position(query.as_str().into()); let tree = get_tree(text.as_str()); @@ -328,7 +377,7 @@ mod tests { fn does_not_fail_on_trailing_whitespace() { let query = format!("select * from {}", CURSOR_POS); - let (position, text) = get_text_and_position(query.as_str()); + let (position, text) = get_text_and_position(query.as_str().into()); let tree = get_tree(text.as_str()); @@ -354,7 +403,7 @@ mod tests { fn does_not_fail_with_empty_statements() { let query = format!("{}", CURSOR_POS); - let (position, text) = get_text_and_position(query.as_str()); + let (position, text) = get_text_and_position(query.as_str().into()); let tree = get_tree(text.as_str()); @@ -379,7 +428,7 @@ mod tests { // is selecting a certain column name, such as `frozen_account`. let query = format!("select * fro{}", CURSOR_POS); - let (position, text) = get_text_and_position(query.as_str()); + let (position, text) = get_text_and_position(query.as_str().into()); let tree = get_tree(text.as_str()); diff --git a/crates/pg_completions/src/item.rs b/crates/pg_completions/src/item.rs index 06771f92..d14485c2 100644 --- a/crates/pg_completions/src/item.rs +++ b/crates/pg_completions/src/item.rs @@ -4,6 +4,7 @@ use serde::{Deserialize, Serialize}; pub enum CompletionItemKind { Table, Function, + Column, } #[derive(Debug, Serialize, Deserialize)] diff --git a/crates/pg_completions/src/providers/columns.rs b/crates/pg_completions/src/providers/columns.rs new file mode 100644 index 00000000..87539c02 --- /dev/null +++ b/crates/pg_completions/src/providers/columns.rs @@ -0,0 +1,114 @@ +use crate::{ + builder::CompletionBuilder, context::CompletionContext, relevance::CompletionRelevanceData, + CompletionItem, CompletionItemKind, +}; + +pub fn complete_columns(ctx: &CompletionContext, builder: &mut CompletionBuilder) { + let available_columns = &ctx.schema_cache.columns; + + for col in available_columns { + let item = CompletionItem { + label: col.name.clone(), + score: CompletionRelevanceData::Column(col).get_score(ctx), + description: format!("Table: {}.{}", col.schema_name, col.table_name), + preselected: false, + kind: CompletionItemKind::Column, + }; + + builder.add_item(item); + } +} + +#[cfg(test)] +mod tests { + use crate::{ + complete, + test_helper::{get_test_deps, get_test_params, InputQuery, CURSOR_POS}, + CompletionItem, + }; + + struct TestCase { + query: String, + message: &'static str, + label: &'static str, + description: &'static str, + } + + impl TestCase { + fn get_input_query(&self) -> InputQuery { + let strs: Vec<&str> = self.query.split_whitespace().collect(); + strs.join(" ").as_str().into() + } + } + + #[tokio::test] + async fn completes_columns() { + let setup = r#" + create schema private; + + create table public.users ( + id serial primary key, + name text + ); + + create table public.audio_books ( + id serial primary key, + narrator text + ); + + create table private.audio_books ( + id serial primary key, + narrator_id text + ); + "#; + + let queries: Vec = vec![ + TestCase { + message: "correctly prefers the columns of present tables", + query: format!(r#"select na{} from public.audio_books;"#, CURSOR_POS), + label: "narrator", + description: "Table: public.audio_books", + }, + TestCase { + message: "correctly handles nested queries", + query: format!( + r#" + select + * + from ( + select id, na{} + from private.audio_books + ) as subquery + join public.users u + on u.id = subquery.id; + "#, + CURSOR_POS + ), + label: "narrator_id", + description: "Table: private.audio_books", + }, + TestCase { + message: "works without a schema", + query: format!(r#"select na{} from users;"#, CURSOR_POS), + label: "name", + description: "Table: public.users", + }, + ]; + + for q in queries { + let (tree, cache) = get_test_deps(setup, q.get_input_query()).await; + let params = get_test_params(&tree, &cache, q.get_input_query()); + let results = complete(params); + + let CompletionItem { + label, description, .. + } = results + .into_iter() + .next() + .expect("Should return at least one completion item"); + + assert_eq!(label, q.label, "{}", q.message); + assert_eq!(description, q.description, "{}", q.message); + } + } +} diff --git a/crates/pg_completions/src/providers/functions.rs b/crates/pg_completions/src/providers/functions.rs index d6c9db4c..e8e53020 100644 --- a/crates/pg_completions/src/providers/functions.rs +++ b/crates/pg_completions/src/providers/functions.rs @@ -43,8 +43,8 @@ mod tests { let query = format!("select coo{}", CURSOR_POS); - let (tree, cache) = get_test_deps(setup, &query).await; - let params = get_test_params(&tree, &cache, &query); + let (tree, cache) = get_test_deps(setup, query.as_str().into()).await; + let params = get_test_params(&tree, &cache, query.as_str().into()); let results = complete(params); let CompletionItem { label, .. } = results @@ -76,8 +76,8 @@ mod tests { let query = format!(r#"select * from coo{}()"#, CURSOR_POS); - let (tree, cache) = get_test_deps(setup, &query).await; - let params = get_test_params(&tree, &cache, &query); + let (tree, cache) = get_test_deps(setup, query.as_str().into()).await; + let params = get_test_params(&tree, &cache, query.as_str().into()); let results = complete(params); let CompletionItem { label, kind, .. } = results @@ -110,8 +110,8 @@ mod tests { let query = format!(r#"select coo{}"#, CURSOR_POS); - let (tree, cache) = get_test_deps(setup, &query).await; - let params = get_test_params(&tree, &cache, &query); + let (tree, cache) = get_test_deps(setup, query.as_str().into()).await; + let params = get_test_params(&tree, &cache, query.as_str().into()); let results = complete(params); let CompletionItem { label, kind, .. } = results @@ -144,8 +144,8 @@ mod tests { let query = format!(r#"select * from coo{}()"#, CURSOR_POS); - let (tree, cache) = get_test_deps(setup, &query).await; - let params = get_test_params(&tree, &cache, &query); + let (tree, cache) = get_test_deps(setup, query.as_str().into()).await; + let params = get_test_params(&tree, &cache, query.as_str().into()); let results = complete(params); let CompletionItem { label, kind, .. } = results diff --git a/crates/pg_completions/src/providers/mod.rs b/crates/pg_completions/src/providers/mod.rs index 10548206..93055129 100644 --- a/crates/pg_completions/src/providers/mod.rs +++ b/crates/pg_completions/src/providers/mod.rs @@ -1,5 +1,7 @@ +mod columns; mod functions; mod tables; +pub use columns::*; pub use functions::*; pub use tables::*; diff --git a/crates/pg_completions/src/providers/tables.rs b/crates/pg_completions/src/providers/tables.rs index 70574ec8..c3d92425 100644 --- a/crates/pg_completions/src/providers/tables.rs +++ b/crates/pg_completions/src/providers/tables.rs @@ -23,6 +23,7 @@ pub fn complete_tables(ctx: &CompletionContext, builder: &mut CompletionBuilder) #[cfg(test)] mod tests { + use crate::{ complete, test_helper::{get_test_deps, get_test_params, CURSOR_POS}, @@ -41,8 +42,8 @@ mod tests { let query = format!("select * from u{}", CURSOR_POS); - let (tree, cache) = get_test_deps(setup, &query).await; - let params = get_test_params(&tree, &cache, &query); + let (tree, cache) = get_test_deps(setup, query.as_str().into()).await; + let params = get_test_params(&tree, &cache, query.as_str().into()); let results = complete(params); assert!(!results.items.is_empty()); @@ -79,8 +80,8 @@ mod tests { ]; for (query, expected_label) in test_cases { - let (tree, cache) = get_test_deps(setup, &query).await; - let params = get_test_params(&tree, &cache, &query); + let (tree, cache) = get_test_deps(setup, query.as_str().into()).await; + let params = get_test_params(&tree, &cache, query.as_str().into()); let results = complete(params); assert!(!results.items.is_empty()); @@ -124,8 +125,8 @@ mod tests { ]; for (query, expected_label) in test_cases { - let (tree, cache) = get_test_deps(setup, &query).await; - let params = get_test_params(&tree, &cache, &query); + let (tree, cache) = get_test_deps(setup, query.as_str().into()).await; + let params = get_test_params(&tree, &cache, query.as_str().into()); let results = complete(params); assert!(!results.items.is_empty()); @@ -161,8 +162,8 @@ mod tests { let query = format!(r#"select * from coo{}"#, CURSOR_POS); - let (tree, cache) = get_test_deps(setup, &query).await; - let params = get_test_params(&tree, &cache, &query); + let (tree, cache) = get_test_deps(setup, query.as_str().into()).await; + let params = get_test_params(&tree, &cache, query.as_str().into()); let results = complete(params); let CompletionItem { label, kind, .. } = results diff --git a/crates/pg_completions/src/relevance.rs b/crates/pg_completions/src/relevance.rs index 5408a8e4..f7a42b16 100644 --- a/crates/pg_completions/src/relevance.rs +++ b/crates/pg_completions/src/relevance.rs @@ -4,6 +4,7 @@ use crate::context::{ClauseType, CompletionContext}; pub(crate) enum CompletionRelevanceData<'a> { Table(&'a pg_schema_cache::Table), Function(&'a pg_schema_cache::Function), + Column(&'a pg_schema_cache::Column), } impl<'a> CompletionRelevanceData<'a> { @@ -34,6 +35,7 @@ impl<'a> CompletionRelevance<'a> { self.check_if_catalog(ctx); self.check_is_invocation(ctx); self.check_matching_clause_type(ctx); + self.check_relations_in_stmt(ctx); self.score } @@ -49,6 +51,7 @@ impl<'a> CompletionRelevance<'a> { let name = match self.data { CompletionRelevanceData::Function(f) => f.name.as_str(), CompletionRelevanceData::Table(t) => t.name.as_str(), + CompletionRelevanceData::Column(c) => c.name.as_str(), }; if name.starts_with(content) { @@ -67,6 +70,8 @@ impl<'a> CompletionRelevance<'a> { Some(ct) => ct, }; + let has_mentioned_tables = ctx.mentioned_relations.len() > 0; + self.score += match self.data { CompletionRelevanceData::Table(_) => match clause_type { ClauseType::From => 5, @@ -75,29 +80,26 @@ impl<'a> CompletionRelevance<'a> { _ => -50, }, CompletionRelevanceData::Function(_) => match clause_type { - ClauseType::Select => 5, + ClauseType::Select if !has_mentioned_tables => 15, + ClauseType::Select if has_mentioned_tables => 0, ClauseType::From => 0, _ => -50, }, + CompletionRelevanceData::Column(_) => match clause_type { + ClauseType::Select if has_mentioned_tables => 10, + ClauseType::Select if !has_mentioned_tables => 0, + ClauseType::Where => 10, + _ => -15, + }, } } fn check_is_invocation(&mut self, ctx: &CompletionContext) { self.score += match self.data { - CompletionRelevanceData::Function(_) => { - if ctx.is_invocation { - 30 - } else { - -30 - } - } - _ => { - if ctx.is_invocation { - -10 - } else { - 0 - } - } + CompletionRelevanceData::Function(_) if ctx.is_invocation => 30, + CompletionRelevanceData::Function(_) if !ctx.is_invocation => -10, + _ if ctx.is_invocation => -10, + _ => 0, }; } @@ -107,10 +109,7 @@ impl<'a> CompletionRelevance<'a> { Some(n) => n, }; - let data_schema = match self.data { - CompletionRelevanceData::Function(f) => f.schema.as_str(), - CompletionRelevanceData::Table(t) => t.schema.as_str(), - }; + let data_schema = self.get_schema_name(); if schema_name == data_schema { self.score += 25; @@ -119,6 +118,22 @@ impl<'a> CompletionRelevance<'a> { } } + fn get_schema_name(&self) -> &str { + match self.data { + CompletionRelevanceData::Function(f) => f.schema.as_str(), + CompletionRelevanceData::Table(t) => t.schema.as_str(), + CompletionRelevanceData::Column(c) => c.schema_name.as_str(), + } + } + + fn get_table_name(&self) -> Option<&str> { + match self.data { + CompletionRelevanceData::Column(c) => Some(c.table_name.as_str()), + CompletionRelevanceData::Table(t) => Some(t.name.as_str()), + _ => None, + } + } + fn check_if_catalog(&mut self, ctx: &CompletionContext) { if ctx.schema_name.as_ref().is_some_and(|n| n == "pg_catalog") { return; @@ -126,4 +141,31 @@ impl<'a> CompletionRelevance<'a> { self.score -= 5; // unlikely that the user wants schema data } + + fn check_relations_in_stmt(&mut self, ctx: &CompletionContext) { + match self.data { + CompletionRelevanceData::Table(_) | CompletionRelevanceData::Function(_) => return, + _ => {} + } + + let schema = self.get_schema_name().to_string(); + let table_name = match self.get_table_name() { + Some(t) => t, + None => return, + }; + + if ctx + .mentioned_relations + .get(&Some(schema.to_string())) + .is_some_and(|tables| tables.contains(table_name)) + { + self.score += 45; + } else if ctx + .mentioned_relations + .get(&None) + .is_some_and(|tables| tables.contains(table_name)) + { + self.score += 30; + } + } } diff --git a/crates/pg_completions/src/test_helper.rs b/crates/pg_completions/src/test_helper.rs index 4c29d1e7..83f9cdd9 100644 --- a/crates/pg_completions/src/test_helper.rs +++ b/crates/pg_completions/src/test_helper.rs @@ -6,9 +6,34 @@ use crate::CompletionParams; pub static CURSOR_POS: char = '€'; +pub struct InputQuery { + sql: String, + position: usize, +} + +impl From<&str> for InputQuery { + fn from(value: &str) -> Self { + let position = value + .find(CURSOR_POS) + .map(|p| p.saturating_sub(1)) + .expect("Insert Cursor Position into your Query."); + + InputQuery { + sql: value.replace(CURSOR_POS, ""), + position, + } + } +} + +impl ToString for InputQuery { + fn to_string(&self) -> String { + self.sql.clone() + } +} + pub(crate) async fn get_test_deps( setup: &str, - input: &str, + input: InputQuery, ) -> (tree_sitter::Tree, pg_schema_cache::SchemaCache) { let test_db = get_new_test_db().await; @@ -26,27 +51,19 @@ pub(crate) async fn get_test_deps( .set_language(tree_sitter_sql::language()) .expect("Error loading sql language"); - let tree = parser.parse(input, None).unwrap(); + let tree = parser.parse(&input.to_string(), None).unwrap(); (tree, schema_cache) } -pub(crate) fn get_text_and_position(sql: &str) -> (usize, String) { - // the cursor is to the left of the `CURSOR_POS` - let position = sql - .find(CURSOR_POS) - .expect("Please insert the CURSOR_POS into your query.") - .saturating_sub(1); - - let text = sql.replace(CURSOR_POS, ""); - - (position, text) +pub(crate) fn get_text_and_position(q: InputQuery) -> (usize, String) { + (q.position, q.sql) } pub(crate) fn get_test_params<'a>( tree: &'a tree_sitter::Tree, schema_cache: &'a pg_schema_cache::SchemaCache, - sql: &'a str, + sql: InputQuery, ) -> CompletionParams<'a> { let (position, text) = get_text_and_position(sql); diff --git a/crates/pg_lsp/src/utils/to_lsp_types.rs b/crates/pg_lsp/src/utils/to_lsp_types.rs index ca5f3f42..24dcc443 100644 --- a/crates/pg_lsp/src/utils/to_lsp_types.rs +++ b/crates/pg_lsp/src/utils/to_lsp_types.rs @@ -6,5 +6,6 @@ pub fn to_completion_kind( match kind { pg_completions::CompletionItemKind::Table => lsp_types::CompletionItemKind::CLASS, pg_completions::CompletionItemKind::Function => lsp_types::CompletionItemKind::FUNCTION, + pg_completions::CompletionItemKind::Column => lsp_types::CompletionItemKind::FIELD, } } diff --git a/crates/pg_lsp_new/Cargo.toml b/crates/pg_lsp_new/Cargo.toml index 0454893e..8e20b521 100644 --- a/crates/pg_lsp_new/Cargo.toml +++ b/crates/pg_lsp_new/Cargo.toml @@ -16,10 +16,10 @@ anyhow = { workspace = true } biome_deserialize = { workspace = true } futures = "0.3.31" pg_analyse = { workspace = true } +pg_completions = { workspace = true } pg_configuration = { workspace = true } pg_console = { workspace = true } pg_diagnostics = { workspace = true } -pg_completions = { workspace = true } pg_fs = { workspace = true } pg_lsp_converters = { workspace = true } pg_text_edit = { workspace = true } diff --git a/crates/pg_lsp_new/src/handlers/completions.rs b/crates/pg_lsp_new/src/handlers/completions.rs index 5f7a1309..4efba210 100644 --- a/crates/pg_lsp_new/src/handlers/completions.rs +++ b/crates/pg_lsp_new/src/handlers/completions.rs @@ -54,7 +54,8 @@ fn to_lsp_types_completion_item_kind( pg_comp_kind: pg_completions::CompletionItemKind, ) -> lsp_types::CompletionItemKind { match pg_comp_kind { - pg_completions::CompletionItemKind::Function - | pg_completions::CompletionItemKind::Table => lsp_types::CompletionItemKind::CLASS, + pg_completions::CompletionItemKind::Function => lsp_types::CompletionItemKind::FUNCTION, + pg_completions::CompletionItemKind::Table => lsp_types::CompletionItemKind::CLASS, + pg_completions::CompletionItemKind::Column => lsp_types::CompletionItemKind::FIELD, } } diff --git a/crates/pg_schema_cache/src/lib.rs b/crates/pg_schema_cache/src/lib.rs index 719da404..c6dad0b7 100644 --- a/crates/pg_schema_cache/src/lib.rs +++ b/crates/pg_schema_cache/src/lib.rs @@ -10,6 +10,7 @@ mod tables; mod types; mod versions; +pub use columns::*; pub use functions::{Behavior, Function, FunctionArg, FunctionArgs}; pub use schema_cache::SchemaCache; pub use tables::{ReplicaIdentity, Table}; diff --git a/crates/pg_treesitter_queries/Cargo.toml b/crates/pg_treesitter_queries/Cargo.toml new file mode 100644 index 00000000..bb85c448 --- /dev/null +++ b/crates/pg_treesitter_queries/Cargo.toml @@ -0,0 +1,24 @@ +[package] +authors.workspace = true +categories.workspace = true +description = "" +edition.workspace = true +homepage.workspace = true +keywords.workspace = true +license.workspace = true +name = "pg_treesitter_queries" +repository.workspace = true +version = "0.0.0" + + +[dependencies] +clap = { version = "4.5.23", features = ["derive"] } +tree-sitter.workspace = true +tree_sitter_sql.workspace = true + +[dev-dependencies] + +[lib] +doctest = false + +[features] diff --git a/crates/pg_treesitter_queries/src/lib.rs b/crates/pg_treesitter_queries/src/lib.rs new file mode 100644 index 00000000..8d29db38 --- /dev/null +++ b/crates/pg_treesitter_queries/src/lib.rs @@ -0,0 +1,188 @@ +pub mod queries; + +use std::slice::Iter; + +use queries::{Query, QueryResult}; + +pub struct TreeSitterQueriesExecutor<'a> { + root_node: tree_sitter::Node<'a>, + stmt: &'a str, + results: Vec>, +} + +impl<'a> TreeSitterQueriesExecutor<'a> { + pub fn new(root_node: tree_sitter::Node<'a>, stmt: &'a str) -> Self { + Self { + root_node, + stmt, + results: vec![], + } + } + + #[allow(private_bounds)] + pub fn add_query_results>(&mut self) { + let mut results = Q::execute(self.root_node, &self.stmt); + self.results.append(&mut results); + } + + pub fn get_iter(&self, range: Option<&'a tree_sitter::Range>) -> QueryResultIter { + match range { + Some(r) => QueryResultIter::new(&self.results).within_range(r), + None => QueryResultIter::new(&self.results), + } + } +} + +pub struct QueryResultIter<'a> { + inner: Iter<'a, QueryResult<'a>>, + range: Option<&'a tree_sitter::Range>, +} + +impl<'a> QueryResultIter<'a> { + pub(crate) fn new(results: &'a Vec>) -> Self { + Self { + inner: results.iter(), + range: None, + } + } + + fn within_range(mut self, r: &'a tree_sitter::Range) -> Self { + self.range = Some(r); + self + } +} + +impl<'a> Iterator for QueryResultIter<'a> { + type Item = &'a QueryResult<'a>; + fn next(&mut self) -> Option { + let next = self.inner.next()?; + + if self.range.as_ref().is_some_and(|r| !next.within_range(r)) { + return self.next(); + } + + Some(next) + } +} + +#[cfg(test)] +mod tests { + + use crate::{queries::RelationMatch, TreeSitterQueriesExecutor}; + + #[test] + fn finds_all_relations_and_ignores_functions() { + let sql = r#" +select + * +from + ( + select + something + from + public.cool_table pu + join private.cool_tableau pr on pu.id = pr.id + where + x = '123' + union + select + something_else + from + another_table puat + inner join private.another_tableau prat on puat.id = prat.id + union + select + x, + y + from + public.get_something_cool () + ) +where + col = 17; +"#; + + let mut parser = tree_sitter::Parser::new(); + parser.set_language(tree_sitter_sql::language()).unwrap(); + + let tree = parser.parse(&sql, None).unwrap(); + + let mut executor = TreeSitterQueriesExecutor::new(tree.root_node(), &sql); + + executor.add_query_results::(); + + let results: Vec<&RelationMatch> = executor + .get_iter(None) + .filter_map(|q| q.try_into().ok()) + .collect(); + + assert_eq!(results[0].get_schema(sql), Some("public".into())); + assert_eq!(results[0].get_table(sql), "cool_table"); + + assert_eq!(results[1].get_schema(sql), Some("private".into())); + assert_eq!(results[1].get_table(sql), "cool_tableau"); + + assert_eq!(results[2].get_schema(sql), None); + assert_eq!(results[2].get_table(sql), "another_table"); + + assert_eq!(results[3].get_schema(sql), Some("private".into())); + assert_eq!(results[3].get_table(sql), "another_tableau"); + + // we have exhausted the matches: function invocations are ignored. + assert!(results.len() == 4); + } + + #[test] + fn only_considers_nodes_in_requested_range() { + let sql = r#" +select + * +from ( + select * + from ( + select * + from private.something + ) as sq2 + join private.tableau pt1 + on sq2.id = pt1.id + ) as sq1 +join private.table pt +on sq1.id = pt.id; +"#; + + let mut parser = tree_sitter::Parser::new(); + parser.set_language(tree_sitter_sql::language()).unwrap(); + + let tree = parser.parse(&sql, None).unwrap(); + + // trust me bro + let range = { + let mut cursor = tree.root_node().walk(); + cursor.goto_first_child(); // statement + cursor.goto_first_child(); // select + cursor.goto_next_sibling(); // from + cursor.goto_first_child(); // keyword_from + cursor.goto_next_sibling(); // relation + cursor.goto_first_child(); // subquery (1) + cursor.goto_first_child(); // "(" + cursor.goto_next_sibling(); // select + cursor.goto_next_sibling(); // from + cursor.goto_first_child(); // keyword_from + cursor.goto_next_sibling(); // relation + cursor.goto_first_child(); // subquery (2) + cursor.node().range() + }; + + let mut executor = TreeSitterQueriesExecutor::new(tree.root_node(), &sql); + + executor.add_query_results::(); + + let results: Vec<&RelationMatch> = executor + .get_iter(Some(&range)) + .filter_map(|q| q.try_into().ok()) + .collect(); + + assert_eq!(results.len(), 1); + assert_eq!(results[0].get_schema(sql), Some("private".into())); + assert_eq!(results[0].get_table(sql), "something"); + } +} diff --git a/crates/pg_treesitter_queries/src/queries/mod.rs b/crates/pg_treesitter_queries/src/queries/mod.rs new file mode 100644 index 00000000..92e3b06c --- /dev/null +++ b/crates/pg_treesitter_queries/src/queries/mod.rs @@ -0,0 +1,35 @@ +mod relations; + +pub use relations::*; + +#[derive(Debug)] +pub enum QueryResult<'a> { + Relation(RelationMatch<'a>), +} + +impl<'a> QueryResult<'a> { + pub fn within_range(&self, range: &tree_sitter::Range) -> bool { + match self { + Self::Relation(rm) => { + let start = match rm.schema { + Some(s) => s.start_position(), + None => rm.table.start_position(), + }; + + let end = rm.table.end_position(); + + start >= range.start_point && end <= range.end_point + } + } + } +} + +// This trait enforces that for any `Self` that implements `Query`, +// its &Self must implement TryFrom<&QueryResult> +pub(crate) trait QueryTryFrom<'a>: Sized { + type Ref: for<'any> TryFrom<&'a QueryResult<'a>, Error = String>; +} + +pub(crate) trait Query<'a>: QueryTryFrom<'a> { + fn execute(root_node: tree_sitter::Node<'a>, stmt: &'a str) -> Vec>; +} diff --git a/crates/pg_treesitter_queries/src/queries/relations.rs b/crates/pg_treesitter_queries/src/queries/relations.rs new file mode 100644 index 00000000..2ca27a05 --- /dev/null +++ b/crates/pg_treesitter_queries/src/queries/relations.rs @@ -0,0 +1,93 @@ +use std::sync::LazyLock; + +use crate::{Query, QueryResult}; + +use super::QueryTryFrom; + +static TS_QUERY: LazyLock = LazyLock::new(|| { + static QUERY_STR: &'static str = r#" + (relation + (object_reference + . + (identifier) @schema_or_table + "."? + (identifier)? @table + )+ + ) +"#; + tree_sitter::Query::new(tree_sitter_sql::language(), &QUERY_STR).expect("Invalid TS Query") +}); + +#[derive(Debug)] +pub struct RelationMatch<'a> { + pub(crate) schema: Option>, + pub(crate) table: tree_sitter::Node<'a>, +} + +impl<'a> RelationMatch<'a> { + pub fn get_schema(&self, sql: &str) -> Option { + let str = self + .schema + .as_ref()? + .utf8_text(sql.as_bytes()) + .expect("Failed to get schema from RelationMatch"); + + Some(str.to_string()) + } + + pub fn get_table(&self, sql: &str) -> String { + self.table + .utf8_text(sql.as_bytes()) + .expect("Failed to get schema from RelationMatch") + .to_string() + } +} + +impl<'a> TryFrom<&'a QueryResult<'a>> for &'a RelationMatch<'a> { + type Error = String; + + fn try_from(q: &'a QueryResult<'a>) -> Result { + match q { + QueryResult::Relation(r) => Ok(&r), + + #[allow(unreachable_patterns)] + _ => Err("Invalid QueryResult type".into()), + } + } +} + +impl<'a> QueryTryFrom<'a> for RelationMatch<'a> { + type Ref = &'a RelationMatch<'a>; +} + +impl<'a> Query<'a> for RelationMatch<'a> { + fn execute(root_node: tree_sitter::Node<'a>, stmt: &'a str) -> Vec> { + let mut cursor = tree_sitter::QueryCursor::new(); + + let matches = cursor.matches(&TS_QUERY, root_node, stmt.as_bytes()); + + let mut to_return = vec![]; + + for m in matches { + if m.captures.len() == 1 { + let capture = m.captures[0].node; + to_return.push(QueryResult::Relation(RelationMatch { + schema: None, + table: capture, + })); + } + + if m.captures.len() == 2 { + let schema = m.captures[0].node; + let table = m.captures[1].node; + + to_return.push(QueryResult::Relation(RelationMatch { + schema: Some(schema), + table, + })); + } + } + + to_return + } +} diff --git a/crates/pg_workspace_new/Cargo.toml b/crates/pg_workspace_new/Cargo.toml index 9da718cf..c48bb6e2 100644 --- a/crates/pg_workspace_new/Cargo.toml +++ b/crates/pg_workspace_new/Cargo.toml @@ -18,10 +18,10 @@ futures = "0.3.31" ignore = { workspace = true } pg_analyse = { workspace = true, features = ["serde"] } pg_analyser = { workspace = true } +pg_completions = { workspace = true } pg_configuration = { workspace = true } pg_console = { workspace = true } pg_diagnostics = { workspace = true } -pg_completions = { workspace = true } pg_fs = { workspace = true, features = ["serde"] } pg_query_ext = { workspace = true } pg_schema_cache = { workspace = true }