3
3
from __future__ import annotations
4
4
5
5
import contextlib
6
+ import itertools
6
7
from operator import itemgetter
7
8
from typing import TYPE_CHECKING , Any
8
9
from urllib .parse import unquote_plus
16
17
import ibis .backends .sql .compilers as sc
17
18
import ibis .common .exceptions as com
18
19
import ibis .common .exceptions as exc
20
+ import ibis .expr .datatypes as dt
19
21
import ibis .expr .operations as ops
20
22
import ibis .expr .schema as sch
21
23
import ibis .expr .types as ir
22
24
from ibis import util
23
25
from ibis .backends import CanCreateDatabase , CanListCatalog , NoExampleLoader
24
26
from ibis .backends .sql import SQLBackend
25
- from ibis .backends .sql .compilers .base import TRUE , C , ColGen
27
+ from ibis .backends .sql .compilers .base import TRUE , C
26
28
from ibis .util import experimental
27
29
28
30
if TYPE_CHECKING :
31
+ from collections .abc import Iterable , Mapping
29
32
from urllib .parse import ParseResult
30
33
31
34
import pandas as pd
32
35
import polars as pl
33
36
import pyarrow as pa
34
37
35
38
39
+ def dict_to_struct (struct ):
40
+ return dt .Struct (
41
+ {
42
+ field : dtype if isinstance (dtype , dt .DataType ) else dict_to_struct (dtype )
43
+ for field , dtype in struct .items ()
44
+ }
45
+ )
46
+
47
+
48
+ def string_to_struct (
49
+ dtypes : Iterable [tuple [str , dt .DataType , int ]], top : str
50
+ ) -> Mapping [str , dt .DataType ]:
51
+ result = {}
52
+ for field , dtype , _ in dtypes :
53
+ field_top , * components , bottom = field .split ("." )
54
+ assert top == field_top , f"{ top } != { field_top } "
55
+ child = result .setdefault (top , {})
56
+ for component in components :
57
+ child = child .setdefault (component , {})
58
+ child [bottom ] = dtype
59
+ return result [top ]
60
+
61
+
36
62
def data_and_encode_format (data_format , encode_format , encode_properties ):
37
- res = ""
63
+ res = []
38
64
if data_format is not None :
39
- res = res + " FORMAT " + data_format .upper ()
65
+ res .append ("FORMAT" )
66
+ res .append (data_format .upper ())
40
67
if encode_format is not None :
41
- res = res + " ENCODE " + encode_format .upper ()
68
+ res .append ("ENCODE" )
69
+ res .append (encode_format .upper ())
42
70
if encode_properties is not None :
43
- res = res + " " + format_properties (encode_properties )
44
- return res
71
+ res . append ( format_properties (encode_properties ) )
72
+ return " " . join ( res )
45
73
46
74
47
75
def format_properties (props ):
48
76
tokens = []
49
77
for k , v in props .items ():
50
78
tokens .append (f"{ k } ='{ v } '" )
51
- return "( {} ) " .format (", " .join (tokens ))
79
+ return "({} ) " .format (", " .join (tokens ))
52
80
53
81
54
82
class Backend (SQLBackend , CanListCatalog , CanCreateDatabase , NoExampleLoader ):
@@ -264,12 +292,6 @@ def get_schema(
264
292
catalog : str | None = None ,
265
293
database : str | None = None ,
266
294
):
267
- a = ColGen (table = "a" )
268
- c = ColGen (table = "c" )
269
- n = ColGen (table = "n" )
270
-
271
- format_type = self .compiler .f ["pg_catalog.format_type" ]
272
-
273
295
# If no database is specified, assume the current database
274
296
db = database or self .current_database
275
297
@@ -280,46 +302,44 @@ def get_schema(
280
302
if database is None and (temp_table_db := self ._session_temp_db ) is not None :
281
303
dbs .append (sge .convert (temp_table_db ))
282
304
283
- type_info = (
284
- sg .select (
285
- a .attname .as_ ("column_name" ),
286
- format_type (a .atttypid , a .atttypmod ).as_ ("data_type" ),
287
- sg .not_ (a .attnotnull ).as_ ("nullable" ),
288
- )
289
- .from_ (sg .table ("pg_attribute" , db = "pg_catalog" ).as_ ("a" ))
290
- .join (
291
- sg .table ("pg_class" , db = "pg_catalog" ).as_ ("c" ),
292
- on = c .oid .eq (a .attrelid ),
293
- join_type = "INNER" ,
294
- )
295
- .join (
296
- sg .table ("pg_namespace" , db = "pg_catalog" ).as_ ("n" ),
297
- on = n .oid .eq (c .relnamespace ),
298
- join_type = "INNER" ,
299
- )
300
- .where (
301
- a .attnum > 0 ,
302
- sg .not_ (a .attisdropped ),
303
- n .nspname .isin (* dbs ),
304
- c .relname .eq (sge .convert (name )),
305
- )
306
- .order_by (a .attnum )
307
- )
305
+ ident = sg .table (name , catalog = catalog , db = database , quoted = True )
306
+ try :
307
+ with self ._safe_raw_sql (sge .Describe (this = ident )) as cur :
308
+ raw_rows = cur .fetchall ()
309
+ except psycopg2 .InternalError as exc :
310
+ raise com .TableNotFound (name ) from exc
308
311
309
312
type_mapper = self .compiler .type_mapper
310
313
311
- with self ._safe_raw_sql (type_info ) as cur :
312
- rows = cur .fetchall ()
314
+ rows = []
315
+ field_number = 0
316
+ for raw_name , dtype , hidden , * _ in raw_rows :
317
+ if hidden == "false" :
318
+ field_number += (not dtype ) or "." not in raw_name
319
+ rows .append (
320
+ (
321
+ raw_name ,
322
+ None
323
+ if not dtype
324
+ else type_mapper .from_string (dtype , nullable = True ),
325
+ field_number ,
326
+ )
327
+ )
313
328
314
- if not rows :
315
- raise com .TableNotFound (name )
329
+ schema = {}
316
330
317
- return sch .Schema (
318
- {
319
- col : type_mapper .from_string (typestr , nullable = nullable )
320
- for col , typestr , nullable in rows
321
- }
322
- )
331
+ for _ , values in itertools .groupby (rows , key = lambda x : x [- 1 ]):
332
+ vals = list (values )
333
+ assert vals , "vals is empty"
334
+
335
+ if len (vals ) == 1 :
336
+ name , dtype , _ = vals .pop ()
337
+ schema [name ] = dtype
338
+ else :
339
+ name , _ , _ = vals [0 ]
340
+ schema [name ] = dict_to_struct (string_to_struct (vals [1 :], name ))
341
+
342
+ return sch .Schema (schema )
323
343
324
344
def _get_schema_using_query (self , query : str ) -> sch .Schema :
325
345
name = util .gen_name (f"{ self .name } _metadata" )
@@ -583,7 +603,9 @@ def create_table(
583
603
create_stmt = sge .Create (
584
604
kind = "TABLE" ,
585
605
this = target ,
586
- properties = sge .Properties .from_dict (connector_properties ),
606
+ properties = sge .Properties (
607
+ expressions = sge .Properties .from_dict (connector_properties )
608
+ ),
587
609
)
588
610
create_stmt = create_stmt .sql (self .dialect ) + data_and_encode_format (
589
611
data_format , encode_format , encode_properties
@@ -742,7 +764,6 @@ def create_source(
742
764
data_format : str ,
743
765
encode_format : str ,
744
766
encode_properties : dict | None = None ,
745
- includes : dict [str , str ] | None = None ,
746
767
) -> ir .Table :
747
768
"""Creating a source.
748
769
@@ -763,32 +784,23 @@ def create_source(
763
784
The encode format for the new source, e.g., "JSON". data_format and encode_format must be specified at the same time.
764
785
encode_properties
765
786
The properties of encode format, providing information like schema registry url. Refer https://docs.risingwave.com/docs/current/sql-create-source/ for more details.
766
- includes
767
- A dict of `INCLUDE` clauses of the form `{field: alias, ...}`.
768
- Set value(s) to `None` if no alias is needed. Refer to https://docs.risingwave.com/docs/current/sql-create-source/ for more details.
769
787
770
788
Returns
771
789
-------
772
790
Table
773
791
Table expression
774
792
"""
775
- quoted = self .compiler .quoted
776
- table = sg .table (name , db = database , quoted = quoted )
793
+ table = sg .table (name , db = database , quoted = self .compiler .quoted )
777
794
target = sge .Schema (this = table , expressions = schema .to_sqlglot (self .dialect ))
778
795
779
- properties = sge .Properties .from_dict (connector_properties )
780
- properties .expressions .extend (
781
- sge .IncludeProperty (
782
- this = sg .to_identifier (include_type ),
783
- alias = sg .to_identifier (column_name , quoted = quoted )
784
- if column_name
785
- else None ,
786
- )
787
- for include_type , column_name in (includes or {}).items ()
796
+ create_stmt = sge .Create (
797
+ kind = "SOURCE" ,
798
+ this = target ,
799
+ properties = sge .Properties (
800
+ expressions = sge .Properties .from_dict (connector_properties )
801
+ ),
788
802
)
789
803
790
- create_stmt = sge .Create (kind = "SOURCE" , this = target , properties = properties )
791
-
792
804
create_stmt = create_stmt .sql (self .dialect ) + data_and_encode_format (
793
805
data_format , encode_format , encode_properties
794
806
)
0 commit comments