22import random
33import string
44import logging
5+
6+ from airflow .utils .db import provide_session
7+ from airflow .models import Connection
58from airflow .utils .decorators import apply_defaults
9+
610from airflow .models import BaseOperator
711from airflow .hooks .S3_hook import S3Hook
812from airflow .hooks .postgres_hook import PostgresHook
9- from airflow .utils .db import provide_session
10- from airflow .models import Connection
1113
1214
1315class S3ToRedshiftOperator (BaseOperator ):
@@ -46,9 +48,12 @@ class S3ToRedshiftOperator(BaseOperator):
4648 possible values include "mysql".
4749 :type origin_datatype: string
4850 :param load_type: The method of loading into Redshift that
49- should occur. Options are "append",
50- "rebuild", and "upsert". Defaults to
51- "append."
51+ should occur. Options:
52+ - "append"
53+ - "rebuild"
54+ - "truncate"
55+ - "upsert"
56+ Defaults to "append."
5257 :type load_type: string
5358 :param primary_key: *(optional)* The primary key for the
5459 destination table. Not enforced by redshift
@@ -128,10 +133,10 @@ def __init__(self,
128133 self .sortkey = sortkey
129134 self .sort_type = sort_type
130135
131- if self .load_type .lower () not in [ "append" , "rebuild" , "upsert" ] :
136+ if self .load_type .lower () not in ( "append" , "rebuild" , "truncate" , " upsert") :
132137 raise Exception ('Please choose "append", "rebuild", or "upsert".' )
133138
134- if self .schema_location .lower () not in [ 's3' , 'local' ] :
139+ if self .schema_location .lower () not in ( 's3' , 'local' ) :
135140 raise Exception ('Valid Schema Locations are "s3" or "local".' )
136141
137142 if not (isinstance (self .sortkey , str ) or isinstance (self .sortkey , list )):
@@ -152,9 +157,12 @@ def execute(self, context):
152157 letters = string .ascii_lowercase
153158 random_string = '' .join (random .choice (letters ) for _ in range (7 ))
154159 self .temp_suffix = '_tmp_{0}' .format (random_string )
160+
155161 if self .origin_schema :
156162 schema = self .read_and_format ()
163+
157164 pg_hook = PostgresHook (self .redshift_conn_id )
165+
158166 self .create_if_not_exists (schema , pg_hook )
159167 self .reconcile_schemas (schema , pg_hook )
160168 self .copy_data (pg_hook , schema )
@@ -221,7 +229,6 @@ def read_and_format(self):
221229 if i ['type' ] == e ['avro' ]:
222230 i ['type' ] = e ['redshift' ]
223231
224- print (schema )
225232 return schema
226233
227234 def reconcile_schemas (self , schema , pg_hook ):
@@ -277,7 +284,7 @@ def getS3Conn():
277284 elif aws_role_arn :
278285 creds = ("aws_iam_role={0}"
279286 .format (aws_role_arn ))
280-
287+
281288 return creds
282289
283290 # Delete records from the destination table where the incremental_key
@@ -331,6 +338,11 @@ def getS3Conn():
331338 FILLTARGET
332339 ''' .format (self .redshift_schema , self .table , self .temp_suffix )
333340
341+ drop_sql = \
342+ '''
343+ DROP TABLE IF EXISTS "{0}"."{1}"
344+ ''' .format (self .redshift_schema , self .table )
345+
334346 drop_temp_sql = \
335347 '''
336348 DROP TABLE IF EXISTS "{0}"."{1}{2}"
@@ -366,6 +378,13 @@ def getS3Conn():
366378 base_sql )
367379 if self .load_type == 'append' :
368380 pg_hook .run (load_sql )
381+ elif self .load_type == 'rebuild' :
382+ pg_hook .run (drop_sql )
383+ self .create_if_not_exists (schema , pg_hook )
384+ pg_hook .run (load_sql )
385+ elif self .load_type == 'truncate' :
386+ pg_hook .run (truncate_sql )
387+ pg_hook .run (load_sql )
369388 elif self .load_type == 'upsert' :
370389 self .create_if_not_exists (schema , pg_hook , temp = True )
371390 load_temp_sql = \
@@ -378,9 +397,6 @@ def getS3Conn():
378397 pg_hook .run (delete_confirm_sql )
379398 pg_hook .run (append_sql , autocommit = True )
380399 pg_hook .run (drop_temp_sql )
381- elif self .load_type == 'rebuild' :
382- pg_hook .run (truncate_sql )
383- pg_hook .run (load_sql )
384400
385401 def create_if_not_exists (self , schema , pg_hook , temp = False ):
386402 output = ''
0 commit comments