Skip to content

Commit 836b0c9

Browse files
Christophe Di Primacpcloud
Christophe Di Prima
authored andcommitted
feat(risingwave): properly support structs
Add a specific visit method for visit_StructField. "f{idx}" is not supported like in Postgres. Use dot and parenthesis annotation as dots only is not working on the first level. `dataframe.metadata["user.profile.link"].value.url`` Will be translated to: `(((("t1"."metadata")."user.profile.link").value).url)`` Resolves #11182
1 parent 431488a commit 836b0c9

File tree

10 files changed

+125
-81
lines changed

10 files changed

+125
-81
lines changed

ci/schema/risingwave.sql

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,3 +188,17 @@ DROP TABLE IF EXISTS "topk";
188188

189189
CREATE TABLE "topk" ("x" BIGINT);
190190
INSERT INTO "topk" VALUES (1), (1), (NULL);
191+
192+
DROP TABLE IF EXISTS "struct";
193+
CREATE TABLE "struct" (
194+
"abc" STRUCT<"a" DOUBLE, "b" STRING, "c" BIGINT>
195+
);
196+
197+
INSERT INTO "struct" VALUES
198+
(ROW(1.0, 'banana', 2)),
199+
(ROW(2.0, 'apple', 3)),
200+
(ROW(3.0, 'orange', 4)),
201+
(ROW(NULL, 'banana', 2)),
202+
(ROW(2.0, NULL, 3)),
203+
(NULL),
204+
(ROW(3.0, 'orange', NULL));

ibis/backends/risingwave/__init__.py

Lines changed: 78 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from __future__ import annotations
44

55
import contextlib
6+
import itertools
67
from operator import itemgetter
78
from typing import TYPE_CHECKING, Any
89
from urllib.parse import unquote_plus
@@ -16,39 +17,66 @@
1617
import ibis.backends.sql.compilers as sc
1718
import ibis.common.exceptions as com
1819
import ibis.common.exceptions as exc
20+
import ibis.expr.datatypes as dt
1921
import ibis.expr.operations as ops
2022
import ibis.expr.schema as sch
2123
import ibis.expr.types as ir
2224
from ibis import util
2325
from ibis.backends import CanCreateDatabase, CanListCatalog, NoExampleLoader
2426
from ibis.backends.sql import SQLBackend
25-
from ibis.backends.sql.compilers.base import TRUE, C, ColGen
27+
from ibis.backends.sql.compilers.base import TRUE, C
2628
from ibis.util import experimental
2729

2830
if TYPE_CHECKING:
31+
from collections.abc import Iterable, Mapping
2932
from urllib.parse import ParseResult
3033

3134
import pandas as pd
3235
import polars as pl
3336
import pyarrow as pa
3437

3538

39+
def dict_to_struct(struct):
40+
return dt.Struct(
41+
{
42+
field: dtype if isinstance(dtype, dt.DataType) else dict_to_struct(dtype)
43+
for field, dtype in struct.items()
44+
}
45+
)
46+
47+
48+
def string_to_struct(
49+
dtypes: Iterable[tuple[str, dt.DataType, int]], top: str
50+
) -> Mapping[str, dt.DataType]:
51+
result = {}
52+
for field, dtype, _ in dtypes:
53+
field_top, *components, bottom = field.split(".")
54+
assert top == field_top, f"{top} != {field_top}"
55+
child = result.setdefault(top, {})
56+
for component in components:
57+
child = child.setdefault(component, {})
58+
child[bottom] = dtype
59+
return result[top]
60+
61+
3662
def data_and_encode_format(data_format, encode_format, encode_properties):
37-
res = ""
63+
res = []
3864
if data_format is not None:
39-
res = res + " FORMAT " + data_format.upper()
65+
res.append("FORMAT")
66+
res.append(data_format.upper())
4067
if encode_format is not None:
41-
res = res + " ENCODE " + encode_format.upper()
68+
res.append("ENCODE")
69+
res.append(encode_format.upper())
4270
if encode_properties is not None:
43-
res = res + " " + format_properties(encode_properties)
44-
return res
71+
res.append(format_properties(encode_properties))
72+
return " ".join(res)
4573

4674

4775
def format_properties(props):
4876
tokens = []
4977
for k, v in props.items():
5078
tokens.append(f"{k}='{v}'")
51-
return "( {} ) ".format(", ".join(tokens))
79+
return "({}) ".format(", ".join(tokens))
5280

5381

5482
class Backend(SQLBackend, CanListCatalog, CanCreateDatabase, NoExampleLoader):
@@ -264,12 +292,6 @@ def get_schema(
264292
catalog: str | None = None,
265293
database: str | None = None,
266294
):
267-
a = ColGen(table="a")
268-
c = ColGen(table="c")
269-
n = ColGen(table="n")
270-
271-
format_type = self.compiler.f["pg_catalog.format_type"]
272-
273295
# If no database is specified, assume the current database
274296
db = database or self.current_database
275297

@@ -280,46 +302,44 @@ def get_schema(
280302
if database is None and (temp_table_db := self._session_temp_db) is not None:
281303
dbs.append(sge.convert(temp_table_db))
282304

283-
type_info = (
284-
sg.select(
285-
a.attname.as_("column_name"),
286-
format_type(a.atttypid, a.atttypmod).as_("data_type"),
287-
sg.not_(a.attnotnull).as_("nullable"),
288-
)
289-
.from_(sg.table("pg_attribute", db="pg_catalog").as_("a"))
290-
.join(
291-
sg.table("pg_class", db="pg_catalog").as_("c"),
292-
on=c.oid.eq(a.attrelid),
293-
join_type="INNER",
294-
)
295-
.join(
296-
sg.table("pg_namespace", db="pg_catalog").as_("n"),
297-
on=n.oid.eq(c.relnamespace),
298-
join_type="INNER",
299-
)
300-
.where(
301-
a.attnum > 0,
302-
sg.not_(a.attisdropped),
303-
n.nspname.isin(*dbs),
304-
c.relname.eq(sge.convert(name)),
305-
)
306-
.order_by(a.attnum)
307-
)
305+
ident = sg.table(name, catalog=catalog, db=database, quoted=True)
306+
try:
307+
with self._safe_raw_sql(sge.Describe(this=ident)) as cur:
308+
raw_rows = cur.fetchall()
309+
except psycopg2.InternalError as exc:
310+
raise com.TableNotFound(name) from exc
308311

309312
type_mapper = self.compiler.type_mapper
310313

311-
with self._safe_raw_sql(type_info) as cur:
312-
rows = cur.fetchall()
314+
rows = []
315+
field_number = 0
316+
for raw_name, dtype, hidden, *_ in raw_rows:
317+
if hidden == "false":
318+
field_number += (not dtype) or "." not in raw_name
319+
rows.append(
320+
(
321+
raw_name,
322+
None
323+
if not dtype
324+
else type_mapper.from_string(dtype, nullable=True),
325+
field_number,
326+
)
327+
)
313328

314-
if not rows:
315-
raise com.TableNotFound(name)
329+
schema = {}
316330

317-
return sch.Schema(
318-
{
319-
col: type_mapper.from_string(typestr, nullable=nullable)
320-
for col, typestr, nullable in rows
321-
}
322-
)
331+
for _, values in itertools.groupby(rows, key=lambda x: x[-1]):
332+
vals = list(values)
333+
assert vals, "vals is empty"
334+
335+
if len(vals) == 1:
336+
name, dtype, _ = vals.pop()
337+
schema[name] = dtype
338+
else:
339+
name, _, _ = vals[0]
340+
schema[name] = dict_to_struct(string_to_struct(vals[1:], name))
341+
342+
return sch.Schema(schema)
323343

324344
def _get_schema_using_query(self, query: str) -> sch.Schema:
325345
name = util.gen_name(f"{self.name}_metadata")
@@ -583,7 +603,9 @@ def create_table(
583603
create_stmt = sge.Create(
584604
kind="TABLE",
585605
this=target,
586-
properties=sge.Properties.from_dict(connector_properties),
606+
properties=sge.Properties(
607+
expressions=sge.Properties.from_dict(connector_properties)
608+
),
587609
)
588610
create_stmt = create_stmt.sql(self.dialect) + data_and_encode_format(
589611
data_format, encode_format, encode_properties
@@ -742,7 +764,6 @@ def create_source(
742764
data_format: str,
743765
encode_format: str,
744766
encode_properties: dict | None = None,
745-
includes: dict[str, str] | None = None,
746767
) -> ir.Table:
747768
"""Creating a source.
748769
@@ -763,32 +784,23 @@ def create_source(
763784
The encode format for the new source, e.g., "JSON". data_format and encode_format must be specified at the same time.
764785
encode_properties
765786
The properties of encode format, providing information like schema registry url. Refer https://docs.risingwave.com/docs/current/sql-create-source/ for more details.
766-
includes
767-
A dict of `INCLUDE` clauses of the form `{field: alias, ...}`.
768-
Set value(s) to `None` if no alias is needed. Refer to https://docs.risingwave.com/docs/current/sql-create-source/ for more details.
769787
770788
Returns
771789
-------
772790
Table
773791
Table expression
774792
"""
775-
quoted = self.compiler.quoted
776-
table = sg.table(name, db=database, quoted=quoted)
793+
table = sg.table(name, db=database, quoted=self.compiler.quoted)
777794
target = sge.Schema(this=table, expressions=schema.to_sqlglot(self.dialect))
778795

779-
properties = sge.Properties.from_dict(connector_properties)
780-
properties.expressions.extend(
781-
sge.IncludeProperty(
782-
this=sg.to_identifier(include_type),
783-
alias=sg.to_identifier(column_name, quoted=quoted)
784-
if column_name
785-
else None,
786-
)
787-
for include_type, column_name in (includes or {}).items()
796+
create_stmt = sge.Create(
797+
kind="SOURCE",
798+
this=target,
799+
properties=sge.Properties(
800+
expressions=sge.Properties.from_dict(connector_properties)
801+
),
788802
)
789803

790-
create_stmt = sge.Create(kind="SOURCE", this=target, properties=properties)
791-
792804
create_stmt = create_stmt.sql(self.dialect) + data_and_encode_format(
793805
data_format, encode_format, encode_properties
794806
)

ibis/backends/risingwave/tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class TestConf(ServiceBackendTest):
3030
# for numeric and decimal
3131

3232
returned_timestamp_unit = "s"
33-
supports_structs = False
33+
supports_structs = True
3434
rounding_method = "half_to_even"
3535
service_name = "risingwave"
3636
deps = ("psycopg2",)

ibis/backends/risingwave/tests/test_client.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import ibis
1111
import ibis.expr.datatypes as dt
1212
import ibis.expr.types as ir
13+
from ibis.backends.risingwave import string_to_struct
1314
from ibis.util import gen_name
1415

1516
pytest.importorskip("psycopg2")
@@ -132,3 +133,17 @@ def test_insert_with_cte(con):
132133
assert Y.execute().empty
133134
con.drop_table("Y")
134135
con.drop_table("X")
136+
137+
138+
@pytest.mark.parametrize(
139+
("expected", "shredded", "top"),
140+
[
141+
(
142+
{"b": {"c": "int", "d": "string"}},
143+
[("a.b.c", "int", 1), ("a.b.d", "string", 1)],
144+
"a",
145+
),
146+
],
147+
)
148+
def test_string_to_struct(expected, shredded, top):
149+
assert string_to_struct(shredded, top) == expected

ibis/backends/sql/compilers/risingwave.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,13 @@ def to_sqlglot(
7676
for col, typ in schema.items()
7777
if typ.is_map()
7878
)
79+
# similar to maps are structs (spelled ROW in RisingWave), which come
80+
# back as stringified tuples without first converting
81+
conversions.update(
82+
(col, table_expr[col].cast(dt.JSON(binary=True)))
83+
for col, typ in schema.items()
84+
if typ.is_struct()
85+
)
7986

8087
if conversions:
8188
table_expr = table_expr.mutate(**conversions)
@@ -84,6 +91,9 @@ def to_sqlglot(
8491
def visit_DateNow(self, op):
8592
return self.cast(sge.CurrentTimestamp(), dt.date)
8693

94+
def visit_Array(self, op, *, exprs):
95+
return self.cast(self.f.array(*exprs), op.dtype)
96+
8797
def visit_Cast(self, op, *, arg, to):
8898
if to.is_json():
8999
return self.f.to_jsonb(arg)
@@ -165,6 +175,8 @@ def visit_NonNullLiteral(self, op, *, value, dtype):
165175
return self.f.map_from_key_values(
166176
self.f.array(*value.keys()), self.f.array(*value.values())
167177
)
178+
elif dtype.is_struct():
179+
return self.cast(self.f.row(*value.values()), dtype)
168180
return None
169181

170182
def visit_MapGet(self, op, *, arg, key, default):
@@ -193,5 +205,8 @@ def visit_MapContains(self, op, *, arg, key):
193205
self.cast(arg, op.arg.dtype), self.cast(key, op.key.dtype)
194206
)
195207

208+
def visit_StructField(self, op, *, arg, field):
209+
return sge.Dot(this=sge.Paren(this=arg), expression=sge.to_identifier(field))
210+
196211

197212
compiler = RisingWaveCompiler()

ibis/backends/tests/test_array.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -283,11 +283,6 @@ def test_array_index(con, idx):
283283
reason="backend does not support nullable nested types",
284284
raises=AssertionError,
285285
)
286-
@pytest.mark.notimpl(
287-
["risingwave"],
288-
raises=AssertionError,
289-
reason="Do not nest ARRAY types; ARRAY(basetype) handles multi-dimensional arrays of basetype",
290-
)
291286
@pytest.mark.never(
292287
["bigquery"], reason="doesn't support arrays of arrays", raises=AssertionError
293288
)

ibis/backends/tests/test_generic.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1741,7 +1741,6 @@ def hash_256(col):
17411741
pytest.mark.notimpl(["druid"], raises=PyDruidProgrammingError),
17421742
pytest.mark.notimpl(["oracle"], raises=OracleDatabaseError),
17431743
pytest.mark.notimpl(["postgres"], raises=PsycoPgSyntaxError),
1744-
pytest.mark.notimpl(["risingwave"], raises=PsycoPg2InternalError),
17451744
pytest.mark.notimpl(["snowflake"], raises=AssertionError),
17461745
pytest.mark.never(
17471746
["datafusion", "exasol", "impala", "mssql", "mysql", "sqlite"],

ibis/backends/tests/test_map.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,6 @@ def test_literal_map_getitem_broadcast(backend, alltypes, df):
427427
marks=[
428428
pytest.mark.notyet("clickhouse", reason="nested types can't be null"),
429429
mark_notyet_postgres,
430-
pytest.mark.notimpl(["risingwave"], raises=PsycoPg2InternalError),
431430
],
432431
id="struct",
433432
),

ibis/backends/tests/test_param.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def test_scalar_param_array(con):
7171
assert result == len(value)
7272

7373

74-
@pytest.mark.notimpl(["impala", "postgres", "risingwave", "druid", "oracle", "exasol"])
74+
@pytest.mark.notimpl(["impala", "postgres", "druid", "oracle", "exasol"])
7575
@pytest.mark.never(
7676
["mysql", "sqlite", "mssql"],
7777
reason="mysql and sqlite will never implement struct types",

ibis/backends/tests/test_struct.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def test_all_fields(struct, struct_df):
7777
_NULL_STRUCT_LITERAL = ibis.null().cast("struct<a: int64, b: string, c: float64>")
7878

7979

80-
@pytest.mark.notimpl(["postgres", "risingwave"])
80+
@pytest.mark.notimpl(["postgres"])
8181
def test_literal(backend, con):
8282
dtype = _STRUCT_LITERAL.type().to_pandas()
8383
result = pd.Series([con.execute(_STRUCT_LITERAL)], dtype=dtype)
@@ -137,11 +137,6 @@ def test_collect_into_struct(alltypes):
137137
@pytest.mark.notimpl(
138138
["postgres"], reason="struct literals not implemented", raises=PsycoPgSyntaxError
139139
)
140-
@pytest.mark.notimpl(
141-
["risingwave"],
142-
reason="struct literals not implemented",
143-
raises=PsycoPg2InternalError,
144-
)
145140
@pytest.mark.notyet(["datafusion"], raises=Exception, reason="unsupported syntax")
146141
@pytest.mark.notimpl(["flink"], raises=Py4JJavaError, reason="not implemented in ibis")
147142
def test_field_access_after_case(con):

0 commit comments

Comments
 (0)