Skip to content

Commit b111ac3

Browse files
fix(databricks)!: transpile TRY_DIVIDE to/from other dialects [CLAUDE] (#7489)
Databricks' TRY_DIVIDE(a, b) returns NULL on division by zero. It is now mapped to exp.SafeDivide, which already has the correct semantics and generates appropriate SQL for all target dialects. Closes #7312
1 parent 4d5a9e0 commit b111ac3

7 files changed

Lines changed: 49 additions & 5 deletions

File tree

sqlglot-integration-tests

sqlglot/generators/duckdb.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3414,6 +3414,21 @@ def right_sql(self, expression: exp.Right) -> str:
34143414
def rtrimmedlength_sql(self, expression: exp.RtrimmedLength) -> str:
34153415
return self.func("LENGTH", exp.Trim(this=expression.this, position="TRAILING"))
34163416

3417+
def stuff_sql(self, expression: exp.Stuff) -> str:
3418+
base = expression.this
3419+
start = expression.args["start"]
3420+
length = expression.args["length"]
3421+
insertion = expression.expression
3422+
left = exp.Substring(
3423+
this=base.copy(),
3424+
start=exp.Literal.number(1),
3425+
length=start.copy() - exp.Literal.number(1),
3426+
)
3427+
right = exp.Substring(this=base.copy(), start=start.copy() + length.copy())
3428+
return self.sql(
3429+
exp.DPipe(this=exp.DPipe(this=left, expression=insertion), expression=right)
3430+
)
3431+
34173432
def rand_sql(self, expression: exp.Rand) -> str:
34183433
seed = expression.this
34193434
if seed is not None:

sqlglot/generators/spark.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ class SparkGenerator(Spark2Generator):
119119
f"PARTITIONED BY {self.wrap(self.expressions(sqls=[_normalize_partition(e) for e in e.this.expressions], skip_first=True))}"
120120
),
121121
exp.SafeAdd: rename_func("TRY_ADD"),
122+
exp.SafeDivide: rename_func("TRY_DIVIDE"),
122123
exp.SafeMultiply: rename_func("TRY_MULTIPLY"),
123124
exp.SafeSubtract: rename_func("TRY_SUBTRACT"),
124125
exp.StartsWith: rename_func("STARTSWITH"),

sqlglot/parsers/spark.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ class SparkParser(Spark2Parser):
8787
"TIMESTAMPADD": _build_dateadd,
8888
"TIMESTAMPDIFF": build_date_delta(exp.TimestampDiff),
8989
"TRY_ADD": exp.SafeAdd.from_arg_list,
90+
"TRY_DIVIDE": exp.SafeDivide.from_arg_list,
9091
"TRY_MULTIPLY": exp.SafeMultiply.from_arg_list,
9192
"TRY_SUBTRACT": exp.SafeSubtract.from_arg_list,
9293
"DATEDIFF": _build_datediff,

tests/dialects/test_bigquery.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1768,8 +1768,8 @@ def test_bigquery(self):
17681768
"trino": "IF(y <> 0, CAST(x AS DOUBLE) / y, NULL)",
17691769
"hive": "IF(y <> 0, x / y, NULL)",
17701770
"spark2": "IF(y <> 0, x / y, NULL)",
1771-
"spark": "IF(y <> 0, x / y, NULL)",
1772-
"databricks": "IF(y <> 0, x / y, NULL)",
1771+
"spark": "TRY_DIVIDE(x, y)",
1772+
"databricks": "TRY_DIVIDE(x, y)",
17731773
"snowflake": "IFF(y <> 0, x / y, NULL)",
17741774
"postgres": "CASE WHEN y <> 0 THEN CAST(x AS DOUBLE PRECISION) / y ELSE NULL END",
17751775
},
@@ -1783,8 +1783,8 @@ def test_bigquery(self):
17831783
"trino": "IF((2 * y) <> 0, CAST((x + 1) AS DOUBLE) / (2 * y), NULL)",
17841784
"hive": "IF((2 * y) <> 0, (x + 1) / (2 * y), NULL)",
17851785
"spark2": "IF((2 * y) <> 0, (x + 1) / (2 * y), NULL)",
1786-
"spark": "IF((2 * y) <> 0, (x + 1) / (2 * y), NULL)",
1787-
"databricks": "IF((2 * y) <> 0, (x + 1) / (2 * y), NULL)",
1786+
"spark": "TRY_DIVIDE(x + 1, 2 * y)",
1787+
"databricks": "TRY_DIVIDE(x + 1, 2 * y)",
17881788
"snowflake": "IFF((2 * y) <> 0, (x + 1) / (2 * y), NULL)",
17891789
"postgres": "CASE WHEN (2 * y) <> 0 THEN CAST((x + 1) AS DOUBLE PRECISION) / (2 * y) ELSE NULL END",
17901790
},

tests/dialects/test_databricks.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,18 @@ def test_iff(self):
510510
},
511511
)
512512

513+
def test_try_divide(self):
514+
self.validate_all(
515+
"SELECT TRY_DIVIDE(a, b)",
516+
read={"databricks": "SELECT TRY_DIVIDE(a, b)"},
517+
write={
518+
"databricks": "SELECT TRY_DIVIDE(a, b)",
519+
"snowflake": "SELECT IFF(b <> 0, a / b, NULL)",
520+
"duckdb": "SELECT CASE WHEN b <> 0 THEN a / b ELSE NULL END",
521+
"spark": "SELECT TRY_DIVIDE(a, b)",
522+
},
523+
)
524+
513525
def test_declare(self):
514526
self.validate_identity("DECLARE VAR x INT", "DECLARE x INT")
515527
self.validate_identity("DECLARE x INT")

tests/dialects/test_spark.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1382,3 +1382,18 @@ def test_set_variable(self):
13821382
"databricks": "SET VARIABLE (v1, v2) = (SELECT 1, 2)",
13831383
},
13841384
)
1385+
1386+
def test_try_divide(self):
1387+
self.validate_all(
1388+
"SELECT TRY_DIVIDE(a, b)",
1389+
read={
1390+
"spark": "SELECT TRY_DIVIDE(a, b)",
1391+
"databricks": "SELECT TRY_DIVIDE(a, b)",
1392+
},
1393+
write={
1394+
"spark": "SELECT TRY_DIVIDE(a, b)",
1395+
"databricks": "SELECT TRY_DIVIDE(a, b)",
1396+
"snowflake": "SELECT IFF(b <> 0, a / b, NULL)",
1397+
"duckdb": "SELECT CASE WHEN b <> 0 THEN a / b ELSE NULL END",
1398+
},
1399+
)

0 commit comments

Comments
 (0)