diff --git a/.gitignore b/.gitignore index d11d5ddc..7dd03f5f 100644 --- a/.gitignore +++ b/.gitignore @@ -84,6 +84,9 @@ target/ *.swp *.swo +# vscode +.vscode/ + # Mypy cache .mypy_cache/ diff --git a/dbt/adapters/sqlserver/relation_configs/__init__.py b/dbt/adapters/sqlserver/relation_configs/__init__.py index b93c52a0..37e7986e 100644 --- a/dbt/adapters/sqlserver/relation_configs/__init__.py +++ b/dbt/adapters/sqlserver/relation_configs/__init__.py @@ -1,3 +1,8 @@ +from dbt.adapters.sqlserver.relation_configs.index import ( + SQLServerIndexConfig, + SQLServerIndexConfigChange, + SQLServerIndexType, +) from dbt.adapters.sqlserver.relation_configs.policies import ( MAX_CHARACTERS_IN_IDENTIFIER, SQLServerIncludePolicy, @@ -10,4 +15,7 @@ "SQLServerIncludePolicy", "SQLServerQuotePolicy", "SQLServerRelationType", + "SQLServerIndexType", + "SQLServerIndexConfig", + "SQLServerIndexConfigChange", ] diff --git a/dbt/adapters/sqlserver/relation_configs/index.py b/dbt/adapters/sqlserver/relation_configs/index.py new file mode 100644 index 00000000..029448c0 --- /dev/null +++ b/dbt/adapters/sqlserver/relation_configs/index.py @@ -0,0 +1,241 @@ +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import FrozenSet, Optional, Set, Tuple + +import agate +from dbt.adapters.exceptions import IndexConfigError, IndexConfigNotDictError +from dbt.adapters.relation_configs import ( + RelationConfigBase, + RelationConfigChange, + RelationConfigChangeAction, + RelationConfigValidationMixin, + RelationConfigValidationRule, +) +from dbt_common.dataclass_schema import StrEnum, ValidationError, dbtClassMixin +from dbt_common.exceptions import DbtRuntimeError +from dbt_common.utils import encoding as dbt_encoding + + +# Handle datetime now for testing. +def datetime_now(tz: Optional[timezone] = timezone.utc) -> datetime: + return datetime.now(tz) + + +# ALTERED FROM: +# github.com/dbt-labs/dbt-postgres/blob/main/dbt/adapters/postgres/relation_configs/index.py +class SQLServerIndexType(StrEnum): + # btree = "btree" #All SQL Server common indexes are B-tree indexes + # hash = "hash" #A hash index can exist only on a memory-optimized table. + # TODO Implement memory optimized table materialization. + clustered = "clustered" # Cant't have included columns + nonclustered = "nonclustered" + columnstore = "columnstore" # Cant't have included columns or unique config + + @classmethod + def default(cls) -> "SQLServerIndexType": + return cls("nonclustered") + + @classmethod + def valid_types(cls): + return tuple(cls) + + +@dataclass(frozen=True, eq=True, unsafe_hash=True) +class SQLServerIndexConfig(RelationConfigBase, RelationConfigValidationMixin, dbtClassMixin): + """ + This config follows the specs found here: + + https://learn.microsoft.com/en-us/sql/t-sql/statements/create-index-transact-sql + + The following parameters are configurable by dbt: + - name: the name of the index in the database, isn't predictable since we apply a timestamp + - unique: checks for duplicate values when the index is created and on data updates + - type: the index type method to be used + - columns: the columns names in the index + - included_columns: the extra included columns names in the index + + """ + + name: str = field(default="", hash=False, compare=False) + columns: Tuple[str, ...] = field( + default_factory=tuple, hash=True + ) # Keeping order is important + unique: bool = field( + default=False, hash=True + ) # Uniqueness can be a property of both clustered and nonclustered indexes. + type: SQLServerIndexType = field(default=SQLServerIndexType.default(), hash=True) + included_columns: FrozenSet[str] = field( + default_factory=frozenset, hash=True + ) # Keeping order is not important + + @property + def validation_rules(self) -> Set[RelationConfigValidationRule]: + return { + RelationConfigValidationRule( + validation_check=True if self.columns else False, + validation_error=DbtRuntimeError("'columns' is a required property"), + ), + RelationConfigValidationRule( + validation_check=( + True + if not self.included_columns + else self.type == SQLServerIndexType.nonclustered + ), + validation_error=DbtRuntimeError( + "Non-clustered indexes are the only index types that can include extra columns" + ), + ), + RelationConfigValidationRule( + validation_check=( + True + if not self.unique + else self.type + in (SQLServerIndexType.clustered, SQLServerIndexType.nonclustered) + ), + validation_error=DbtRuntimeError( + "Clustered and nonclustered indexes are the only types that can be unique" + ), + ), + RelationConfigValidationRule( + validation_check=True if self.type in SQLServerIndexType.valid_types() else False, + validation_error=DbtRuntimeError( + f"Invalid index type: {self.type}, valid types:" + + f"{SQLServerIndexType.valid_types()}" + ), + ), + } + + @classmethod + def from_dict(cls, config_dict) -> "SQLServerIndexConfig": + kwargs_dict = { + "name": config_dict.get("name"), + "columns": tuple(column for column in config_dict.get("columns", tuple())), + "unique": config_dict.get("unique"), + "type": config_dict.get("type"), + "included_columns": frozenset( + column for column in config_dict.get("included_columns", set()) + ), + } + index: "SQLServerIndexConfig" = super().from_dict(kwargs_dict) # type: ignore + return index + + @classmethod + def parse_model_node(cls, model_node_entry: dict) -> dict: + config_dict = { + "columns": tuple(model_node_entry.get("columns", tuple())), + "unique": model_node_entry.get("unique"), + "type": model_node_entry.get("type"), + "included_columns": frozenset(model_node_entry.get("included_columns", set())), + } + return config_dict + + @classmethod + def parse_relation_results(cls, relation_results_entry: agate.Row) -> dict: + config_dict = { + "name": relation_results_entry.get("name"), + "columns": tuple(relation_results_entry.get("columns", "").split(",")), + "unique": relation_results_entry.get("unique"), + "type": relation_results_entry.get("type"), + "included_columns": set(relation_results_entry.get("included_columns", "").split(",")), + } + return config_dict + + @property + def as_node_config(self) -> dict: + """ + Returns: a dictionary that can be passed into `get_create_index_sql()` + """ + node_config = { + "columns": tuple(self.columns), + "unique": self.unique, + "type": self.type.value, + "included_columns": frozenset(self.included_columns), + } + return node_config + + def render(self, relation): + # We append the current timestamp to the index name because otherwise + # the index will only be created on every other run. See + # https://github.com/dbt-labs/dbt-core/issues/1945#issuecomment-576714925 + # for an explanation. + + now = datetime_now(tz=timezone.utc).isoformat() + inputs = self.columns + tuple((relation.render(), str(self.unique), str(self.type), now)) + string = "_".join(inputs) + print(f"Actual string before MD5: {string}") + return dbt_encoding.md5(string) + + @classmethod + def parse(cls, raw_index) -> Optional["SQLServerIndexConfig"]: + if raw_index is None: + return None + try: + if not isinstance(raw_index, dict): + raise IndexConfigNotDictError(raw_index) + cls.validate(raw_index) + return cls.from_dict(raw_index) + except ValidationError as exc: + raise IndexConfigError(exc) + except TypeError: + raise IndexConfigNotDictError(raw_index) + + +@dataclass(frozen=True, eq=True, unsafe_hash=True) +class SQLServerIndexConfigChange(RelationConfigChange, RelationConfigValidationMixin): + """ + Example of an index change: + { + "action": "create", + "context": { + "name": "", # we don't know the name since it gets created as a hash at runtime + "columns": ["column_1", "column_3"], + "type": "clustered", + "unique": True + } + }, + { + "action": "drop", + "context": { + "name": "index_abc", # we only need this to drop, but we need the rest to compare + "columns": ["column_1"], + "type": "nonclustered", + "unique": True + } + } + """ + + # TODO: Implement the change actions on the adapter + context: SQLServerIndexConfig + + @property + def requires_full_refresh(self) -> bool: + return False + + @property + def validation_rules(self) -> Set[RelationConfigValidationRule]: + return { + RelationConfigValidationRule( + validation_check=self.action + in {RelationConfigChangeAction.create, RelationConfigChangeAction.drop}, + validation_error=DbtRuntimeError( + "Invalid operation, only `drop` and `create` are supported for indexes." + ), + ), + RelationConfigValidationRule( + validation_check=not ( + self.action == RelationConfigChangeAction.drop and self.context.name is None + ), + validation_error=DbtRuntimeError( + "Invalid operation, attempting to drop an index with no name." + ), + ), + RelationConfigValidationRule( + validation_check=not ( + self.action == RelationConfigChangeAction.create + and self.context.columns == set() + ), + validation_error=DbtRuntimeError( + "Invalid operations, attempting to create an index with no columns." + ), + ), + } diff --git a/dbt/adapters/sqlserver/sqlserver_adapter.py b/dbt/adapters/sqlserver/sqlserver_adapter.py index 6f05c501..9e24e74a 100644 --- a/dbt/adapters/sqlserver/sqlserver_adapter.py +++ b/dbt/adapters/sqlserver/sqlserver_adapter.py @@ -1,11 +1,13 @@ -from typing import Optional +from typing import Any, Optional import dbt.exceptions -from dbt.adapters.base.impl import ConstraintSupport +from dbt.adapters.base import ConstraintSupport, available from dbt.adapters.fabric import FabricAdapter from dbt.contracts.graph.nodes import ConstraintType +from dbt.adapters.sqlserver.relation_configs import SQLServerIndexConfig from dbt.adapters.sqlserver.sqlserver_column import SQLServerColumn +from dbt.adapters.sqlserver.sqlserver_configs import SQLServerConfigs from dbt.adapters.sqlserver.sqlserver_connections import SQLServerConnectionManager from dbt.adapters.sqlserver.sqlserver_relation import SQLServerRelation @@ -18,6 +20,7 @@ class SQLServerAdapter(FabricAdapter): ConnectionManager = SQLServerConnectionManager Column = SQLServerColumn Relation = SQLServerRelation + AdapterSpecificConfigs = SQLServerConfigs CONSTRAINT_SUPPORT = { ConstraintType.check: ConstraintSupport.ENFORCED, @@ -60,6 +63,10 @@ def render_model_constraint(cls, constraint) -> Optional[str]: def date_function(cls): return "getdate()" + @available + def parse_index(self, raw_index: Any) -> Optional[SQLServerIndexConfig]: + return SQLServerIndexConfig.parse(raw_index) + def valid_incremental_strategies(self): """The set of standard builtin strategies which this adapter supports out-of-the-box. Not used to validate custom strategies defined by end users. diff --git a/dbt/adapters/sqlserver/sqlserver_configs.py b/dbt/adapters/sqlserver/sqlserver_configs.py index 35ce4262..1cdbc8cd 100644 --- a/dbt/adapters/sqlserver/sqlserver_configs.py +++ b/dbt/adapters/sqlserver/sqlserver_configs.py @@ -1,8 +1,11 @@ from dataclasses import dataclass +from typing import Optional, Tuple from dbt.adapters.fabric import FabricConfigs +from dbt.adapters.sqlserver.relation_configs import SQLServerIndexConfig + @dataclass class SQLServerConfigs(FabricConfigs): - pass + indexes: Optional[Tuple[SQLServerIndexConfig]] = None diff --git a/dbt/include/sqlserver/macros/adapter/indexes.sql b/dbt/include/sqlserver/macros/adapter/indexes.sql index 33fa6cfe..97e79414 100644 --- a/dbt/include/sqlserver/macros/adapter/indexes.sql +++ b/dbt/include/sqlserver/macros/adapter/indexes.sql @@ -168,3 +168,21 @@ {% endif %} end {% endmacro %} + + +{% macro sqlserver__get_create_index_sql(relation, index_dict) -%} + {%- set index_config = adapter.parse_index(index_dict) -%} + {%- set comma_separated_columns = ", ".join(index_config.columns) -%} + {%- set index_name = index_config.render(relation) -%} + + {# Validations are made on the adapter class SQLServerIndexConfig to control resulting sql #} + create + {% if index_config.unique -%} unique {% endif %}{{ index_config.type }} + index "{{ index_name }}" + on {{ relation }} + ({{ comma_separated_columns }}) + {% if index_config.included_columns -%} + include ({{ ", ".join(index_config.included_columns) }}) + {% endif %} + +{%- endmacro %} diff --git a/tests/functional/adapter/mssql/test_index_config.py b/tests/functional/adapter/mssql/test_index_config.py new file mode 100644 index 00000000..2b62507d --- /dev/null +++ b/tests/functional/adapter/mssql/test_index_config.py @@ -0,0 +1,497 @@ +import re + +import pytest +from dbt.tests.util import run_dbt, run_dbt_and_capture + +base_validation = """ +with base_query AS ( +select i.[name] as index_name, + substring(column_names, 1, len(column_names)-1) as [columns], + substring(included_column_names, 1, len(included_column_names)-1) as included_columns, + case when i.[type] = 1 then 'clustered' + when i.[type] = 2 then 'nonclustered' + when i.[type] = 3 then 'xml' + when i.[type] = 4 then 'spatial' + when i.[type] = 5 then 'clustered columnstore' + when i.[type] = 6 then 'nonclustered columnstore' + when i.[type] = 7 then 'nonclustered hash' + end as index_type, + case when i.is_unique = 1 then 'Unique' + else 'Not unique' end as [unique], + schema_name(t.schema_id) + '.' + t.[name] as table_view, + case when t.[type] = 'U' then 'Table' + when t.[type] = 'V' then 'View' + end as [object_type], + s.name as schema_name +from sys.objects t + inner join sys.schemas s + on + t.schema_id = s.schema_id + inner join sys.indexes i + on t.object_id = i.object_id + cross apply (select col.[name] + ', ' + from sys.index_columns ic + inner join sys.columns col + on ic.object_id = col.object_id + and ic.column_id = col.column_id + where ic.object_id = t.object_id + and ic.index_id = i.index_id + and ic.is_included_column = 0 + order by key_ordinal + for xml path ('') ) D (column_names) + cross apply (select col.[name] + ', ' + from sys.index_columns ic + inner join sys.columns col + on ic.object_id = col.object_id + and ic.column_id = col.column_id + where ic.object_id = t.object_id + and ic.index_id = i.index_id + and ic.is_included_column = 1 + order by key_ordinal + for xml path ('') ) E (included_column_names) +where t.is_ms_shipped <> 1 +and index_id > 0 +) +""" + +index_count = ( + base_validation + + """ +select + index_type + case when [unique] = 'Unique' then ' unique' else '' end as index_type, + count(*) index_count +from + base_query +WHERE + schema_name='{schema_name}' +group by index_type + case when [unique] = 'Unique' then ' unique' else '' end +""" +) + +indexes_def = ( + base_validation + + """ +SELECT + index_name, + [columns], + [included_columns], + index_type, + [unique], + table_view, + [object_type], + schema_name +FROM + base_query +WHERE + schema_name='{schema_name}' + AND + table_view='{schema_name}.{table_name}' + +""" +) + +# Altered from: https://github.com/dbt-labs/dbt-postgres + +models__incremental_sql = """ +{{ + config( + materialized = "incremental", + as_columnstore = False, + indexes=[ + {'columns': ['column_a'], 'type': 'nonclustered'}, + {'columns': ['column_a', 'column_b'], 'unique': True}, + ] + ) +}} + +select * +from ( + select 1 as column_a, 2 as column_b +) t + +{% if is_incremental() %} + where column_a > (select max(column_a) from {{this}}) +{% endif %} + +""" + +models__columnstore_sql = """ +{{ + config( + materialized = "incremental", + as_columnstore = False, + indexes=[ + {'columns': ['column_a'], 'type': 'columnstore'}, + ] + ) +}} + +select * +from ( + select 1 as column_a, 2 as column_b +) t + +{% if is_incremental() %} + where column_a > (select max(column_a) from {{this}}) +{% endif %} + +""" + + +models__table_sql = """ +{{ + config( + materialized = "table", + as_columnstore = False, + indexes=[ + {'columns': ['column_a']}, + {'columns': ['column_b']}, + {'columns': ['column_a', 'column_b']}, + {'columns': ['column_b', 'column_a'], 'type': 'clustered', 'unique': True}, + {'columns': ['column_a','column_c'], + 'type': 'nonclustered', + 'included_columns': ['column_b']}, + ] + ) +}} + +select 1 as column_a, 2 as column_b, 3 as column_c + +""" + + +models__table_included_sql = """ +{{ + config( + materialized = "table", + as_columnstore = False, + indexes=[ + {'columns': ['column_a'], 'included_columns': ['column_b']}, + {'columns': ['column_b'], 'type': 'clustered'} + ] + ) +}} + +select 1 as column_a, 2 as column_b + +""" + +models_invalid__invalid_columns_type_sql = """ +{{ + config( + materialized = "table", + indexes=[ + {'columns': 'column_a, column_b'}, + ] + ) +}} + +select 1 as column_a, 2 as column_b + +""" + +models_invalid__invalid_type_sql = """ +{{ + config( + materialized = "table", + indexes=[ + {'columns': ['column_a'], 'type': 'non_existent_type'}, + ] + ) +}} + +select 1 as column_a, 2 as column_b + +""" + +models_invalid__invalid_unique_config_sql = """ +{{ + config( + materialized = "table", + indexes=[ + {'columns': ['column_a'], 'unique': 'yes'}, + ] + ) +}} + +select 1 as column_a, 2 as column_b + +""" + +models_invalid__missing_columns_sql = """ +{{ + config( + materialized = "table", + indexes=[ + {'unique': True}, + ] + ) +}} + +select 1 as column_a, 2 as column_b + +""" + +snapshots__colors_sql = """ +{% snapshot colors %} + + {{ + config( + target_database=database, + target_schema=schema, + as_columnstore=False, + unique_key='id', + strategy='check', + check_cols=['color'], + indexes=[ + {'columns': ['id'], 'type': 'nonclustered'}, + {'columns': ['id', 'color'], 'unique': True}, + ] + ) + }} + + {% if var('version') == 1 %} + + select 1 as id, 'red' as color union all + select 2 as id, 'green' as color + + {% else %} + + select 1 as id, 'blue' as color union all + select 2 as id, 'green' as color + + {% endif %} + +{% endsnapshot %} + +""" + +seeds__seed_csv = """country_code,country_name +US,United States +CA,Canada +GB,United Kingdom +""" + + +class TestSQLServerIndex: + @pytest.fixture(scope="class") + def models(self): + return { + "table.sql": models__table_sql, + "incremental.sql": models__incremental_sql, + "columnstore.sql": models__columnstore_sql, + "table_included.sql": models__table_included_sql, + } + + @pytest.fixture(scope="class") + def seeds(self): + return {"seed.csv": seeds__seed_csv} + + @pytest.fixture(scope="class") + def snapshots(self): + return {"colors.sql": snapshots__colors_sql} + + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "config-version": 2, + "seeds": { + "quote_columns": False, + "indexes": [ + {"columns": ["country_code"], "unique": False}, + { + "columns": ["country_code", "country_name"], + "unique": True, + "type": "clustered", + }, + ], + }, + "vars": { + "version": 1, + }, + } + + def test_table(self, project, unique_schema): + results = run_dbt(["run", "--models", "table"]) + assert len(results) == 1 + + indexes = self.get_indexes("table", project, unique_schema) + indexes = self.sort_indexes(indexes) + expected = [ + { + "columns": "column_a", + "unique": False, + "type": "nonclustered", + "included_columns": None, + }, + { + "columns": "column_a, column_b", + "unique": False, + "type": "nonclustered", + "included_columns": None, + }, + { + "columns": "column_a, column_c", + "unique": False, + "type": "nonclustered", + "included_columns": "column_b", + }, + { + "columns": "column_b", + "unique": False, + "type": "nonclustered", + "included_columns": None, + }, + { + "columns": "column_b, column_a", + "unique": True, + "type": "clustered", + "included_columns": None, + }, + ] + assert indexes == expected + + def test_table_included(self, project, unique_schema): + results = run_dbt(["run", "--models", "table_included"]) + assert len(results) == 1 + + indexes = self.get_indexes("table_included", project, unique_schema) + indexes = self.sort_indexes(indexes) + expected = [ + { + "columns": "column_a", + "unique": False, + "type": "nonclustered", + "included_columns": "column_b", + }, + { + "columns": "column_b", + "unique": False, + "type": "clustered", + "included_columns": None, + }, + ] + assert indexes == expected + + def test_incremental(self, project, unique_schema): + for additional_argument in [[], [], ["--full-refresh"]]: + results = run_dbt(["run", "--models", "incremental"] + additional_argument) + assert len(results) == 1 + + indexes = self.get_indexes("incremental", project, unique_schema) + indexes = self.sort_indexes(indexes) + expected = [ + { + "columns": "column_a", + "unique": False, + "type": "nonclustered", + "included_columns": None, + }, + { + "columns": "column_a, column_b", + "unique": True, + "type": "nonclustered", + "included_columns": None, + }, + ] + assert indexes == expected + + def test_columnstore(self, project, unique_schema): + for additional_argument in [[], [], ["--full-refresh"]]: + results = run_dbt(["run", "--models", "columnstore"] + additional_argument) + assert len(results) == 1 + + indexes = self.get_indexes("columnstore", project, unique_schema) + expected = [ + { + "columns": "column_a", + "unique": False, + "type": "columnstore", + "included_columns": None, + }, + ] + assert len(indexes) == len( + expected + ) # Nonclustered columnstore indexes meta is different + + def test_seed(self, project, unique_schema): + for additional_argument in [[], [], ["--full-refresh"]]: + results = run_dbt(["seed"] + additional_argument) + assert len(results) == 1 + + indexes = self.get_indexes("seed", project, unique_schema) + indexes = self.sort_indexes(indexes) + expected = [ + { + "columns": "country_code", + "unique": False, + "type": "nonclustered", + "included_columns": None, + }, + { + "columns": "country_code, country_name", + "unique": True, + "type": "clustered", + "included_columns": None, + }, + ] + assert indexes == expected + + def test_snapshot(self, project, unique_schema): + for version in [1, 2]: + results = run_dbt(["snapshot", "--vars", f"version: {version}"]) + assert len(results) == 1 + + indexes = self.get_indexes("colors", project, unique_schema) + indexes = self.sort_indexes(indexes) + expected = [ + { + "columns": "id", + "unique": False, + "type": "nonclustered", + "included_columns": None, + }, + { + "columns": "id, color", + "unique": True, + "type": "nonclustered", + "included_columns": None, + }, + ] + assert indexes == expected + + def get_indexes(self, table_name, project, unique_schema): + sql = indexes_def.format(schema_name=unique_schema, table_name=table_name) + results = project.run_sql(sql, fetch="all") + return [self.index_definition_dict(row) for row in results] + + def index_definition_dict(self, index_definition): + is_unique = index_definition[4] == "Unique" + return { + "columns": index_definition[1], + "included_columns": index_definition[2], + "unique": is_unique, + "type": index_definition[3], + } + + def sort_indexes(self, indexes): + return sorted(indexes, key=lambda x: (x["columns"], x["type"])) + + +class TestSQLServerInvalidIndex: + @pytest.fixture(scope="class") + def models(self): + return { + "invalid_unique_config.sql": models_invalid__invalid_unique_config_sql, + "invalid_type.sql": models_invalid__invalid_type_sql, + "invalid_columns_type.sql": models_invalid__invalid_columns_type_sql, + "missing_columns.sql": models_invalid__missing_columns_sql, + } + + def test_invalid_index_configs(self, project): + results, output = run_dbt_and_capture(expect_pass=False) + assert len(results) == 4 + assert re.search(r"columns.*is not of type 'array'", output) + assert re.search(r"unique.*is not of type 'boolean'", output) + assert re.search(r"'columns' is a required property", output) + assert re.search(r"'non_existent_type'.*is not one of", output) diff --git a/tests/functional/adapter/mssql/test_index.py b/tests/functional/adapter/mssql/test_index_macros.py similarity index 70% rename from tests/functional/adapter/mssql/test_index.py rename to tests/functional/adapter/mssql/test_index_macros.py index 9b9588a7..855c32fe 100644 --- a/tests/functional/adapter/mssql/test_index.py +++ b/tests/functional/adapter/mssql/test_index_macros.py @@ -1,6 +1,8 @@ import pytest from dbt.tests.util import get_connection, run_dbt +from tests.functional.adapter.mssql.test_index_config import index_count, indexes_def + # flake8: noqa: E501 index_seed_csv = """id_col,data,secondary_data,tertiary_data @@ -66,74 +68,6 @@ select * from {{ ref('raw_data') }} """ -base_validation = """ -with base_query AS ( -select i.[name] as index_name, - substring(column_names, 1, len(column_names)-1) as [columns], - case when i.[type] = 1 then 'Clustered index' - when i.[type] = 2 then 'Nonclustered unique index' - when i.[type] = 3 then 'XML index' - when i.[type] = 4 then 'Spatial index' - when i.[type] = 5 then 'Clustered columnstore index' - when i.[type] = 6 then 'Nonclustered columnstore index' - when i.[type] = 7 then 'Nonclustered hash index' - end as index_type, - case when i.is_unique = 1 then 'Unique' - else 'Not unique' end as [unique], - schema_name(t.schema_id) + '.' + t.[name] as table_view, - case when t.[type] = 'U' then 'Table' - when t.[type] = 'V' then 'View' - end as [object_type], - s.name as schema_name -from sys.objects t - inner join sys.schemas s - on - t.schema_id = s.schema_id - inner join sys.indexes i - on t.object_id = i.object_id - cross apply (select col.[name] + ', ' - from sys.index_columns ic - inner join sys.columns col - on ic.object_id = col.object_id - and ic.column_id = col.column_id - where ic.object_id = t.object_id - and ic.index_id = i.index_id - order by key_ordinal - for xml path ('') ) D (column_names) -where t.is_ms_shipped <> 1 -and index_id > 0 -) -""" - -index_count = ( - base_validation - + """ -select - index_type, - count(*) index_count -from - base_query -WHERE - schema_name='{schema_name}' -group by index_type -""" -) - -other_index_count = ( - base_validation - + """ -SELECT - * -FROM - base_query -WHERE - schema_name='{schema_name}' - AND - table_view='{schema_name}.{table_name}' - -""" -) - class TestIndex: @pytest.fixture(scope="class") @@ -155,6 +89,11 @@ def models(self): "schema.yml": model_yml, } + def drop_artifacts(self, project): + with get_connection(project.adapter): + project.adapter.execute("DROP TABLE IF EXISTS index_model", fetch=True) + project.adapter.execute("DROP TABLE IF EXISTS index_ccs_model") + def test_create_index(self, project): run_dbt(["seed"]) run_dbt(["run"]) @@ -165,10 +104,11 @@ def test_create_index(self, project): ) schema_dict = {_[0]: _[1] for _ in table.rows} expected = { - "Clustered columnstore index": 1, - "Clustered index": 1, - "Nonclustered unique index": 4, + "clustered columnstore": 1, + "clustered unique": 1, + "nonclustered": 4, } + self.drop_artifacts(project) assert schema_dict == expected @@ -230,7 +170,7 @@ def drop_schema_artifacts(self, project): def validate_other_schema(self, project): with get_connection(project.adapter): result, table = project.adapter.execute( - other_index_count.format( + indexes_def.format( schema_name=project.test_schema + "other", table_name="index_model" ), fetch=True, diff --git a/tests/unit/adapters/mssql/test_index_configs.py b/tests/unit/adapters/mssql/test_index_configs.py new file mode 100644 index 00000000..278ba8e7 --- /dev/null +++ b/tests/unit/adapters/mssql/test_index_configs.py @@ -0,0 +1,174 @@ +from datetime import datetime, timezone +from unittest.mock import MagicMock, patch + +import pytest +from dbt.adapters.exceptions import IndexConfigError, IndexConfigNotDictError +from dbt.exceptions import DbtRuntimeError +from dbt_common.utils import encoding as dbt_encoding + +from dbt.adapters.sqlserver.relation_configs.index import SQLServerIndexConfig, SQLServerIndexType + + +def test_sqlserver_index_type_default(): + assert SQLServerIndexType.default() == SQLServerIndexType.nonclustered + + +def test_sqlserver_index_type_valid_types(): + valid_types = SQLServerIndexType.valid_types() + assert isinstance(valid_types, tuple) + assert len(valid_types) > 0 + + +def test_sqlserver_index_config_creation(): + config = SQLServerIndexConfig( + columns=("col1", "col2"), + unique=True, + type=SQLServerIndexType.nonclustered, + included_columns=frozenset(["col3", "col4"]), + ) + assert config.columns == ("col1", "col2") + assert config.unique is True + assert config.type == SQLServerIndexType.nonclustered + assert config.included_columns == frozenset(["col3", "col4"]) + + +def test_sqlserver_index_config_from_dict(): + config_dict = { + "columns": ["col1", "col2"], + "unique": True, + "type": "nonclustered", + "included_columns": ["col3", "col4"], + } + config = SQLServerIndexConfig.from_dict(config_dict) + assert config.columns == ("col1", "col2") + assert config.unique is True + assert config.type == SQLServerIndexType.nonclustered + assert config.included_columns == frozenset(["col3", "col4"]) + + +def test_sqlserver_index_config_validation_rules(): + # Test valid configuration + valid_config = SQLServerIndexConfig( + columns=("col1", "col2"), + unique=True, + type=SQLServerIndexType.nonclustered, + included_columns=frozenset(["col3", "col4"]), + ) + assert len(valid_config.validation_rules) == 4 + for rule in valid_config.validation_rules: + assert rule.validation_check is True + + # Test invalid configurations + with pytest.raises(DbtRuntimeError, match="'columns' is a required property"): + SQLServerIndexConfig(columns=()) + + with pytest.raises( + DbtRuntimeError, + match="Non-clustered indexes are the only index types that can include extra columns", + ): + SQLServerIndexConfig( + columns=("col1",), + type=SQLServerIndexType.clustered, + included_columns=frozenset(["col2"]), + ) + + with pytest.raises( + DbtRuntimeError, + match="Clustered and nonclustered indexes are the only types that can be unique", + ): + SQLServerIndexConfig(columns=("col1",), unique=True, type=SQLServerIndexType.columnstore) + + +def test_sqlserver_index_config_parse_model_node(): + model_node_entry = { + "columns": ["col1", "col2"], + "unique": True, + "type": "nonclustered", + "included_columns": ["col3", "col4"], + } + parsed_dict = SQLServerIndexConfig.parse_model_node(model_node_entry) + assert parsed_dict == { + "columns": ("col1", "col2"), + "unique": True, + "type": "nonclustered", + "included_columns": frozenset(["col3", "col4"]), + } + + +def test_sqlserver_index_config_parse_relation_results(): + relation_results_entry = { + "name": "index_name", + "columns": "col1,col2", + "unique": True, + "type": "nonclustered", + "included_columns": "col3,col4", + } + parsed_dict = SQLServerIndexConfig.parse_relation_results(relation_results_entry) + assert parsed_dict == { + "name": "index_name", + "columns": ("col1", "col2"), + "unique": True, + "type": "nonclustered", + "included_columns": {"col3", "col4"}, + } + + +def test_sqlserver_index_config_as_node_config(): + config = SQLServerIndexConfig( + columns=("col1", "col2"), + unique=True, + type=SQLServerIndexType.nonclustered, + included_columns=frozenset(["col3", "col4"]), + ) + node_config = config.as_node_config + assert node_config == { + "columns": ("col1", "col2"), + "unique": True, + "type": "nonclustered", + "included_columns": frozenset(["col3", "col4"]), + } + + +FAKE_NOW = datetime(2023, 1, 1, 0, 0, 0, tzinfo=timezone.utc) + + +@pytest.fixture(autouse=True) +def patch_datetime_now(): + with patch("dbt.adapters.sqlserver.relation_configs.index.datetime_now") as mocked_datetime: + mocked_datetime.return_value = FAKE_NOW + yield mocked_datetime + + +def test_sqlserver_index_config_render(): + config = SQLServerIndexConfig( + columns=("col1", "col2"), unique=True, type=SQLServerIndexType.nonclustered + ) + relation = MagicMock() + relation.render.return_value = "test_relation" + + result = config.render(relation) + + expected_string = "col1_col2_test_relation_True_nonclustered_2023-01-01T00:00:00+00:00" + + print(f"Expected string: {expected_string}") + print(f"Actual result (MD5): {result}") + print(f"Expected result (MD5): {dbt_encoding.md5(expected_string)}") + + assert result == dbt_encoding.md5(expected_string) + + +def test_sqlserver_index_config_parse(): + valid_raw_index = {"columns": ["col1", "col2"], "unique": True, "type": "nonclustered"} + result = SQLServerIndexConfig.parse(valid_raw_index) + assert isinstance(result, SQLServerIndexConfig) + assert result.columns == ("col1", "col2") + assert result.unique is True + assert result.type == SQLServerIndexType.nonclustered + + assert SQLServerIndexConfig.parse(None) is None + + with pytest.raises(IndexConfigError): + SQLServerIndexConfig.parse({"invalid": "config"}) + + with pytest.raises(IndexConfigNotDictError): + SQLServerIndexConfig.parse("not a dict")