|
6 | 6 | from sqlmesh.lsp.context import LSPContext, ModelTarget, AuditTarget
|
7 | 7 | from sqlglot import exp
|
8 | 8 | from sqlmesh.lsp.description import generate_markdown_description
|
| 9 | +from sqlglot.optimizer.scope import build_scope |
9 | 10 | from sqlmesh.lsp.uri import URI
|
10 | 11 | from sqlmesh.utils.pydantic import PydanticModel
|
| 12 | +from sqlglot.optimizer.normalize_identifiers import normalize_identifiers |
11 | 13 |
|
12 | 14 |
|
13 | 15 | class Reference(PydanticModel):
|
14 | 16 | """
|
15 |
| - A reference to a model. |
| 17 | + A reference to a model or CTE. |
16 | 18 |
|
17 | 19 | Attributes:
|
18 | 20 | range: The range of the reference in the source file
|
19 | 21 | uri: The uri of the referenced model
|
20 | 22 | markdown_description: The markdown description of the referenced model
|
| 23 | + target_range: The range of the definition for go-to-definition (optional, used for CTEs) |
21 | 24 | """
|
22 | 25 |
|
23 | 26 | range: Range
|
24 | 27 | uri: str
|
25 | 28 | markdown_description: t.Optional[str] = None
|
| 29 | + target_range: t.Optional[Range] = None |
26 | 30 |
|
27 | 31 |
|
28 | 32 | def by_position(position: Position) -> t.Callable[[Reference], bool]:
|
@@ -88,6 +92,7 @@ def get_model_definitions_for_a_path(
|
88 | 92 | - Need to normalize it before matching
|
89 | 93 | - Try get_model before normalization
|
90 | 94 | - Match to models that the model refers to
|
| 95 | + - Also find CTE references within the query |
91 | 96 | """
|
92 | 97 | path = document_uri.to_path()
|
93 | 98 | if path.suffix != ".sql":
|
@@ -126,66 +131,95 @@ def get_model_definitions_for_a_path(
|
126 | 131 | # Find all possible references
|
127 | 132 | references = []
|
128 | 133 |
|
129 |
| - # Get SQL query and find all table references |
130 |
| - tables = list(query.find_all(exp.Table)) |
131 |
| - if len(tables) == 0: |
132 |
| - return [] |
133 |
| - |
134 | 134 | with open(file_path, "r", encoding="utf-8") as file:
|
135 | 135 | read_file = file.readlines()
|
136 | 136 |
|
137 |
| - for table in tables: |
138 |
| - # Normalize the table reference |
139 |
| - unaliased = table.copy() |
140 |
| - if unaliased.args.get("alias") is not None: |
141 |
| - unaliased.set("alias", None) |
142 |
| - reference_name = unaliased.sql(dialect=dialect) |
143 |
| - try: |
144 |
| - normalized_reference_name = normalize_model_name( |
145 |
| - reference_name, |
146 |
| - default_catalog=lint_context.context.default_catalog, |
147 |
| - dialect=dialect, |
148 |
| - ) |
149 |
| - if normalized_reference_name not in depends_on: |
150 |
| - continue |
151 |
| - except Exception: |
152 |
| - # Skip references that cannot be normalized |
153 |
| - continue |
154 |
| - |
155 |
| - # Get the referenced model uri |
156 |
| - referenced_model = lint_context.context.get_model( |
157 |
| - model_or_snapshot=normalized_reference_name, raise_if_missing=False |
158 |
| - ) |
159 |
| - if referenced_model is None: |
160 |
| - continue |
161 |
| - referenced_model_path = referenced_model._path |
162 |
| - # Check whether the path exists |
163 |
| - if not referenced_model_path.is_file(): |
164 |
| - continue |
165 |
| - referenced_model_uri = URI.from_path(referenced_model_path) |
166 |
| - |
167 |
| - # Extract metadata for positioning |
168 |
| - table_meta = TokenPositionDetails.from_meta(table.this.meta) |
169 |
| - table_range = _range_from_token_position_details(table_meta, read_file) |
170 |
| - start_pos = table_range.start |
171 |
| - end_pos = table_range.end |
172 |
| - |
173 |
| - # If there's a catalog or database qualifier, adjust the start position |
174 |
| - catalog_or_db = table.args.get("catalog") or table.args.get("db") |
175 |
| - if catalog_or_db is not None: |
176 |
| - catalog_or_db_meta = TokenPositionDetails.from_meta(catalog_or_db.meta) |
177 |
| - catalog_or_db_range = _range_from_token_position_details(catalog_or_db_meta, read_file) |
178 |
| - start_pos = catalog_or_db_range.start |
179 |
| - |
180 |
| - description = generate_markdown_description(referenced_model) |
181 |
| - |
182 |
| - references.append( |
183 |
| - Reference( |
184 |
| - uri=referenced_model_uri.value, |
185 |
| - range=Range(start=start_pos, end=end_pos), |
186 |
| - markdown_description=description, |
187 |
| - ) |
188 |
| - ) |
| 137 | + # Build scope tree to properly handle nested CTEs |
| 138 | + query = normalize_identifiers(query.copy(), dialect=dialect) |
| 139 | + root_scope = build_scope(query) |
| 140 | + |
| 141 | + if root_scope: |
| 142 | + # Traverse all scopes to find CTE definitions and table references |
| 143 | + for scope in root_scope.traverse(): |
| 144 | + for table in scope.tables: |
| 145 | + table_name = table.name |
| 146 | + |
| 147 | + # Check if this table reference is a CTE in the current scope |
| 148 | + if cte_scope := scope.cte_sources.get(table_name): |
| 149 | + cte = cte_scope.expression.parent |
| 150 | + alias = cte.args["alias"] |
| 151 | + if isinstance(alias, exp.TableAlias): |
| 152 | + identifier = alias.this |
| 153 | + if isinstance(identifier, exp.Identifier): |
| 154 | + target_range = _range_from_token_position_details( |
| 155 | + TokenPositionDetails.from_meta(identifier.meta), read_file |
| 156 | + ) |
| 157 | + table_range = _range_from_token_position_details( |
| 158 | + TokenPositionDetails.from_meta(table.this.meta), read_file |
| 159 | + ) |
| 160 | + references.append( |
| 161 | + Reference( |
| 162 | + uri=document_uri.value, # Same file |
| 163 | + range=table_range, |
| 164 | + target_range=target_range, |
| 165 | + ) |
| 166 | + ) |
| 167 | + continue |
| 168 | + |
| 169 | + # For non-CTE tables, process as before (external model references) |
| 170 | + # Normalize the table reference |
| 171 | + unaliased = table.copy() |
| 172 | + if unaliased.args.get("alias") is not None: |
| 173 | + unaliased.set("alias", None) |
| 174 | + reference_name = unaliased.sql(dialect=dialect) |
| 175 | + try: |
| 176 | + normalized_reference_name = normalize_model_name( |
| 177 | + reference_name, |
| 178 | + default_catalog=lint_context.context.default_catalog, |
| 179 | + dialect=dialect, |
| 180 | + ) |
| 181 | + if normalized_reference_name not in depends_on: |
| 182 | + continue |
| 183 | + except Exception: |
| 184 | + # Skip references that cannot be normalized |
| 185 | + continue |
| 186 | + |
| 187 | + # Get the referenced model uri |
| 188 | + referenced_model = lint_context.context.get_model( |
| 189 | + model_or_snapshot=normalized_reference_name, raise_if_missing=False |
| 190 | + ) |
| 191 | + if referenced_model is None: |
| 192 | + continue |
| 193 | + referenced_model_path = referenced_model._path |
| 194 | + # Check whether the path exists |
| 195 | + if not referenced_model_path.is_file(): |
| 196 | + continue |
| 197 | + referenced_model_uri = URI.from_path(referenced_model_path) |
| 198 | + |
| 199 | + # Extract metadata for positioning |
| 200 | + table_meta = TokenPositionDetails.from_meta(table.this.meta) |
| 201 | + table_range = _range_from_token_position_details(table_meta, read_file) |
| 202 | + start_pos = table_range.start |
| 203 | + end_pos = table_range.end |
| 204 | + |
| 205 | + # If there's a catalog or database qualifier, adjust the start position |
| 206 | + catalog_or_db = table.args.get("catalog") or table.args.get("db") |
| 207 | + if catalog_or_db is not None: |
| 208 | + catalog_or_db_meta = TokenPositionDetails.from_meta(catalog_or_db.meta) |
| 209 | + catalog_or_db_range = _range_from_token_position_details( |
| 210 | + catalog_or_db_meta, read_file |
| 211 | + ) |
| 212 | + start_pos = catalog_or_db_range.start |
| 213 | + |
| 214 | + description = generate_markdown_description(referenced_model) |
| 215 | + |
| 216 | + references.append( |
| 217 | + Reference( |
| 218 | + uri=referenced_model_uri.value, |
| 219 | + range=Range(start=start_pos, end=end_pos), |
| 220 | + markdown_description=description, |
| 221 | + ) |
| 222 | + ) |
189 | 223 |
|
190 | 224 | return references
|
191 | 225 |
|
|
0 commit comments