Skip to content

Commit 4e7d6c7

Browse files
author
Fredhøi
committed
feat: Add support for Microsoft Fabric Waerhouse
1 parent d4cc990 commit 4e7d6c7

File tree

3 files changed

+259
-0
lines changed

3 files changed

+259
-0
lines changed

sqlmesh/core/config/connection.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1587,6 +1587,28 @@ def _extra_engine_config(self) -> t.Dict[str, t.Any]:
15871587
return {"catalog_support": CatalogSupport.SINGLE_CATALOG_ONLY}
15881588

15891589

1590+
class FabricWarehouseConnectionConfig(MSSQLConnectionConfig):
1591+
"""
1592+
Fabric Warehouse Connection Configuration. Inherits most settings from MSSQLConnectionConfig.
1593+
"""
1594+
1595+
type_: t.Literal["fabric_warehouse"] = Field(alias="type", default="fabric_warehouse") # type: ignore
1596+
autocommit: t.Optional[bool] = True
1597+
1598+
@property
1599+
def _engine_adapter(self) -> t.Type[EngineAdapter]:
1600+
from sqlmesh.core.engine_adapter.fabric_warehouse import FabricWarehouseAdapter
1601+
1602+
return FabricWarehouseAdapter
1603+
1604+
@property
1605+
def _extra_engine_config(self) -> t.Dict[str, t.Any]:
1606+
return {
1607+
"database": self.database,
1608+
"catalog_support": CatalogSupport.REQUIRES_SET_CATALOG,
1609+
}
1610+
1611+
15901612
class SparkConnectionConfig(ConnectionConfig):
15911613
"""
15921614
Vanilla Spark Connection Configuration. Use `DatabricksConnectionConfig` for Databricks.

sqlmesh/core/engine_adapter/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from sqlmesh.core.engine_adapter.trino import TrinoEngineAdapter
2020
from sqlmesh.core.engine_adapter.athena import AthenaEngineAdapter
2121
from sqlmesh.core.engine_adapter.risingwave import RisingwaveEngineAdapter
22+
from sqlmesh.core.engine_adapter.fabric_warehouse import FabricWarehouseAdapter
2223

2324
DIALECT_TO_ENGINE_ADAPTER = {
2425
"hive": SparkEngineAdapter,
@@ -35,6 +36,7 @@
3536
"trino": TrinoEngineAdapter,
3637
"athena": AthenaEngineAdapter,
3738
"risingwave": RisingwaveEngineAdapter,
39+
"fabric_warehouse": FabricWarehouseAdapter,
3840
}
3941

4042
DIALECT_ALIASES = {
@@ -45,9 +47,11 @@
4547
def create_engine_adapter(
4648
connection_factory: t.Callable[[], t.Any], dialect: str, **kwargs: t.Any
4749
) -> EngineAdapter:
50+
print(kwargs)
4851
dialect = dialect.lower()
4952
dialect = DIALECT_ALIASES.get(dialect, dialect)
5053
engine_adapter = DIALECT_TO_ENGINE_ADAPTER.get(dialect)
54+
print(engine_adapter)
5155
if engine_adapter is None:
5256
return EngineAdapter(connection_factory, dialect, **kwargs)
5357
if engine_adapter is EngineAdapterWithIndexSupport:
Lines changed: 233 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,233 @@
1+
from __future__ import annotations
2+
3+
import typing as t
4+
from sqlglot import exp
5+
from sqlmesh.core.engine_adapter.mssql import MSSQLEngineAdapter
6+
from sqlmesh.core.engine_adapter.shared import InsertOverwriteStrategy, SourceQuery
7+
8+
if t.TYPE_CHECKING:
9+
from sqlmesh.core._typing import SchemaName, TableName
10+
from sqlmesh.core.engine_adapter._typing import QueryOrDF
11+
12+
13+
class FabricWarehouseAdapter(MSSQLEngineAdapter):
14+
"""
15+
Adapter for Microsoft Fabric Warehouses.
16+
"""
17+
18+
DIALECT = "tsql"
19+
SUPPORTS_INDEXES = False
20+
SUPPORTS_TRANSACTIONS = False
21+
22+
INSERT_OVERWRITE_STRATEGY = InsertOverwriteStrategy.DELETE_INSERT
23+
24+
def __init__(self, *args: t.Any, **kwargs: t.Any):
25+
self.database = kwargs.get("database")
26+
27+
super().__init__(*args, **kwargs)
28+
29+
if not self.database:
30+
raise ValueError(
31+
"The 'database' parameter is required in the connection config for the FabricWarehouseAdapter."
32+
)
33+
try:
34+
self.execute(f"USE [{self.database}]")
35+
except Exception as e:
36+
raise RuntimeError(f"Failed to set database context to '{self.database}'. Reason: {e}")
37+
38+
def _get_schema_name(self, name: t.Union[TableName, SchemaName]) -> str:
39+
"""Extracts the schema name from a sqlglot object or string."""
40+
table = exp.to_table(name)
41+
schema_part = table.db
42+
43+
if isinstance(schema_part, exp.Identifier):
44+
return schema_part.name
45+
if isinstance(schema_part, str):
46+
return schema_part
47+
48+
if schema_part is None and table.this and table.this.is_identifier:
49+
return table.this.name
50+
51+
raise ValueError(f"Could not determine schema name from '{name}'")
52+
53+
def create_schema(self, schema: SchemaName) -> None:
54+
"""
55+
Creates a schema in a Microsoft Fabric Warehouse.
56+
57+
Overridden to handle Fabric's specific T-SQL requirements.
58+
T-SQL's `CREATE SCHEMA` command does not support `IF NOT EXISTS`, so this
59+
implementation first checks for the schema's existence in the
60+
`INFORMATION_SCHEMA.SCHEMATA` view.
61+
"""
62+
sql = (
63+
exp.select("1")
64+
.from_(f"{self.database}.INFORMATION_SCHEMA.SCHEMATA")
65+
.where(f"SCHEMA_NAME = '{schema}'")
66+
)
67+
if self.fetchone(sql):
68+
return
69+
self.execute(f"USE [{self.database}]")
70+
self.execute(f"CREATE SCHEMA [{schema}]")
71+
72+
def _create_table_from_columns(
73+
self,
74+
table_name: TableName,
75+
columns_to_types: t.Dict[str, exp.DataType],
76+
primary_key: t.Optional[t.Tuple[str, ...]] = None,
77+
exists: bool = True,
78+
table_description: t.Optional[str] = None,
79+
column_descriptions: t.Optional[t.Dict[str, str]] = None,
80+
**kwargs: t.Any,
81+
) -> None:
82+
"""
83+
Creates a table, ensuring the schema exists first and that all
84+
object names are fully qualified with the database.
85+
"""
86+
table_exp = exp.to_table(table_name)
87+
schema_name = self._get_schema_name(table_name)
88+
89+
self.create_schema(schema_name)
90+
91+
fully_qualified_table_name = f"[{self.database}].[{schema_name}].[{table_exp.name}]"
92+
93+
column_defs = ", ".join(
94+
f"[{col}] {kind.sql(dialect=self.dialect)}" for col, kind in columns_to_types.items()
95+
)
96+
97+
create_table_sql = f"CREATE TABLE {fully_qualified_table_name} ({column_defs})"
98+
99+
if not exists:
100+
self.execute(create_table_sql)
101+
return
102+
103+
if not self.table_exists(table_name):
104+
self.execute(create_table_sql)
105+
106+
if table_description and self.comments_enabled:
107+
qualified_table_for_comment = self._fully_qualify(table_name)
108+
self._create_table_comment(qualified_table_for_comment, table_description)
109+
if column_descriptions and self.comments_enabled:
110+
self._create_column_comments(qualified_table_for_comment, column_descriptions)
111+
112+
def table_exists(self, table_name: TableName) -> bool:
113+
"""
114+
Checks if a table exists.
115+
116+
Overridden to query the uppercase `INFORMATION_SCHEMA` required
117+
by case-sensitive Fabric environments.
118+
"""
119+
table = exp.to_table(table_name)
120+
schema = self._get_schema_name(table_name)
121+
122+
sql = (
123+
exp.select("1")
124+
.from_("INFORMATION_SCHEMA.TABLES")
125+
.where(f"TABLE_NAME = '{table.alias_or_name}'")
126+
.where(f"TABLE_SCHEMA = '{schema}'")
127+
)
128+
129+
result = self.fetchone(sql, quote_identifiers=True)
130+
131+
return result[0] == 1 if result else False
132+
133+
def _fully_qualify(self, name: t.Union[TableName, SchemaName]) -> exp.Table:
134+
"""Ensures an object name is prefixed with the configured database."""
135+
table = exp.to_table(name)
136+
return exp.Table(this=table.this, db=table.db, catalog=exp.to_identifier(self.database))
137+
138+
def create_view(
139+
self,
140+
view_name: TableName,
141+
query_or_df: QueryOrDF,
142+
columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
143+
replace: bool = True,
144+
materialized: bool = False,
145+
materialized_properties: t.Optional[t.Dict[str, t.Any]] = None,
146+
table_description: t.Optional[str] = None,
147+
column_descriptions: t.Optional[t.Dict[str, str]] = None,
148+
view_properties: t.Optional[t.Dict[str, exp.Expression]] = None,
149+
**create_kwargs: t.Any,
150+
) -> None:
151+
"""
152+
Creates a view from a query or DataFrame.
153+
154+
Overridden to ensure that the view name and all tables referenced
155+
in the source query are fully qualified with the database name,
156+
as required by Fabric.
157+
"""
158+
view_schema = self._get_schema_name(view_name)
159+
self.create_schema(view_schema)
160+
161+
qualified_view_name = self._fully_qualify(view_name)
162+
163+
if isinstance(query_or_df, exp.Expression):
164+
for table in query_or_df.find_all(exp.Table):
165+
if not table.catalog:
166+
qualified_table = self._fully_qualify(table)
167+
table.replace(qualified_table)
168+
169+
return super().create_view(
170+
qualified_view_name,
171+
query_or_df,
172+
columns_to_types,
173+
replace,
174+
materialized,
175+
table_description=table_description,
176+
column_descriptions=column_descriptions,
177+
view_properties=view_properties,
178+
**create_kwargs,
179+
)
180+
181+
def columns(
182+
self, table_name: TableName, include_pseudo_columns: bool = False
183+
) -> t.Dict[str, exp.DataType]:
184+
"""
185+
Fetches column names and types for the target table.
186+
187+
Overridden to query the uppercase `INFORMATION_SCHEMA.COLUMNS` view
188+
required by case-sensitive Fabric environments.
189+
"""
190+
table = exp.to_table(table_name)
191+
schema = self._get_schema_name(table_name)
192+
sql = (
193+
exp.select("COLUMN_NAME", "DATA_TYPE")
194+
.from_(f"{self.database}.INFORMATION_SCHEMA.COLUMNS")
195+
.where(f"TABLE_NAME = '{table.name}'")
196+
.where(f"TABLE_SCHEMA = '{schema}'")
197+
.order_by("ORDINAL_POSITION")
198+
)
199+
df = self.fetchdf(sql)
200+
return {
201+
str(row.COLUMN_NAME): exp.DataType.build(str(row.DATA_TYPE), dialect=self.dialect)
202+
for row in df.itertuples()
203+
}
204+
205+
def _insert_overwrite_by_condition(
206+
self,
207+
table_name: TableName,
208+
source_queries: t.List[SourceQuery],
209+
columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
210+
where: t.Optional[exp.Condition] = None,
211+
insert_overwrite_strategy_override: t.Optional[InsertOverwriteStrategy] = None,
212+
**kwargs: t.Any,
213+
) -> None:
214+
"""
215+
Implements the insert overwrite strategy for Fabric.
216+
217+
Overridden to enforce a `DELETE`/`INSERT` strategy, as Fabric's
218+
`MERGE` statement has limitations.
219+
"""
220+
221+
columns_to_types = columns_to_types or self.columns(table_name)
222+
223+
self.delete_from(table_name, where=where or exp.true())
224+
225+
for source_query in source_queries:
226+
with source_query as query:
227+
query = self._order_projections_and_filter(query, columns_to_types)
228+
self._insert_append_query(
229+
table_name,
230+
query,
231+
columns_to_types=columns_to_types,
232+
order_projections=False,
233+
)

0 commit comments

Comments
 (0)