Skip to content

Commit f0f84e1

Browse files
committed
fixed typing for python >=3.8, optimized with inmutables, added unit tests for index configs
1 parent d354c11 commit f0f84e1

File tree

3 files changed

+199
-15
lines changed

3 files changed

+199
-15
lines changed

dbt/adapters/sqlserver/relation_configs/index.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from dataclasses import dataclass, field
22
from datetime import datetime, timezone
3-
from typing import Optional
3+
from typing import FrozenSet, Optional, Set, Tuple
44

55
import agate
66
from dbt.adapters.exceptions import IndexConfigError, IndexConfigNotDictError
@@ -16,6 +16,11 @@
1616
from dbt_common.utils import encoding as dbt_encoding
1717

1818

19+
# Handle datetime now for testing.
20+
def datetime_now(tz: Optional[timezone] = timezone.utc) -> datetime:
21+
return datetime.now(tz)
22+
23+
1924
# ALTERED FROM:
2025
# github.com/dbt-labs/dbt-postgres/blob/main/dbt/adapters/postgres/relation_configs/index.py
2126
class SQLServerIndexType(StrEnum):
@@ -32,7 +37,7 @@ def default(cls) -> "SQLServerIndexType":
3237

3338
@classmethod
3439
def valid_types(cls):
35-
return list(cls)
40+
return tuple(cls)
3641

3742

3843
@dataclass(frozen=True, eq=True, unsafe_hash=True)
@@ -52,17 +57,19 @@ class SQLServerIndexConfig(RelationConfigBase, RelationConfigValidationMixin, db
5257
"""
5358

5459
name: str = field(default="", hash=False, compare=False)
55-
columns: list[str] = field(default_factory=list, hash=True) # Keeping order is important
60+
columns: Tuple[str, ...] = field(
61+
default_factory=tuple, hash=True
62+
) # Keeping order is important
5663
unique: bool = field(
5764
default=False, hash=True
5865
) # Uniqueness can be a property of both clustered and nonclustered indexes.
5966
type: SQLServerIndexType = field(default=SQLServerIndexType.default(), hash=True)
60-
included_columns: frozenset[str] = field(
67+
included_columns: FrozenSet[str] = field(
6168
default_factory=frozenset, hash=True
6269
) # Keeping order is not important
6370

6471
@property
65-
def validation_rules(self) -> set[RelationConfigValidationRule]:
72+
def validation_rules(self) -> Set[RelationConfigValidationRule]:
6673
return {
6774
RelationConfigValidationRule(
6875
validation_check=True if self.columns else False,
@@ -102,7 +109,7 @@ def validation_rules(self) -> set[RelationConfigValidationRule]:
102109
def from_dict(cls, config_dict) -> "SQLServerIndexConfig":
103110
kwargs_dict = {
104111
"name": config_dict.get("name"),
105-
"columns": list(column for column in config_dict.get("columns", list())),
112+
"columns": tuple(column for column in config_dict.get("columns", tuple())),
106113
"unique": config_dict.get("unique"),
107114
"type": config_dict.get("type"),
108115
"included_columns": frozenset(
@@ -115,7 +122,7 @@ def from_dict(cls, config_dict) -> "SQLServerIndexConfig":
115122
@classmethod
116123
def parse_model_node(cls, model_node_entry: dict) -> dict:
117124
config_dict = {
118-
"columns": list(model_node_entry.get("columns", list())),
125+
"columns": tuple(model_node_entry.get("columns", tuple())),
119126
"unique": model_node_entry.get("unique"),
120127
"type": model_node_entry.get("type"),
121128
"included_columns": frozenset(model_node_entry.get("included_columns", set())),
@@ -126,7 +133,7 @@ def parse_model_node(cls, model_node_entry: dict) -> dict:
126133
def parse_relation_results(cls, relation_results_entry: agate.Row) -> dict:
127134
config_dict = {
128135
"name": relation_results_entry.get("name"),
129-
"columns": list(relation_results_entry.get("columns", "").split(",")),
136+
"columns": tuple(relation_results_entry.get("columns", "").split(",")),
130137
"unique": relation_results_entry.get("unique"),
131138
"type": relation_results_entry.get("type"),
132139
"included_columns": set(relation_results_entry.get("included_columns", "").split(",")),
@@ -139,10 +146,10 @@ def as_node_config(self) -> dict:
139146
Returns: a dictionary that can be passed into `get_create_index_sql()`
140147
"""
141148
node_config = {
142-
"columns": list(self.columns),
149+
"columns": tuple(self.columns),
143150
"unique": self.unique,
144151
"type": self.type.value,
145-
"included_columns": list(self.included_columns),
152+
"included_columns": frozenset(self.included_columns),
146153
}
147154
return node_config
148155

@@ -152,16 +159,19 @@ def render(self, relation):
152159
# https://github.com/dbt-labs/dbt-core/issues/1945#issuecomment-576714925
153160
# for an explanation.
154161

155-
now = datetime.now(timezone.utc).isoformat()
156-
inputs = self.columns + [relation.render(), str(self.unique), str(self.type), now]
162+
now = datetime_now(tz=timezone.utc).isoformat()
163+
inputs = self.columns + tuple((relation.render(), str(self.unique), str(self.type), now))
157164
string = "_".join(inputs)
165+
print(f"Actual string before MD5: {string}")
158166
return dbt_encoding.md5(string)
159167

160168
@classmethod
161169
def parse(cls, raw_index) -> Optional["SQLServerIndexConfig"]:
162170
if raw_index is None:
163171
return None
164172
try:
173+
if not isinstance(raw_index, dict):
174+
raise IndexConfigNotDictError(raw_index)
165175
cls.validate(raw_index)
166176
return cls.from_dict(raw_index)
167177
except ValidationError as exc:
@@ -202,7 +212,7 @@ def requires_full_refresh(self) -> bool:
202212
return False
203213

204214
@property
205-
def validation_rules(self) -> set[RelationConfigValidationRule]:
215+
def validation_rules(self) -> Set[RelationConfigValidationRule]:
206216
return {
207217
RelationConfigValidationRule(
208218
validation_check=self.action
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass
2-
from typing import Optional
2+
from typing import Optional, Tuple
33

44
from dbt.adapters.fabric import FabricConfigs
55

@@ -8,4 +8,4 @@
88

99
@dataclass
1010
class SQLServerConfigs(FabricConfigs):
11-
indexes: Optional[list[SQLServerIndexConfig]] = None
11+
indexes: Optional[Tuple[SQLServerIndexConfig]] = None
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
from datetime import datetime, timezone
2+
from unittest.mock import MagicMock, patch
3+
4+
import pytest
5+
from dbt.adapters.exceptions import IndexConfigError, IndexConfigNotDictError
6+
from dbt.exceptions import DbtRuntimeError
7+
from dbt_common.utils import encoding as dbt_encoding
8+
9+
from dbt.adapters.sqlserver.relation_configs.index import SQLServerIndexConfig, SQLServerIndexType
10+
11+
12+
def test_sqlserver_index_type_default():
13+
assert SQLServerIndexType.default() == SQLServerIndexType.nonclustered
14+
15+
16+
def test_sqlserver_index_type_valid_types():
17+
valid_types = SQLServerIndexType.valid_types()
18+
assert isinstance(valid_types, tuple)
19+
assert len(valid_types) > 0
20+
21+
22+
def test_sqlserver_index_config_creation():
23+
config = SQLServerIndexConfig(
24+
columns=("col1", "col2"),
25+
unique=True,
26+
type=SQLServerIndexType.nonclustered,
27+
included_columns=frozenset(["col3", "col4"]),
28+
)
29+
assert config.columns == ("col1", "col2")
30+
assert config.unique is True
31+
assert config.type == SQLServerIndexType.nonclustered
32+
assert config.included_columns == frozenset(["col3", "col4"])
33+
34+
35+
def test_sqlserver_index_config_from_dict():
36+
config_dict = {
37+
"columns": ["col1", "col2"],
38+
"unique": True,
39+
"type": "nonclustered",
40+
"included_columns": ["col3", "col4"],
41+
}
42+
config = SQLServerIndexConfig.from_dict(config_dict)
43+
assert config.columns == ("col1", "col2")
44+
assert config.unique is True
45+
assert config.type == SQLServerIndexType.nonclustered
46+
assert config.included_columns == frozenset(["col3", "col4"])
47+
48+
49+
def test_sqlserver_index_config_validation_rules():
50+
# Test valid configuration
51+
valid_config = SQLServerIndexConfig(
52+
columns=("col1", "col2"),
53+
unique=True,
54+
type=SQLServerIndexType.nonclustered,
55+
included_columns=frozenset(["col3", "col4"]),
56+
)
57+
assert len(valid_config.validation_rules) == 4
58+
for rule in valid_config.validation_rules:
59+
assert rule.validation_check is True
60+
61+
# Test invalid configurations
62+
with pytest.raises(DbtRuntimeError, match="'columns' is a required property"):
63+
SQLServerIndexConfig(columns=())
64+
65+
with pytest.raises(
66+
DbtRuntimeError,
67+
match="Non-clustered indexes are the only index types that can include extra columns",
68+
):
69+
SQLServerIndexConfig(
70+
columns=("col1",),
71+
type=SQLServerIndexType.clustered,
72+
included_columns=frozenset(["col2"]),
73+
)
74+
75+
with pytest.raises(
76+
DbtRuntimeError,
77+
match="Clustered and nonclustered indexes are the only types that can be unique",
78+
):
79+
SQLServerIndexConfig(columns=("col1",), unique=True, type=SQLServerIndexType.columnstore)
80+
81+
82+
def test_sqlserver_index_config_parse_model_node():
83+
model_node_entry = {
84+
"columns": ["col1", "col2"],
85+
"unique": True,
86+
"type": "nonclustered",
87+
"included_columns": ["col3", "col4"],
88+
}
89+
parsed_dict = SQLServerIndexConfig.parse_model_node(model_node_entry)
90+
assert parsed_dict == {
91+
"columns": ("col1", "col2"),
92+
"unique": True,
93+
"type": "nonclustered",
94+
"included_columns": frozenset(["col3", "col4"]),
95+
}
96+
97+
98+
def test_sqlserver_index_config_parse_relation_results():
99+
relation_results_entry = {
100+
"name": "index_name",
101+
"columns": "col1,col2",
102+
"unique": True,
103+
"type": "nonclustered",
104+
"included_columns": "col3,col4",
105+
}
106+
parsed_dict = SQLServerIndexConfig.parse_relation_results(relation_results_entry)
107+
assert parsed_dict == {
108+
"name": "index_name",
109+
"columns": ("col1", "col2"),
110+
"unique": True,
111+
"type": "nonclustered",
112+
"included_columns": {"col3", "col4"},
113+
}
114+
115+
116+
def test_sqlserver_index_config_as_node_config():
117+
config = SQLServerIndexConfig(
118+
columns=("col1", "col2"),
119+
unique=True,
120+
type=SQLServerIndexType.nonclustered,
121+
included_columns=frozenset(["col3", "col4"]),
122+
)
123+
node_config = config.as_node_config
124+
assert node_config == {
125+
"columns": ("col1", "col2"),
126+
"unique": True,
127+
"type": "nonclustered",
128+
"included_columns": frozenset(["col3", "col4"]),
129+
}
130+
131+
132+
FAKE_NOW = datetime(2023, 1, 1, 0, 0, 0, tzinfo=timezone.utc)
133+
134+
135+
@pytest.fixture(autouse=True)
136+
def patch_datetime_now():
137+
with patch("dbt.adapters.sqlserver.relation_configs.index.datetime_now") as mocked_datetime:
138+
mocked_datetime.return_value = FAKE_NOW
139+
yield mocked_datetime
140+
141+
142+
def test_sqlserver_index_config_render():
143+
config = SQLServerIndexConfig(
144+
columns=("col1", "col2"), unique=True, type=SQLServerIndexType.nonclustered
145+
)
146+
relation = MagicMock()
147+
relation.render.return_value = "test_relation"
148+
149+
result = config.render(relation)
150+
151+
expected_string = "col1_col2_test_relation_True_nonclustered_2023-01-01T00:00:00+00:00"
152+
153+
print(f"Expected string: {expected_string}")
154+
print(f"Actual result (MD5): {result}")
155+
print(f"Expected result (MD5): {dbt_encoding.md5(expected_string)}")
156+
157+
assert result == dbt_encoding.md5(expected_string)
158+
159+
160+
def test_sqlserver_index_config_parse():
161+
valid_raw_index = {"columns": ["col1", "col2"], "unique": True, "type": "nonclustered"}
162+
result = SQLServerIndexConfig.parse(valid_raw_index)
163+
assert isinstance(result, SQLServerIndexConfig)
164+
assert result.columns == ("col1", "col2")
165+
assert result.unique is True
166+
assert result.type == SQLServerIndexType.nonclustered
167+
168+
assert SQLServerIndexConfig.parse(None) is None
169+
170+
with pytest.raises(IndexConfigError):
171+
SQLServerIndexConfig.parse({"invalid": "config"})
172+
173+
with pytest.raises(IndexConfigNotDictError):
174+
SQLServerIndexConfig.parse("not a dict")

0 commit comments

Comments
 (0)