Skip to content

Commit 4ee6cb3

Browse files
authored
Fix: Allow python models to call resolve_table() on themselves within a unit test (#4967)
1 parent 5dfdeca commit 4ee6cb3

File tree

3 files changed

+284
-3
lines changed

3 files changed

+284
-3
lines changed

sqlmesh/core/context.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,15 @@ def resolve_table(self, model_name: str) -> str:
214214
"""
215215
model_name = normalize_model_name(model_name, self.default_catalog, self.default_dialect)
216216

217+
if model_name not in self._model_tables:
218+
model_name_list = "\n".join(list(self._model_tables))
219+
logger.debug(
220+
f"'{model_name}' not found in model to table mapping. Available model names: \n{model_name_list}"
221+
)
222+
raise SQLMeshError(
223+
f"Unable to find a table mapping for model '{model_name}'. Has it been spelled correctly?"
224+
)
225+
217226
# We generate SQL for the default dialect because the table name may be used in a
218227
# fetchdf call and so the quotes need to be correct (eg. backticks for bigquery)
219228
return parse_one(self._model_tables[model_name]).sql(

sqlmesh/core/test/context.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ def _model_tables(self) -> t.Dict[str, str]:
4343
# Include upstream dependencies to ensure they can be resolved during test execution
4444
return {
4545
name: self._test._test_fixture_table(name).sql()
46-
for model in self._models.values()
47-
for name in [model.name, *model.depends_on]
46+
for normalized_model_name, model in self._models.items()
47+
for name in [normalized_model_name, *model.depends_on]
4848
}
4949

5050
def with_variables(

tests/core/test_test.py

Lines changed: 273 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from sqlmesh.core.model import Model, SqlModel, load_sql_based_model, model
3232
from sqlmesh.core.test.definition import ModelTest, PythonModelTest, SqlModelTest
3333
from sqlmesh.core.test.result import ModelTextTestResult
34+
from sqlmesh.core.test.context import TestExecutionContext
3435
from sqlmesh.utils import Verbosity
3536
from sqlmesh.utils.errors import ConfigError, SQLMeshError, TestError
3637
from sqlmesh.utils.yaml import dump as dump_yaml
@@ -69,11 +70,14 @@ def _create_model(
6970
meta: str = SUSHI_FOO_META,
7071
dialect: t.Optional[str] = None,
7172
default_catalog: t.Optional[str] = None,
73+
**kwargs: t.Any,
7274
) -> SqlModel:
7375
parsed_definition = parse(f"{meta};{query}", default_dialect=dialect)
7476
return t.cast(
7577
SqlModel,
76-
load_sql_based_model(parsed_definition, dialect=dialect, default_catalog=default_catalog),
78+
load_sql_based_model(
79+
parsed_definition, dialect=dialect, default_catalog=default_catalog, **kwargs
80+
),
7781
)
7882

7983

@@ -2814,3 +2818,271 @@ def test_test_generation_with_timestamp_nat(tmp_path: Path) -> None:
28142818
assert query_output[0]["ts_col"] == datetime.datetime(2024, 9, 20, 11, 30, 0, 123456)
28152819
assert query_output[1]["ts_col"] is None
28162820
assert query_output[2]["ts_col"] == datetime.datetime(2024, 9, 21, 15, 45, 0, 987654)
2821+
2822+
2823+
def test_parameterized_name_sql_model() -> None:
2824+
variables = {"table_catalog": "gold"}
2825+
model = _create_model(
2826+
"select 1 as id, 'foo' as name",
2827+
meta="""
2828+
MODEL (
2829+
name @{table_catalog}.sushi.foo,
2830+
kind FULL
2831+
)
2832+
""",
2833+
dialect="snowflake",
2834+
variables=variables,
2835+
)
2836+
assert model.fqn == '"GOLD"."SUSHI"."FOO"'
2837+
2838+
test = _create_test(
2839+
body=load_yaml(
2840+
"""
2841+
test_foo:
2842+
model: {{ var('table_catalog' ) }}.sushi.foo
2843+
outputs:
2844+
query:
2845+
- id: 1
2846+
name: foo
2847+
""",
2848+
variables=variables,
2849+
),
2850+
test_name="test_foo",
2851+
model=model,
2852+
context=Context(
2853+
config=Config(
2854+
model_defaults=ModelDefaultsConfig(dialect="snowflake"), variables=variables
2855+
)
2856+
),
2857+
)
2858+
2859+
assert test.body["model"] == '"GOLD"."SUSHI"."FOO"'
2860+
2861+
_check_successful_or_raise(test.run())
2862+
2863+
2864+
def test_parameterized_name_python_model() -> None:
2865+
variables = {"table_catalog": "gold"}
2866+
2867+
@model(
2868+
name="@{table_catalog}.sushi.foo",
2869+
columns={
2870+
"id": "int",
2871+
"name": "varchar",
2872+
},
2873+
dialect="snowflake",
2874+
)
2875+
def execute(
2876+
context: ExecutionContext,
2877+
**kwargs: t.Any,
2878+
) -> pd.DataFrame:
2879+
return pd.DataFrame([{"ID": 1, "NAME": "foo"}])
2880+
2881+
python_model = model.get_registry()["@{table_catalog}.sushi.foo"].model(
2882+
module_path=Path("."), path=Path("."), variables=variables
2883+
)
2884+
2885+
assert python_model.fqn == '"GOLD"."SUSHI"."FOO"'
2886+
2887+
test = _create_test(
2888+
body=load_yaml(
2889+
"""
2890+
test_foo:
2891+
model: {{ var('table_catalog' ) }}.sushi.foo
2892+
outputs:
2893+
query:
2894+
- id: 1
2895+
name: foo
2896+
""",
2897+
variables=variables,
2898+
),
2899+
test_name="test_foo",
2900+
model=python_model,
2901+
context=Context(
2902+
config=Config(
2903+
model_defaults=ModelDefaultsConfig(dialect="snowflake"), variables=variables
2904+
)
2905+
),
2906+
)
2907+
2908+
assert test.body["model"] == '"GOLD"."SUSHI"."FOO"'
2909+
2910+
_check_successful_or_raise(test.run())
2911+
2912+
2913+
def test_parameterized_name_self_referential_model():
2914+
variables = {"table_catalog": "gold"}
2915+
model = _create_model(
2916+
"""
2917+
with last_value as (
2918+
select coalesce(max(v), 0) as v from @{table_catalog}.sushi.foo
2919+
)
2920+
select v + 1 as v from last_value
2921+
""",
2922+
meta="""
2923+
MODEL (
2924+
name @{table_catalog}.sushi.foo,
2925+
kind FULL
2926+
)
2927+
""",
2928+
dialect="snowflake",
2929+
variables=variables,
2930+
)
2931+
assert model.fqn == '"GOLD"."SUSHI"."FOO"'
2932+
2933+
test1 = _create_test(
2934+
body=load_yaml(
2935+
"""
2936+
test_foo_intial_state:
2937+
model: {{ var('table_catalog' ) }}.sushi.foo
2938+
inputs:
2939+
{{ var('table_catalog' ) }}.sushi.foo:
2940+
rows: []
2941+
columns:
2942+
v: int
2943+
outputs:
2944+
query:
2945+
- v: 1
2946+
""",
2947+
variables=variables,
2948+
),
2949+
test_name="test_foo_intial_state",
2950+
model=model,
2951+
context=Context(
2952+
config=Config(
2953+
model_defaults=ModelDefaultsConfig(dialect="snowflake"), variables=variables
2954+
)
2955+
),
2956+
)
2957+
assert isinstance(test1, SqlModelTest)
2958+
assert test1.body["model"] == '"GOLD"."SUSHI"."FOO"'
2959+
test1_model_query = test1._render_model_query().sql(dialect="snowflake")
2960+
assert '"GOLD"."SUSHI"."FOO"' not in test1_model_query
2961+
assert (
2962+
test1._test_fixture_table('"GOLD"."SUSHI"."FOO"').sql(dialect="snowflake", identify=True)
2963+
in test1_model_query
2964+
)
2965+
2966+
test2 = _create_test(
2967+
body=load_yaml(
2968+
"""
2969+
test_foo_cumulative:
2970+
model: {{ var('table_catalog' ) }}.sushi.foo
2971+
inputs:
2972+
{{ var('table_catalog' ) }}.sushi.foo:
2973+
rows:
2974+
- v: 5
2975+
outputs:
2976+
query:
2977+
- v: 6
2978+
""",
2979+
variables=variables,
2980+
),
2981+
test_name="test_foo_cumulative",
2982+
model=model,
2983+
context=Context(
2984+
config=Config(
2985+
model_defaults=ModelDefaultsConfig(dialect="snowflake"), variables=variables
2986+
)
2987+
),
2988+
)
2989+
assert isinstance(test2, SqlModelTest)
2990+
assert test2.body["model"] == '"GOLD"."SUSHI"."FOO"'
2991+
test2_model_query = test2._render_model_query().sql(dialect="snowflake")
2992+
assert '"GOLD"."SUSHI"."FOO"' not in test2_model_query
2993+
assert (
2994+
test2._test_fixture_table('"GOLD"."SUSHI"."FOO"').sql(dialect="snowflake", identify=True)
2995+
in test2_model_query
2996+
)
2997+
2998+
_check_successful_or_raise(test1.run())
2999+
_check_successful_or_raise(test2.run())
3000+
3001+
3002+
def test_parameterized_name_self_referential_python_model():
3003+
variables = {"table_catalog": "gold"}
3004+
3005+
@model(
3006+
name="@{table_catalog}.sushi.foo",
3007+
columns={
3008+
"id": "int",
3009+
},
3010+
depends_on=["@{table_catalog}.sushi.bar"],
3011+
dialect="snowflake",
3012+
)
3013+
def execute(
3014+
context: ExecutionContext,
3015+
**kwargs: t.Any,
3016+
) -> pd.DataFrame:
3017+
current_table = context.resolve_table(f"{context.var('table_catalog')}.sushi.foo")
3018+
current_df = context.fetchdf(f"select id from {current_table}")
3019+
upstream_table = context.resolve_table(f"{context.var('table_catalog')}.sushi.bar")
3020+
upstream_df = context.fetchdf(f"select id from {upstream_table}")
3021+
3022+
return pd.DataFrame([{"ID": upstream_df["ID"].sum() + current_df["ID"].sum()}])
3023+
3024+
@model(
3025+
name="@{table_catalog}.sushi.bar",
3026+
columns={
3027+
"id": "int",
3028+
},
3029+
dialect="snowflake",
3030+
)
3031+
def execute(
3032+
context: ExecutionContext,
3033+
**kwargs: t.Any,
3034+
) -> pd.DataFrame:
3035+
return pd.DataFrame([{"ID": 1}])
3036+
3037+
model_foo = model.get_registry()["@{table_catalog}.sushi.foo"].model(
3038+
module_path=Path("."), path=Path("."), variables=variables
3039+
)
3040+
model_bar = model.get_registry()["@{table_catalog}.sushi.bar"].model(
3041+
module_path=Path("."), path=Path("."), variables=variables
3042+
)
3043+
3044+
assert model_foo.fqn == '"GOLD"."SUSHI"."FOO"'
3045+
assert model_bar.fqn == '"GOLD"."SUSHI"."BAR"'
3046+
3047+
ctx = Context(
3048+
config=Config(model_defaults=ModelDefaultsConfig(dialect="snowflake"), variables=variables)
3049+
)
3050+
ctx.upsert_model(model_foo)
3051+
ctx.upsert_model(model_bar)
3052+
3053+
test = _create_test(
3054+
body=load_yaml(
3055+
"""
3056+
test_foo:
3057+
model: {{ var('table_catalog') }}.sushi.foo
3058+
inputs:
3059+
{{ var('table_catalog') }}.sushi.foo:
3060+
rows:
3061+
- id: 3
3062+
{{ var('table_catalog') }}.sushi.bar:
3063+
rows:
3064+
- id: 5
3065+
outputs:
3066+
query:
3067+
- id: 8
3068+
""",
3069+
variables=variables,
3070+
),
3071+
test_name="test_foo",
3072+
model=model_foo,
3073+
context=ctx,
3074+
)
3075+
3076+
assert isinstance(test, PythonModelTest)
3077+
3078+
assert test.body["model"] == '"GOLD"."SUSHI"."FOO"'
3079+
assert '"GOLD"."SUSHI"."BAR"' in test.body["inputs"]
3080+
3081+
assert isinstance(test.context, TestExecutionContext)
3082+
assert '"GOLD"."SUSHI"."FOO"' in test.context._model_tables
3083+
assert '"GOLD"."SUSHI"."BAR"' in test.context._model_tables
3084+
3085+
with pytest.raises(SQLMeshError, match=r"Unable to find a table mapping"):
3086+
test.context.resolve_table("silver.sushi.bar")
3087+
3088+
_check_successful_or_raise(test.run())

0 commit comments

Comments
 (0)