Skip to content

Commit 7dd52c5

Browse files
authored
feat: add macros to LSP completion (#4667)
1 parent 8efbe5f commit 7dd52c5

File tree

8 files changed

+210
-46
lines changed

8 files changed

+210
-46
lines changed

sqlmesh/lsp/completions.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
from functools import lru_cache
22
from sqlglot import Dialect, Tokenizer
33
from sqlmesh.lsp.custom import AllModelsResponse
4+
from sqlmesh import macro
45
import typing as t
56
from sqlmesh.lsp.context import AuditTarget, LSPContext, ModelTarget
67
from sqlmesh.lsp.uri import URI
78

89

910
def get_sql_completions(
10-
context: t.Optional[LSPContext], file_uri: t.Optional[URI], content: t.Optional[str] = None
11+
context: t.Optional[LSPContext] = None,
12+
file_uri: t.Optional[URI] = None,
13+
content: t.Optional[str] = None,
1114
) -> AllModelsResponse:
1215
"""
1316
Return a list of completions for a given file.
@@ -26,6 +29,7 @@ def get_sql_completions(
2629
return AllModelsResponse(
2730
models=list(get_models(context, file_uri)),
2831
keywords=all_keywords,
32+
macros=list(get_macros(context, file_uri)),
2933
)
3034

3135

@@ -56,6 +60,17 @@ def get_models(context: t.Optional[LSPContext], file_uri: t.Optional[URI]) -> t.
5660
return all_models
5761

5862

63+
def get_macros(context: t.Optional[LSPContext], file_uri: t.Optional[URI]) -> t.Set[str]:
64+
"""Return a set of all macros with the ``@`` prefix."""
65+
names = set(macro.get_registry())
66+
try:
67+
if context is not None:
68+
names.update(context.context._macros)
69+
except Exception:
70+
pass
71+
return names
72+
73+
5974
def get_keywords(context: t.Optional[LSPContext], file_uri: t.Optional[URI]) -> t.Set[str]:
6075
"""
6176
Return a list of sql keywords for a given file.
@@ -138,6 +153,7 @@ def get_dialect(context: t.Optional[LSPContext], file_uri: t.Optional[URI]) -> t
138153
def extract_keywords_from_content(content: str, dialect: t.Optional[str] = None) -> t.Set[str]:
139154
"""
140155
Extract identifiers from SQL content using the tokenizer.
156+
141157
Only extracts identifiers (variable names, table names, column names, etc.)
142158
that are not SQL keywords.
143159
"""
@@ -155,7 +171,7 @@ def extract_keywords_from_content(content: str, dialect: t.Optional[str] = None)
155171
keywords.add(token.text)
156172

157173
except Exception:
158-
# If tokenization fails, return empty set
174+
# If tokenization fails, return an empty set
159175
pass
160176

161177
return keywords

sqlmesh/lsp/context.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -176,18 +176,13 @@ def list_of_models_for_rendering(self) -> t.List[ModelForRendering]:
176176
if audit._path is not None
177177
]
178178

179-
def get_autocomplete(
180-
self, uri: t.Optional[URI], content: t.Optional[str] = None
179+
@staticmethod
180+
def get_completions(
181+
self: t.Optional["LSPContext"] = None,
182+
uri: t.Optional[URI] = None,
183+
file_content: t.Optional[str] = None,
181184
) -> AllModelsResponse:
182-
"""Get autocomplete suggestions for a file.
183-
184-
Args:
185-
uri: The URI of the file to get autocomplete suggestions for.
186-
content: The content of the file (optional).
187-
188-
Returns:
189-
AllModelsResponse containing models and keywords.
190-
"""
185+
"""Get completion suggestions for a file"""
191186
from sqlmesh.lsp.completions import get_sql_completions
192187

193-
return get_sql_completions(self, uri, content)
188+
return get_sql_completions(self, uri, file_content)

sqlmesh/lsp/custom.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ class AllModelsResponse(PydanticModel):
2020

2121
models: t.List[str]
2222
keywords: t.List[str]
23+
macros: t.List[str]
2324

2425

2526
RENDER_MODEL_FEATURE = "sqlmesh/render_model"

sqlmesh/lsp/helpers.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from lsprotocol.types import Range, Position
2+
3+
from sqlmesh.core.linter.helpers import (
4+
Range as SQLMeshRange,
5+
Position as SQLMeshPosition,
6+
)
7+
8+
9+
def to_lsp_range(
10+
range: SQLMeshRange,
11+
) -> Range:
12+
"""
13+
Converts a SQLMesh Range to an LSP Range.
14+
"""
15+
return Range(
16+
start=Position(line=range.start.line, character=range.start.character),
17+
end=Position(line=range.end.line, character=range.end.character),
18+
)
19+
20+
21+
def to_lsp_position(
22+
position: SQLMeshPosition,
23+
) -> Position:
24+
"""
25+
Converts a SQLMesh Position to an LSP Position.
26+
"""
27+
return Position(line=position.line, character=position.character)

sqlmesh/lsp/main.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def all_models(ls: LanguageServer, params: AllModelsRequest) -> AllModelsRespons
145145

146146
try:
147147
context = self._context_get_or_load(uri)
148-
return context.get_autocomplete(uri, content)
148+
return LSPContext.get_completions(context, uri, content)
149149
except Exception as e:
150150
from sqlmesh.lsp.completions import get_sql_completions
151151

@@ -565,7 +565,10 @@ def workspace_diagnostic(
565565
)
566566
return types.WorkspaceDiagnosticReport(items=[])
567567

568-
@self.server.feature(types.TEXT_DOCUMENT_COMPLETION)
568+
@self.server.feature(
569+
types.TEXT_DOCUMENT_COMPLETION,
570+
types.CompletionOptions(trigger_characters=["@"]), # advertise "@" for macros
571+
)
569572
def completion(
570573
ls: LanguageServer, params: types.CompletionParams
571574
) -> t.Optional[types.CompletionList]:
@@ -583,7 +586,7 @@ def completion(
583586
pass
584587

585588
# Get completions using the existing completions module
586-
completion_response = context.get_autocomplete(uri, content)
589+
completion_response = LSPContext.get_completions(context, uri, content)
587590

588591
completion_items = []
589592
# Add model completions
@@ -595,7 +598,26 @@ def completion(
595598
detail="SQLMesh Model",
596599
)
597600
)
598-
# Add keyword completions
601+
# Add macro completions
602+
triggered_by_at = (
603+
params.context is not None
604+
and getattr(params.context, "trigger_character", None) == "@"
605+
)
606+
607+
for macro_name in completion_response.macros:
608+
insert_text = macro_name if triggered_by_at else f"@{macro_name}"
609+
610+
completion_items.append(
611+
types.CompletionItem(
612+
label=f"@{macro_name}",
613+
insert_text=insert_text,
614+
insert_text_format=types.InsertTextFormat.PlainText,
615+
filter_text=macro_name,
616+
kind=types.CompletionItemKind.Function,
617+
detail="SQLMesh Macro",
618+
)
619+
)
620+
599621
for keyword in completion_response.keywords:
600622
completion_items.append(
601623
types.CompletionItem(

sqlmesh/lsp/reference.py

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@
77
from sqlmesh.core.dialect import normalize_model_name
88
from sqlmesh.core.linter.helpers import (
99
TokenPositionDetails,
10-
Range as SQLMeshRange,
11-
Position as SQLMeshPosition,
1210
)
1311
from sqlmesh.core.model.definition import SqlModel
1412
from sqlmesh.lsp.context import LSPContext, ModelTarget, AuditTarget
1513
from sqlglot import exp
1614
from sqlmesh.lsp.description import generate_markdown_description
1715
from sqlglot.optimizer.scope import build_scope
16+
17+
from sqlmesh.lsp.helpers import to_lsp_range, to_lsp_position
1818
from sqlmesh.lsp.uri import URI
1919
from sqlmesh.utils.pydantic import PydanticModel
2020
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
@@ -624,24 +624,3 @@ def _position_within_range(position: Position, range: Range) -> bool:
624624
range.end.line > position.line
625625
or (range.end.line == position.line and range.end.character >= position.character)
626626
)
627-
628-
629-
def to_lsp_range(
630-
range: SQLMeshRange,
631-
) -> Range:
632-
"""
633-
Converts a SQLMesh Range to an LSP Range.
634-
"""
635-
return Range(
636-
start=Position(line=range.start.line, character=range.start.character),
637-
end=Position(line=range.end.line, character=range.end.character),
638-
)
639-
640-
641-
def to_lsp_position(
642-
position: SQLMeshPosition,
643-
) -> Position:
644-
"""
645-
Converts a SQLMesh Position to an LSP Position.
646-
"""
647-
return Position(line=position.line, character=position.character)

tests/lsp/test_completions.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,27 @@ def test_get_sql_completions_no_context():
2222
assert len(completions.models) == 0
2323

2424

25+
def test_get_macros():
26+
context = Context(paths=["examples/sushi"])
27+
lsp_context = LSPContext(context)
28+
29+
file_path = next(key for key in lsp_context.map.keys() if key.name == "active_customers.sql")
30+
with open(file_path, "r", encoding="utf-8") as f:
31+
file_content = f.read()
32+
33+
file_uri = URI.from_path(file_path)
34+
completions = LSPContext.get_completions(lsp_context, file_uri, file_content)
35+
36+
assert "each" in completions.macros
37+
assert "add_one" in completions.macros
38+
39+
2540
def test_get_sql_completions_with_context_no_file_uri():
2641
context = Context(paths=["examples/sushi"])
2742
lsp_context = LSPContext(context)
2843

29-
completions = lsp_context.get_autocomplete(None)
30-
assert len(completions.keywords) > len(TOKENIZER_KEYWORDS)
44+
completions = LSPContext.get_completions(lsp_context, None)
45+
assert len(completions.keywords) >= len(TOKENIZER_KEYWORDS)
3146
assert "sushi.active_customers" in completions.models
3247
assert "sushi.customers" in completions.models
3348

@@ -37,7 +52,7 @@ def test_get_sql_completions_with_context_and_file_uri():
3752
lsp_context = LSPContext(context)
3853

3954
file_uri = next(key for key in lsp_context.map.keys() if key.name == "active_customers.sql")
40-
completions = lsp_context.get_autocomplete(URI.from_path(file_uri))
55+
completions = LSPContext.get_completions(lsp_context, URI.from_path(file_uri))
4156
assert len(completions.keywords) > len(TOKENIZER_KEYWORDS)
4257
assert "sushi.active_customers" not in completions.models
4358

@@ -84,7 +99,7 @@ def test_get_sql_completions_with_file_content():
8499
"""
85100

86101
file_uri = next(key for key in lsp_context.map.keys() if key.name == "active_customers.sql")
87-
completions = lsp_context.get_autocomplete(URI.from_path(file_uri), content)
102+
completions = LSPContext.get_completions(lsp_context, URI.from_path(file_uri), content)
88103

89104
# Check that SQL keywords are included
90105
assert any(k in ["SELECT", "FROM", "WHERE", "JOIN"] for k in completions.keywords)
@@ -129,7 +144,7 @@ def test_get_sql_completions_with_partial_cte_query():
129144
"""
130145

131146
file_uri = next(key for key in lsp_context.map.keys() if key.name == "active_customers.sql")
132-
completions = lsp_context.get_autocomplete(URI.from_path(file_uri), content)
147+
completions = LSPContext.get_completions(lsp_context, URI.from_path(file_uri), content)
133148

134149
# Check that CTE names are included in the keywords
135150
keywords_list = completions.keywords

vscode/extension/tests/completions.spec.ts

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,118 @@ test('Autocomplete for model names', async () => {
4747
expect(
4848
await window.locator('text=sushi.waiter_as_customer_by_day').count(),
4949
).toBe(1)
50+
expect(await window.locator('text=SQLMesh Model').count()).toBe(1)
5051

5152
await close()
5253
} finally {
5354
await fs.remove(tempDir)
5455
}
5556
})
57+
58+
// Skip the macro completions test as regular checks because they are flaky and
59+
// covered in other non-integration tests.
60+
test.describe('Macro Completions', () => {
61+
test('Completion for inbuilt macros', async () => {
62+
const tempDir = await fs.mkdtemp(
63+
path.join(os.tmpdir(), 'vscode-test-sushi-'),
64+
)
65+
await fs.copy(SUSHI_SOURCE_PATH, tempDir)
66+
67+
try {
68+
const { window, close } = await startVSCode(tempDir)
69+
70+
// Wait for the models folder to be visible
71+
await window.waitForSelector('text=models')
72+
73+
// Click on the models folder
74+
await window
75+
.getByRole('treeitem', { name: 'models', exact: true })
76+
.locator('a')
77+
.click()
78+
79+
// Open the top_waiters model
80+
await window
81+
.getByRole('treeitem', { name: 'customers.sql', exact: true })
82+
.locator('a')
83+
.click()
84+
85+
await window.waitForSelector('text=grain')
86+
await window.waitForSelector('text=Loaded SQLMesh Context')
87+
88+
await window.locator('text=grain').first().click()
89+
90+
// Move to the end of the file
91+
await window.keyboard.press('Control+End')
92+
93+
// Add a new line
94+
await window.keyboard.press('Enter')
95+
96+
await window.waitForTimeout(500)
97+
98+
// Hit the '@' key to trigger autocomplete for inbuilt macros
99+
await window.keyboard.press('@')
100+
await window.keyboard.type('eac')
101+
102+
// Wait a moment for autocomplete to appear
103+
await window.waitForTimeout(500)
104+
105+
// Check if the autocomplete suggestion for inbuilt macros is visible
106+
expect(await window.locator('text=@each').count()).toBe(1)
107+
108+
await close()
109+
} finally {
110+
await fs.remove(tempDir)
111+
}
112+
})
113+
114+
test('Completion for custom macros', async () => {
115+
const tempDir = await fs.mkdtemp(
116+
path.join(os.tmpdir(), 'vscode-test-sushi-'),
117+
)
118+
await fs.copy(SUSHI_SOURCE_PATH, tempDir)
119+
120+
try {
121+
const { window, close } = await startVSCode(tempDir)
122+
123+
// Wait for the models folder to be visible
124+
await window.waitForSelector('text=models')
125+
126+
// Click on the models folder
127+
await window
128+
.getByRole('treeitem', { name: 'models', exact: true })
129+
.locator('a')
130+
.click()
131+
132+
// Open the top_waiters model
133+
await window
134+
.getByRole('treeitem', { name: 'customers.sql', exact: true })
135+
.locator('a')
136+
.click()
137+
138+
await window.waitForSelector('text=grain')
139+
await window.waitForSelector('text=Loaded SQLMesh Context')
140+
141+
await window.locator('text=grain').first().click()
142+
143+
// Move to the end of the file
144+
await window.keyboard.press('Control+End')
145+
146+
// Add a new line
147+
await window.keyboard.press('Enter')
148+
149+
// Type the beginning of a macro to trigger autocomplete
150+
await window.keyboard.press('@')
151+
await window.keyboard.type('add_o')
152+
153+
// Wait a moment for autocomplete to appear
154+
await window.waitForTimeout(500)
155+
156+
// Check if the autocomplete suggestion for custom macros is visible
157+
expect(await window.locator('text=@add_one').count()).toBe(1)
158+
159+
await close()
160+
} finally {
161+
await fs.remove(tempDir)
162+
}
163+
})
164+
})

0 commit comments

Comments
 (0)