Skip to content

Commit 8efbe5f

Browse files
Feat(lsp): Add support for find and go to references for Model usages (#4680)
1 parent 1611644 commit 8efbe5f

File tree

8 files changed

+776
-43
lines changed

8 files changed

+776
-43
lines changed

sqlmesh/lsp/main.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,12 @@
4444
CustomMethod,
4545
)
4646
from sqlmesh.lsp.hints import get_hints
47-
from sqlmesh.lsp.reference import get_references, get_cte_references
47+
from sqlmesh.lsp.reference import (
48+
LSPCteReference,
49+
LSPModelReference,
50+
get_references,
51+
get_all_references,
52+
)
4853
from sqlmesh.lsp.uri import URI
4954
from web.server.api.endpoints.lineage import column_lineage, model_lineage
5055
from web.server.api.endpoints.models import get_models
@@ -368,7 +373,7 @@ def hover(ls: LanguageServer, params: types.HoverParams) -> t.Optional[types.Hov
368373
if not references:
369374
return None
370375
reference = references[0]
371-
if not reference.markdown_description:
376+
if isinstance(reference, LSPCteReference) or not reference.markdown_description:
372377
return None
373378
return types.Hover(
374379
contents=types.MarkupContent(
@@ -418,8 +423,8 @@ def goto_definition(
418423
references = get_references(self.lsp_context, uri, params.position)
419424
location_links = []
420425
for reference in references:
421-
# Use target_range if available (for CTEs), otherwise default to start of file
422-
if reference.target_range:
426+
# Use target_range if available (CTEs, Macros), otherwise default to start of file
427+
if not isinstance(reference, LSPModelReference):
423428
target_range = reference.target_range
424429
target_selection_range = reference.target_range
425430
else:
@@ -449,18 +454,18 @@ def goto_definition(
449454
def find_references(
450455
ls: LanguageServer, params: types.ReferenceParams
451456
) -> t.Optional[t.List[types.Location]]:
452-
"""Find all references of a symbol (currently supporting CTEs)"""
457+
"""Find all references of a symbol (supporting CTEs, models for now)"""
453458
try:
454459
uri = URI(params.text_document.uri)
455460
self._ensure_context_for_document(uri)
456461
document = ls.workspace.get_text_document(params.text_document.uri)
457462
if self.lsp_context is None:
458463
raise RuntimeError(f"No context found for document: {document.path}")
459464

460-
cte_references = get_cte_references(self.lsp_context, uri, params.position)
465+
all_references = get_all_references(self.lsp_context, uri, params.position)
461466

462467
# Convert references to Location objects
463-
locations = [types.Location(uri=ref.uri, range=ref.range) for ref in cte_references]
468+
locations = [types.Location(uri=ref.uri, range=ref.range) for ref in all_references]
464469

465470
return locations if locations else None
466471
except Exception as e:

sqlmesh/lsp/reference.py

Lines changed: 162 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from lsprotocol.types import Range, Position
22
import typing as t
33
from pathlib import Path
4+
from pydantic import Field
45

56
from sqlmesh.core.audit import StandaloneAudit
67
from sqlmesh.core.dialect import normalize_model_name
@@ -23,21 +24,37 @@
2324
import inspect
2425

2526

26-
class Reference(PydanticModel):
27-
"""
28-
A reference to a model or CTE.
27+
class LSPModelReference(PydanticModel):
28+
"""A LSP reference to a model."""
29+
30+
type: t.Literal["model"] = "model"
31+
uri: str
32+
range: Range
33+
markdown_description: t.Optional[str] = None
2934

30-
Attributes:
31-
range: The range of the reference in the source file
32-
uri: The uri of the referenced model
33-
markdown_description: The markdown description of the referenced model
34-
target_range: The range of the definition for go-to-definition (optional, used for CTEs)
35-
"""
3635

36+
class LSPCteReference(PydanticModel):
37+
"""A LSP reference to a CTE."""
38+
39+
type: t.Literal["cte"] = "cte"
40+
uri: str
3741
range: Range
42+
target_range: Range
43+
44+
45+
class LSPMacroReference(PydanticModel):
46+
"""A LSP reference to a macro."""
47+
48+
type: t.Literal["macro"] = "macro"
3849
uri: str
50+
range: Range
51+
target_range: Range
3952
markdown_description: t.Optional[str] = None
40-
target_range: t.Optional[Range] = None
53+
54+
55+
Reference = t.Annotated[
56+
t.Union[LSPModelReference, LSPCteReference, LSPMacroReference], Field(discriminator="type")
57+
]
4158

4259

4360
def by_position(position: Position) -> t.Callable[[Reference], bool]:
@@ -136,7 +153,7 @@ def get_model_definitions_for_a_path(
136153
return []
137154

138155
# Find all possible references
139-
references = []
156+
references: t.List[Reference] = []
140157

141158
with open(file_path, "r", encoding="utf-8") as file:
142159
read_file = file.readlines()
@@ -173,7 +190,7 @@ def get_model_definitions_for_a_path(
173190
table_range = to_lsp_range(table_range_sqlmesh)
174191

175192
references.append(
176-
Reference(
193+
LSPCteReference(
177194
uri=document_uri.value, # Same file
178195
range=table_range,
179196
target_range=target_range,
@@ -227,7 +244,7 @@ def get_model_definitions_for_a_path(
227244
description = generate_markdown_description(referenced_model)
228245

229246
references.append(
230-
Reference(
247+
LSPModelReference(
231248
uri=referenced_model_uri.value,
232249
range=Range(
233250
start=to_lsp_position(start_pos_sqlmesh),
@@ -286,7 +303,7 @@ def get_macro_definitions_for_a_path(
286303
return []
287304

288305
references = []
289-
config_for_model, config_path = lsp_context.context.config_for_path(
306+
_, config_path = lsp_context.context.config_for_path(
290307
file_path,
291308
)
292309

@@ -372,7 +389,7 @@ def get_macro_reference(
372389
# Create a reference to the macro definition
373390
macro_uri = URI.from_path(path)
374391

375-
return Reference(
392+
return LSPMacroReference(
376393
uri=macro_uri.value,
377394
range=to_lsp_range(macro_range),
378395
target_range=Range(
@@ -405,7 +422,7 @@ def get_built_in_macro_reference(macro_name: str, macro_range: Range) -> t.Optio
405422
# Calculate the end line number by counting the number of source lines
406423
end_line_number = line_number + len(source_lines) - 1
407424

408-
return Reference(
425+
return LSPMacroReference(
409426
uri=URI.from_path(Path(filename)).value,
410427
range=macro_range,
411428
target_range=Range(
@@ -416,9 +433,99 @@ def get_built_in_macro_reference(macro_name: str, macro_range: Range) -> t.Optio
416433
)
417434

418435

436+
def get_model_find_all_references(
437+
lint_context: LSPContext, document_uri: URI, position: Position
438+
) -> t.List[LSPModelReference]:
439+
"""
440+
Get all references to a model across the entire project.
441+
442+
This function finds all usages of a model in other files by searching through
443+
all models in the project and checking their dependencies.
444+
445+
Args:
446+
lint_context: The LSP context
447+
document_uri: The URI of the document
448+
position: The position to check for model references
449+
450+
Returns:
451+
A list of references to the model across all files
452+
"""
453+
# First, get the references in the current file to determine what model we're looking for
454+
current_file_references = [
455+
ref
456+
for ref in get_model_definitions_for_a_path(lint_context, document_uri)
457+
if isinstance(ref, LSPModelReference)
458+
]
459+
460+
# Find the model reference at the cursor position
461+
target_model_uri: t.Optional[str] = None
462+
for ref in current_file_references:
463+
if _position_within_range(position, ref.range):
464+
# This is a model reference, get the target model URI
465+
target_model_uri = ref.uri
466+
break
467+
468+
if target_model_uri is None:
469+
return []
470+
471+
# Start with the model definition
472+
all_references: t.List[LSPModelReference] = [
473+
LSPModelReference(
474+
uri=ref.uri,
475+
range=Range(
476+
start=Position(line=0, character=0),
477+
end=Position(line=0, character=0),
478+
),
479+
markdown_description=ref.markdown_description,
480+
)
481+
]
482+
483+
# Then add the original reference
484+
for ref in current_file_references:
485+
if ref.uri == target_model_uri and isinstance(ref, LSPModelReference):
486+
all_references.append(
487+
LSPModelReference(
488+
uri=document_uri.value,
489+
range=ref.range,
490+
markdown_description=ref.markdown_description,
491+
)
492+
)
493+
494+
# Search through the models in the project
495+
for path, target in lint_context.map.items():
496+
if not isinstance(target, (ModelTarget, AuditTarget)):
497+
continue
498+
499+
file_uri = URI.from_path(path)
500+
501+
# Skip current file, already processed
502+
if file_uri.value == document_uri.value:
503+
continue
504+
505+
# Get model references for this file
506+
file_references = [
507+
ref
508+
for ref in get_model_definitions_for_a_path(lint_context, file_uri)
509+
if isinstance(ref, LSPModelReference)
510+
]
511+
512+
# Add references that point to the target model file
513+
for ref in file_references:
514+
if ref.uri == target_model_uri and isinstance(ref, LSPModelReference):
515+
all_references.append(
516+
LSPModelReference(
517+
uri=file_uri.value,
518+
range=ref.range,
519+
markdown_description=ref.markdown_description,
520+
)
521+
)
522+
523+
return all_references
524+
525+
419526
def get_cte_references(
420527
lint_context: LSPContext, document_uri: URI, position: Position
421-
) -> t.List[Reference]:
528+
) -> t.List[LSPCteReference]:
422529
"""
423530
Get all references to a CTE at a specific position in a document.
424531
@@ -432,12 +539,12 @@ def get_cte_references(
432539
Returns:
433540
A list of references to the CTE (including its definition and all usages)
434541
"""
435-
references = get_model_definitions_for_a_path(lint_context, document_uri)
436542

437-
# Filter for CTE references (those with target_range set and same URI)
438-
# TODO: Consider extending Reference class to explicitly indicate reference type instead
439-
cte_references = [
440-
ref for ref in references if ref.target_range is not None and ref.uri == document_uri.value
543+
# Filter to get the CTE references
544+
cte_references: t.List[LSPCteReference] = [
545+
ref
546+
for ref in get_model_definitions_for_a_path(lint_context, document_uri)
547+
if isinstance(ref, LSPCteReference)
441548
]
442549

443550
if not cte_references:
@@ -450,7 +557,7 @@ def get_cte_references(
450557
target_cte_definition_range = ref.target_range
451558
break
452559
# Check if cursor is on the CTE definition
453-
elif ref.target_range and _position_within_range(position, ref.target_range):
560+
elif _position_within_range(position, ref.target_range):
454561
target_cte_definition_range = ref.target_range
455562
break
456563

@@ -459,27 +566,55 @@ def get_cte_references(
459566

460567
# Add the CTE definition
461568
matching_references = [
462-
Reference(
569+
LSPCteReference(
463570
uri=document_uri.value,
464571
range=target_cte_definition_range,
465-
markdown_description="CTE definition",
572+
target_range=target_cte_definition_range,
466573
)
467574
]
468575

469576
# Add all usages
470577
for ref in cte_references:
471578
if ref.target_range == target_cte_definition_range:
472579
matching_references.append(
473-
Reference(
580+
LSPCteReference(
474581
uri=document_uri.value,
475582
range=ref.range,
476-
markdown_description="CTE usage",
583+
target_range=ref.target_range,
477584
)
478585
)
479586

480587
return matching_references
481588

482589

590+
def get_all_references(
591+
lint_context: LSPContext, document_uri: URI, position: Position
592+
) -> t.Sequence[Reference]:
593+
"""
594+
Get all references of a symbol at a specific position in a document.
595+
596+
This function determines the type of reference (CTE, model for now) at the cursor
597+
position and returns all references to that symbol across the project.
598+
599+
Args:
600+
lint_context: The LSP context
601+
document_uri: The URI of the document
602+
position: The position to check for references
603+
604+
Returns:
605+
A list of references to the symbol at the given position
606+
"""
607+
# First try CTE references (within same file)
608+
if cte_references := get_cte_references(lint_context, document_uri, position):
609+
return cte_references
610+
611+
# Then try model references (across files)
612+
if model_references := get_model_find_all_references(lint_context, document_uri, position):
613+
return model_references
614+
615+
return []
616+
617+
483618
def _position_within_range(position: Position, range: Range) -> bool:
484619
"""Check if a position is within a given range."""
485620
return (

tests/lsp/test_reference.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from lsprotocol.types import Position
22
from sqlmesh.core.context import Context
33
from sqlmesh.lsp.context import LSPContext, ModelTarget, AuditTarget
4-
from sqlmesh.lsp.reference import get_model_definitions_for_a_path, by_position
4+
from sqlmesh.lsp.reference import LSPModelReference, get_model_definitions_for_a_path, by_position
55
from sqlmesh.lsp.uri import URI
66

77

@@ -47,9 +47,13 @@ def test_reference_with_alias() -> None:
4747
if isinstance(info, ModelTarget) and "sushi.waiter_revenue_by_day" in info.names
4848
)
4949

50-
references = get_model_definitions_for_a_path(
51-
lsp_context, URI.from_path(waiter_revenue_by_day_path)
52-
)
50+
references = [
51+
ref
52+
for ref in get_model_definitions_for_a_path(
53+
lsp_context, URI.from_path(waiter_revenue_by_day_path)
54+
)
55+
if isinstance(ref, LSPModelReference)
56+
]
5357
assert len(references) == 3
5458

5559
with open(waiter_revenue_by_day_path, "r") as file:

0 commit comments

Comments
 (0)