Skip to content

Commit a69f1a4

Browse files
authored
feat(lsp): add go to definition for ctes (#4543)
1 parent 4dbff92 commit a69f1a4

File tree

3 files changed

+176
-68
lines changed

3 files changed

+176
-68
lines changed

sqlmesh/lsp/main.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -279,21 +279,31 @@ def goto_definition(
279279
raise RuntimeError(f"No context found for document: {document.path}")
280280

281281
references = get_references(self.lsp_context, uri, params.position)
282-
return [
283-
types.LocationLink(
284-
target_uri=reference.uri,
285-
target_selection_range=types.Range(
282+
location_links = []
283+
for reference in references:
284+
# Use target_range if available (for CTEs), otherwise default to start of file
285+
if reference.target_range:
286+
target_range = reference.target_range
287+
target_selection_range = reference.target_range
288+
else:
289+
target_range = types.Range(
286290
start=types.Position(line=0, character=0),
287291
end=types.Position(line=0, character=0),
288-
),
289-
target_range=types.Range(
292+
)
293+
target_selection_range = types.Range(
290294
start=types.Position(line=0, character=0),
291295
end=types.Position(line=0, character=0),
292-
),
293-
origin_selection_range=reference.range,
296+
)
297+
298+
location_links.append(
299+
types.LocationLink(
300+
target_uri=reference.uri,
301+
target_selection_range=target_selection_range,
302+
target_range=target_range,
303+
origin_selection_range=reference.range,
304+
)
294305
)
295-
for reference in references
296-
]
306+
return location_links
297307
except Exception as e:
298308
ls.show_message(f"Error getting references: {e}", types.MessageType.Error)
299309
return []

sqlmesh/lsp/reference.py

Lines changed: 92 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,27 @@
66
from sqlmesh.lsp.context import LSPContext, ModelTarget, AuditTarget
77
from sqlglot import exp
88
from sqlmesh.lsp.description import generate_markdown_description
9+
from sqlglot.optimizer.scope import build_scope
910
from sqlmesh.lsp.uri import URI
1011
from sqlmesh.utils.pydantic import PydanticModel
12+
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
1113

1214

1315
class Reference(PydanticModel):
1416
"""
15-
A reference to a model.
17+
A reference to a model or CTE.
1618
1719
Attributes:
1820
range: The range of the reference in the source file
1921
uri: The uri of the referenced model
2022
markdown_description: The markdown description of the referenced model
23+
target_range: The range of the definition for go-to-definition (optional, used for CTEs)
2124
"""
2225

2326
range: Range
2427
uri: str
2528
markdown_description: t.Optional[str] = None
29+
target_range: t.Optional[Range] = None
2630

2731

2832
def by_position(position: Position) -> t.Callable[[Reference], bool]:
@@ -88,6 +92,7 @@ def get_model_definitions_for_a_path(
8892
- Need to normalize it before matching
8993
- Try get_model before normalization
9094
- Match to models that the model refers to
95+
- Also find CTE references within the query
9196
"""
9297
path = document_uri.to_path()
9398
if path.suffix != ".sql":
@@ -126,66 +131,95 @@ def get_model_definitions_for_a_path(
126131
# Find all possible references
127132
references = []
128133

129-
# Get SQL query and find all table references
130-
tables = list(query.find_all(exp.Table))
131-
if len(tables) == 0:
132-
return []
133-
134134
with open(file_path, "r", encoding="utf-8") as file:
135135
read_file = file.readlines()
136136

137-
for table in tables:
138-
# Normalize the table reference
139-
unaliased = table.copy()
140-
if unaliased.args.get("alias") is not None:
141-
unaliased.set("alias", None)
142-
reference_name = unaliased.sql(dialect=dialect)
143-
try:
144-
normalized_reference_name = normalize_model_name(
145-
reference_name,
146-
default_catalog=lint_context.context.default_catalog,
147-
dialect=dialect,
148-
)
149-
if normalized_reference_name not in depends_on:
150-
continue
151-
except Exception:
152-
# Skip references that cannot be normalized
153-
continue
154-
155-
# Get the referenced model uri
156-
referenced_model = lint_context.context.get_model(
157-
model_or_snapshot=normalized_reference_name, raise_if_missing=False
158-
)
159-
if referenced_model is None:
160-
continue
161-
referenced_model_path = referenced_model._path
162-
# Check whether the path exists
163-
if not referenced_model_path.is_file():
164-
continue
165-
referenced_model_uri = URI.from_path(referenced_model_path)
166-
167-
# Extract metadata for positioning
168-
table_meta = TokenPositionDetails.from_meta(table.this.meta)
169-
table_range = _range_from_token_position_details(table_meta, read_file)
170-
start_pos = table_range.start
171-
end_pos = table_range.end
172-
173-
# If there's a catalog or database qualifier, adjust the start position
174-
catalog_or_db = table.args.get("catalog") or table.args.get("db")
175-
if catalog_or_db is not None:
176-
catalog_or_db_meta = TokenPositionDetails.from_meta(catalog_or_db.meta)
177-
catalog_or_db_range = _range_from_token_position_details(catalog_or_db_meta, read_file)
178-
start_pos = catalog_or_db_range.start
179-
180-
description = generate_markdown_description(referenced_model)
181-
182-
references.append(
183-
Reference(
184-
uri=referenced_model_uri.value,
185-
range=Range(start=start_pos, end=end_pos),
186-
markdown_description=description,
187-
)
188-
)
137+
# Build scope tree to properly handle nested CTEs
138+
query = normalize_identifiers(query.copy(), dialect=dialect)
139+
root_scope = build_scope(query)
140+
141+
if root_scope:
142+
# Traverse all scopes to find CTE definitions and table references
143+
for scope in root_scope.traverse():
144+
for table in scope.tables:
145+
table_name = table.name
146+
147+
# Check if this table reference is a CTE in the current scope
148+
if cte_scope := scope.cte_sources.get(table_name):
149+
cte = cte_scope.expression.parent
150+
alias = cte.args["alias"]
151+
if isinstance(alias, exp.TableAlias):
152+
identifier = alias.this
153+
if isinstance(identifier, exp.Identifier):
154+
target_range = _range_from_token_position_details(
155+
TokenPositionDetails.from_meta(identifier.meta), read_file
156+
)
157+
table_range = _range_from_token_position_details(
158+
TokenPositionDetails.from_meta(table.this.meta), read_file
159+
)
160+
references.append(
161+
Reference(
162+
uri=document_uri.value, # Same file
163+
range=table_range,
164+
target_range=target_range,
165+
)
166+
)
167+
continue
168+
169+
# For non-CTE tables, process as before (external model references)
170+
# Normalize the table reference
171+
unaliased = table.copy()
172+
if unaliased.args.get("alias") is not None:
173+
unaliased.set("alias", None)
174+
reference_name = unaliased.sql(dialect=dialect)
175+
try:
176+
normalized_reference_name = normalize_model_name(
177+
reference_name,
178+
default_catalog=lint_context.context.default_catalog,
179+
dialect=dialect,
180+
)
181+
if normalized_reference_name not in depends_on:
182+
continue
183+
except Exception:
184+
# Skip references that cannot be normalized
185+
continue
186+
187+
# Get the referenced model uri
188+
referenced_model = lint_context.context.get_model(
189+
model_or_snapshot=normalized_reference_name, raise_if_missing=False
190+
)
191+
if referenced_model is None:
192+
continue
193+
referenced_model_path = referenced_model._path
194+
# Check whether the path exists
195+
if not referenced_model_path.is_file():
196+
continue
197+
referenced_model_uri = URI.from_path(referenced_model_path)
198+
199+
# Extract metadata for positioning
200+
table_meta = TokenPositionDetails.from_meta(table.this.meta)
201+
table_range = _range_from_token_position_details(table_meta, read_file)
202+
start_pos = table_range.start
203+
end_pos = table_range.end
204+
205+
# If there's a catalog or database qualifier, adjust the start position
206+
catalog_or_db = table.args.get("catalog") or table.args.get("db")
207+
if catalog_or_db is not None:
208+
catalog_or_db_meta = TokenPositionDetails.from_meta(catalog_or_db.meta)
209+
catalog_or_db_range = _range_from_token_position_details(
210+
catalog_or_db_meta, read_file
211+
)
212+
start_pos = catalog_or_db_range.start
213+
214+
description = generate_markdown_description(referenced_model)
215+
216+
references.append(
217+
Reference(
218+
uri=referenced_model_uri.value,
219+
range=Range(start=start_pos, end=end_pos),
220+
markdown_description=description,
221+
)
222+
)
189223

190224
return references
191225

tests/lsp/test_reference_cte.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import re
2+
from sqlmesh.core.context import Context
3+
from sqlmesh.lsp.context import LSPContext, ModelTarget
4+
from sqlmesh.lsp.reference import get_references
5+
from sqlmesh.lsp.uri import URI
6+
from lsprotocol.types import Range, Position
7+
import typing as t
8+
9+
10+
def test_cte_parsing():
11+
context = Context(paths=["examples/sushi"])
12+
lsp_context = LSPContext(context)
13+
14+
# Find model URIs
15+
sushi_customers_path = next(
16+
path
17+
for path, info in lsp_context.map.items()
18+
if isinstance(info, ModelTarget) and "sushi.customers" in info.names
19+
)
20+
21+
with open(sushi_customers_path, "r", encoding="utf-8") as file:
22+
read_file = file.readlines()
23+
24+
# Find position of the cte reference
25+
ranges = find_ranges_from_regex(read_file, r"current_marketing(?!_outer)")
26+
assert len(ranges) == 2
27+
position = Position(line=ranges[1].start.line, character=ranges[1].start.character + 4)
28+
references = get_references(lsp_context, URI.from_path(sushi_customers_path), position)
29+
assert len(references) == 1
30+
assert references[0].uri == URI.from_path(sushi_customers_path).value
31+
assert references[0].markdown_description is None
32+
assert (
33+
references[0].range.start.line == ranges[1].start.line
34+
) # The reference location (where we clicked)
35+
assert (
36+
references[0].target_range.start.line == ranges[0].start.line
37+
) # The CTE definition location
38+
39+
# Find the position of the current_marketing_outer reference
40+
ranges = find_ranges_from_regex(read_file, r"current_marketing_outer")
41+
assert len(ranges) == 2
42+
position = Position(line=ranges[1].start.line, character=ranges[1].start.character + 4)
43+
references = get_references(lsp_context, URI.from_path(sushi_customers_path), position)
44+
assert len(references) == 1
45+
assert references[0].uri == URI.from_path(sushi_customers_path).value
46+
assert references[0].markdown_description is None
47+
assert (
48+
references[0].range.start.line == ranges[1].start.line
49+
) # The reference location (where we clicked)
50+
assert (
51+
references[0].target_range.start.line == ranges[0].start.line
52+
) # The CTE definition location
53+
54+
55+
def find_ranges_from_regex(read_file: t.List[str], regex: str) -> t.List[Range]:
56+
"""Find all ranges in the read file that match the regex."""
57+
return [
58+
Range(
59+
start=Position(line=line_number, character=match.start()),
60+
end=Position(line=line_number, character=match.end()),
61+
)
62+
for line_number, line in enumerate(read_file)
63+
for match in [m for m in [re.search(regex, line)] if m]
64+
]

0 commit comments

Comments
 (0)