Skip to content

Commit ca73536

Browse files
committed
feat: add go to definition to lsp
1 parent 3271ae1 commit ca73536

File tree

7 files changed

+290
-18
lines changed

7 files changed

+290
-18
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ guard-%:
109109
fi
110110

111111
engine-%-install:
112-
pip3 install -e ".[dev,web,slack,${*}]" ./examples/custom_materializations
112+
pip3 install -e ".[dev,web,slack,lsp,${*}]" ./examples/custom_materializations
113113

114114
engine-docker-%-up:
115115
docker compose -f ./tests/core/engine_adapter/integration/docker/compose.${*}.yaml up -d

sqlmesh/lsp/context.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from collections import defaultdict
2+
from pathlib import Path
3+
from sqlmesh.core.context import Context
4+
import typing as t
5+
6+
7+
class LSPContext:
8+
"""
9+
A context that is used for linting. It contains the context and a reverse map of file uri to model names .
10+
"""
11+
12+
def __init__(self, context: Context) -> None:
13+
self.context = context
14+
map: t.Dict[str, t.List[str]] = defaultdict(list)
15+
for model in context.models.values():
16+
if model._path is not None:
17+
path = Path(model._path).resolve()
18+
map[f"file://{path.as_posix()}"].append(model.name)
19+
self.map = map

sqlmesh/lsp/main.py

Lines changed: 39 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#!/usr/bin/env python
22
"""A Language Server Protocol (LSP) server for SQL with SQLMesh integration, refactored without globals."""
33

4-
from collections import defaultdict
54
import logging
65
import typing as t
76
from pathlib import Path
@@ -12,21 +11,8 @@
1211
from sqlmesh._version import __version__
1312
from sqlmesh.core.context import Context
1413
from sqlmesh.core.linter.definition import AnnotatedRuleViolation
15-
16-
17-
class LSPContext:
18-
"""
19-
A context that is used for linting. It contains the context and a reverse map of file uri to model names .
20-
"""
21-
22-
def __init__(self, context: Context) -> None:
23-
self.context = context
24-
map: t.Dict[str, t.List[str]] = defaultdict(list)
25-
for model in context.models.values():
26-
if model._path is not None:
27-
path = Path(model._path).resolve()
28-
map[f"file://{path.as_posix()}"].append(model.name)
29-
self.map = map
14+
from sqlmesh.lsp.context import LSPContext
15+
from sqlmesh.lsp.reference import get_model_definitions_for_a_path
3016

3117

3218
class SQLMeshLanguageServer:
@@ -144,6 +130,43 @@ def formatting(
144130
ls.show_message(f"Error formatting SQL: {e}", types.MessageType.Error)
145131
return []
146132

133+
@self.server.feature(types.TEXT_DOCUMENT_DEFINITION)
134+
def goto_definition(
135+
ls: LanguageServer, params: types.DefinitionParams
136+
) -> t.List[types.LocationLink]:
137+
"""Jump to an object's definition."""
138+
try:
139+
self._ensure_context_for_document(params.text_document.uri)
140+
document = ls.workspace.get_document(params.text_document.uri)
141+
if self.lsp_context is None:
142+
raise RuntimeError(f"No context found for document: {document.path}")
143+
144+
references = get_model_definitions_for_a_path(
145+
self.lsp_context, params.text_document.uri
146+
)
147+
if not references:
148+
return []
149+
150+
return [
151+
types.LocationLink(
152+
target_uri=reference.uri,
153+
target_selection_range=types.Range(
154+
start=types.Position(line=0, character=0),
155+
end=types.Position(line=0, character=0),
156+
),
157+
target_range=types.Range(
158+
start=types.Position(line=0, character=0),
159+
end=types.Position(line=0, character=0),
160+
),
161+
origin_selection_range=reference.range,
162+
)
163+
for reference in references
164+
]
165+
166+
except Exception as e:
167+
ls.show_message(f"Error getting references: {e}", types.MessageType.Error)
168+
return []
169+
147170
def _context_get_or_load(self, document_uri: str) -> LSPContext:
148171
if self.lsp_context is None:
149172
self._ensure_context_for_document(document_uri)

sqlmesh/lsp/reference.py

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
from lsprotocol.types import Range, Position
2+
import typing as t
3+
4+
from sqlmesh.core.dialect import normalize_model_name
5+
from sqlmesh.core.model.definition import SqlModel
6+
from sqlmesh.lsp.context import LSPContext
7+
from sqlglot import exp
8+
9+
from sqlmesh.utils.pydantic import PydanticModel
10+
11+
12+
class Reference(PydanticModel):
13+
range: Range
14+
uri: str
15+
16+
17+
def get_model_definitions_for_a_path(
18+
lint_context: LSPContext, document_uri: str
19+
) -> t.List[Reference]:
20+
"""
21+
Get the model references for a given path.
22+
23+
Works for models and audits.
24+
Works for targeting sql and python models.
25+
26+
Steps:
27+
- Get the parsed query
28+
- Find all table objects using find_all exp.Table
29+
- Match the string against all model names
30+
- Need to normalize it before matching
31+
- Try get_model before normalization
32+
- Match to models that the model refers to
33+
"""
34+
# Ensure the path is a sql model
35+
if not document_uri.endswith(".sql"):
36+
return []
37+
38+
# Get the model
39+
models = lint_context.map[document_uri]
40+
if not models:
41+
return []
42+
model = lint_context.context.get_model(model_or_snapshot=models[0], raise_if_missing=False)
43+
if model is None or not isinstance(model, SqlModel):
44+
return []
45+
46+
# Find all possible references
47+
references = []
48+
tables = list(model.query.find_all(exp.Table))
49+
if len(tables) == 0:
50+
return []
51+
52+
read_file = open(model._path, "r").readlines()
53+
54+
for table in tables:
55+
depends_on = model.depends_on
56+
57+
# Normalize the table reference
58+
reference_name = table.sql(dialect=model.dialect)
59+
normalized_reference_name = normalize_model_name(
60+
reference_name,
61+
default_catalog=lint_context.context.default_catalog,
62+
dialect=model.dialect,
63+
)
64+
if normalized_reference_name not in depends_on:
65+
continue
66+
67+
# Get the referenced model uri
68+
referenced_model = lint_context.context.get_model(
69+
model_or_snapshot=normalized_reference_name, raise_if_missing=False
70+
)
71+
if referenced_model is None:
72+
continue
73+
referenced_model_path = referenced_model._path
74+
# Check whether the path exists
75+
if not referenced_model_path.is_file():
76+
continue
77+
referenced_model_uri = f"file://{referenced_model_path}"
78+
79+
# Extract metadata for positioning
80+
table_meta = TokenPositionDetails.from_meta(table.this.meta)
81+
table_range = _range_from_token_position_details(table_meta, read_file)
82+
start_pos = table_range.start
83+
end_pos = table_range.end
84+
85+
# If there's a catalog or database qualifier, adjust the start position
86+
catalog_or_db = table.args.get("catalog") or table.args.get("db")
87+
if catalog_or_db is not None:
88+
catalog_or_db_meta = TokenPositionDetails.from_meta(catalog_or_db.meta)
89+
catalog_or_db_range = _range_from_token_position_details(catalog_or_db_meta, read_file)
90+
start_pos = catalog_or_db_range.start
91+
92+
references.append(
93+
Reference(uri=referenced_model_uri, range=Range(start=start_pos, end=end_pos))
94+
)
95+
96+
return references
97+
98+
99+
class TokenPositionDetails(PydanticModel):
100+
"""
101+
Details about a token's position in the source code.
102+
103+
Attributes:
104+
line (int): The line that the token ends on.
105+
col (int): The column that the token ends on.
106+
start (int): The start index of the token.
107+
end (int): The ending index of the token.
108+
"""
109+
110+
line: int
111+
col: int
112+
start: int
113+
end: int
114+
115+
@staticmethod
116+
def from_meta(meta: t.Dict[str, int]) -> "TokenPositionDetails":
117+
return TokenPositionDetails(
118+
line=meta["line"],
119+
col=meta["col"],
120+
start=meta["start"],
121+
end=meta["end"],
122+
)
123+
124+
125+
def _range_from_token_position_details(
126+
token_position_details: TokenPositionDetails, read_file: t.List[str]
127+
) -> Range:
128+
"""
129+
Convert a TokenPositionDetails object to a Range object.
130+
131+
:param token_position_details: Details about a token's position
132+
:param read_file: List of lines from the file
133+
:return: A Range object representing the token's position
134+
"""
135+
# Convert from 1-indexed to 0-indexed for line only
136+
end_line_0 = token_position_details.line - 1
137+
end_col_0 = token_position_details.col
138+
139+
# Find the start line and column by counting backwards from the end position
140+
start_pos = token_position_details.start
141+
end_pos = token_position_details.end
142+
143+
# Initialize with the end position
144+
start_line_0 = end_line_0
145+
start_col_0 = end_col_0 - (end_pos - start_pos + 1)
146+
147+
# If start_col_0 is negative, we need to go back to previous lines
148+
while start_col_0 < 0 and start_line_0 > 0:
149+
start_line_0 -= 1
150+
start_col_0 += len(read_file[start_line_0])
151+
# Account for newline character
152+
if start_col_0 >= 0:
153+
break
154+
start_col_0 += 1 # For the newline character
155+
156+
# Ensure we don't have negative values
157+
start_col_0 = max(0, start_col_0)
158+
return Range(
159+
start=Position(line=start_line_0, character=start_col_0),
160+
end=Position(line=end_line_0, character=end_col_0),
161+
)

tests/lsp/test_context.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import pytest
2+
from sqlmesh.core.context import Context
3+
from sqlmesh.lsp.context import LSPContext
4+
5+
6+
@pytest.mark.fast
7+
def test_lsp_context():
8+
context = Context(paths=["examples/sushi"])
9+
lsp_context = LSPContext(context)
10+
11+
assert lsp_context is not None
12+
assert lsp_context.context is not None
13+
assert lsp_context.map is not None
14+
15+
# find one model in the map
16+
active_customers_key = next(
17+
key for key in lsp_context.map.keys() if key.endswith("models/active_customers.sql")
18+
)
19+
assert lsp_context.map[active_customers_key] == ["sushi.active_customers"]

tests/lsp/test_reference.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import pytest
2+
from sqlmesh.core.context import Context
3+
from sqlmesh.lsp.context import LSPContext
4+
from sqlmesh.lsp.reference import get_model_definitions_for_a_path
5+
6+
7+
@pytest.mark.fast
8+
def test_reference() -> None:
9+
context = Context(paths=["examples/sushi"])
10+
lsp_context = LSPContext(context)
11+
12+
active_customers_uri = next(
13+
uri for uri, models in lsp_context.map.items() if "sushi.active_customers" in models
14+
)
15+
sushi_customers_uri = next(
16+
uri for uri, models in lsp_context.map.items() if "sushi.customers" in models
17+
)
18+
19+
references = get_model_definitions_for_a_path(lsp_context, active_customers_uri)
20+
21+
assert len(references) == 1
22+
assert references[0].uri == sushi_customers_uri
23+
24+
# Check that the reference in the correct range is sushi.customers
25+
path = active_customers_uri.removeprefix("file://")
26+
read_file = open(path, "r").readlines()
27+
# Get the string range in the read file
28+
reference_range = references[0].range
29+
start_line = reference_range.start.line
30+
end_line = reference_range.end.line
31+
start_character = reference_range.start.character
32+
end_character = reference_range.end.character
33+
# Get the string from the file
34+
35+
# If the reference spans multiple lines, handle it accordingly
36+
if start_line == end_line:
37+
# Reference is on a single line
38+
line_content = read_file[start_line]
39+
referenced_text = line_content[start_character:end_character]
40+
else:
41+
# Reference spans multiple lines
42+
referenced_text = read_file[start_line][
43+
start_character:
44+
] # First line from start_character to end
45+
for line_num in range(start_line + 1, end_line): # Middle lines (if any)
46+
referenced_text += read_file[line_num]
47+
referenced_text += read_file[end_line][:end_character] # Last line up to end_character
48+
assert referenced_text == "sushi.customers"

vscode/extension/src/lsp/lsp.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@ export class LSPClient implements Disposable {
2727

2828
const sqlmesh = await sqlmesh_lsp_exec()
2929
if (isErr(sqlmesh)) {
30-
traceError(`Failed to get sqlmesh_lsp_exec, ${sqlmesh.error.type}`)
30+
traceError(
31+
`Failed to get sqlmesh_lsp_exec, ${JSON.stringify(sqlmesh.error)}`,
32+
)
3133
return sqlmesh
3234
}
3335
const workspaceFolders = getWorkspaceFolders()

0 commit comments

Comments
 (0)