@@ -60,11 +60,33 @@ class S3ToRedshiftOperator(BaseOperator):
6060 with. Only required if using a load_type of
6161 "upsert".
6262 :type incremental_key: string
63+ :param foreign_key: *(optional)* This specifies any foreign_keys
64+ in the table and which corresponding table
65+ and key they reference. This may be either
66+ a dictionary or list of dictionaries (for
67+ multiple foreign keys). The fields that are
68+ required in each dictionary are:
69+ - column_name
70+ - reftable
71+ - ref_column
72+ :type foreign_key: dictionary
73+ :param distkey: *(optional)* The distribution key for the
74+ table. Only one key may be specified.
75+ :type distkey: string
76+ :param sortkey: *(optional)* The sort keys for the table.
77+ If more than one key is specified, set this
78+ as a list.
79+ :type sortkey: string
80+ :param sort_type: *(optional)* The style of distribution
81+ to sort the table. Possible values include:
82+ - compound
83+ - interleaved
84+ Defaults to "compound".
85+ :type sort_type: string
6386 """
6487
6588 template_fields = ('s3_key' ,
66- 'origin_schema' ,
67- 'com' )
89+ 'origin_schema' )
6890
6991 @apply_defaults
7092 def __init__ (self ,
@@ -81,7 +103,10 @@ def __init__(self,
81103 load_type = 'append' ,
82104 primary_key = None ,
83105 incremental_key = None ,
84- timeformat = 'auto' ,
106+ foreign_key = {},
107+ distkey = None ,
108+ sortkey = '' ,
109+ sort_type = 'COMPOUND' ,
85110 * args ,
86111 ** kwargs ):
87112 super ().__init__ (* args , ** kwargs )
@@ -98,14 +123,29 @@ def __init__(self,
98123 self .load_type = load_type
99124 self .primary_key = primary_key
100125 self .incremental_key = incremental_key
101- self .timeformat = timeformat
126+ self .foreign_key = foreign_key
127+ self .distkey = distkey
128+ self .sortkey = sortkey
129+ self .sort_type = sort_type
102130
103131 if self .load_type .lower () not in ["append" , "rebuild" , "upsert" ]:
104132 raise Exception ('Please choose "append", "rebuild", or "upsert".' )
105133
106134 if self .schema_location .lower () not in ['s3' , 'local' ]:
107135 raise Exception ('Valid Schema Locations are "s3" or "local".' )
108136
137+ if not (isinstance (self .sortkey , str ) or isinstance (self .sortkey , list )):
138+ raise Exception ('Sort Keys must be specified as either a string or list.' )
139+
140+ if not (isinstance (self .foreign_key , dict ) or isinstance (self .foreign_key , list )):
141+ raise Exception ('Foreign Keys must be specified as either a dictionary or a list of dictionaries.' )
142+
143+ if ((',' in self .distkey ) or not isinstance (self .distkey , str )):
144+ raise Exception ('Only one distribution key may be specified.' )
145+
146+ if self .sort_type .lower () not in ('compound' , 'interleaved' ):
147+ raise Exception ('Please choose "compound" or "interleaved" for sort type.' )
148+
109149 def execute (self , context ):
110150 # Append a random string to the end of the staging table to ensure
111151 # no conflicts if multiple processes running concurrently.
@@ -337,6 +377,8 @@ def create_if_not_exists(self, schema, pg_hook, temp=False):
337377 for item in schema :
338378 k = "{quote}{key}{quote}" .format (quote = '"' , key = item ['name' ])
339379 field = ' ' .join ([k , item ['type' ]])
380+ if isinstance (self .sortkey , str ) and self .sortkey == item ['name' ]:
381+ field += ' sortkey'
340382 output += field
341383 output += ', '
342384 # Remove last comma and space after schema items loop ends
@@ -346,12 +388,50 @@ def create_if_not_exists(self, schema, pg_hook, temp=False):
346388 else :
347389 copy_table = self .table
348390 create_schema_query = \
349- '''CREATE SCHEMA IF NOT EXISTS "{0}";''' .format (
350- self .redshift_schema )
391+ '''
392+ CREATE SCHEMA IF NOT EXISTS "{0}";
393+ ''' .format (self .redshift_schema )
394+
395+ pk = ''
396+ fk = ''
397+ dk = ''
398+ sk = ''
399+
400+ if self .primary_key :
401+ pk = ', primary key("{0}")' .format (self .primary_key )
402+
403+ if self .foreign_key :
404+ if isinstance (self .foreign_key , list ):
405+ fk = ', '
406+ for i , e in enumerate (self .foreign_key ):
407+ fk += 'foreign key("{0}") references {1}("{2}")' .format (e ['column_name' ],
408+ e ['reftable' ],
409+ e ['ref_column' ])
410+ if i != (len (self .foreign_key ) - 1 ):
411+ fk += ', ' ""
412+ elif isinstance (self .foreign_key , dict ):
413+ fk += ', '
414+ fk += 'foreign key("{0}") references {1}("{2}")' .format (self .foreign_key ['column_name' ],
415+ self .foreign_key ['reftable' ],
416+ self .foreign_key ['ref_column' ])
417+ if self .distkey :
418+ dk = 'distkey({})' .format (self .distkey )
419+
420+ if self .sortkey :
421+ if isinstance (self .sortkey , list ):
422+ sk += '{0} sortkey({1})' .format (self .sort_type , ', ' .join (["{}" .format (e ) for e in self .sortkey ]))
423+
351424 create_table_query = \
352- '''CREATE TABLE IF NOT EXISTS "{0}"."{1}" ({2})''' .format (
353- self .redshift_schema ,
354- copy_table ,
355- output )
425+ '''
426+ CREATE TABLE IF NOT EXISTS "{schema}"."{table}"
427+ ({fields}{primary_key}{foreign_key}) {distkey} {sortkey}
428+ ''' .format (schema = self .redshift_schema ,
429+ table = copy_table ,
430+ fields = output ,
431+ primary_key = pk ,
432+ foreign_key = fk ,
433+ distkey = dk ,
434+ sortkey = sk )
435+
356436 pg_hook .run (create_schema_query )
357437 pg_hook .run (create_table_query )
0 commit comments