Skip to content

Commit d776c51

Browse files
change tmp name and abstract copy params
1 parent b2c3302 commit d776c51

File tree

5 files changed

+48
-11
lines changed

5 files changed

+48
-11
lines changed

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# S3 To Redshift Operator
2+
3+
4+
# License
5+
Apache 2.0

__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from airflow.plugins_manager import AirflowPlugin
2-
from redshift_plugin.operators.s3_to_redshift import S3ToRedshiftOperator
2+
from s3_to_redshift_operator.operators.s3_to_redshift import S3ToRedshiftOperator
3+
from s3_to_redshift_operator.macros.redshift_auth import redshift_auth
34

45

56
class S3ToRedshiftPlugin(AirflowPlugin):
@@ -8,7 +9,7 @@ class S3ToRedshiftPlugin(AirflowPlugin):
89
# Leave in for explicitness
910
hooks = []
1011
executors = []
11-
macros = []
12+
macros = [redshift_auth]
1213
admin_views = []
1314
flask_blueprints = []
1415
menu_links = []

macros/__init__.py

Whitespace-only changes.

macros/redshift_auth.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from airflow.utils.db import provide_session
2+
from airflow.models import Connection
3+
4+
5+
@provide_session
6+
def get_conn(conn_id, session=None):
7+
conn = (
8+
session.query(Connection)
9+
.filter(Connection.conn_id == conn_id)
10+
.first())
11+
return conn
12+
13+
14+
def redshift_auth(s3_conn_id):
15+
s3_conn = get_conn(s3_conn_id)
16+
aws_key = s3_conn.extra_dejson.get('aws_access_key_id')
17+
aws_secret = s3_conn.extra_dejson.get('aws_secret_access_key')
18+
return ("aws_access_key_id={0};aws_secret_access_key={1}"
19+
.format(aws_key, aws_secret))

operators/s3_to_redshift.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ class S3ToRedshiftOperator(BaseOperator):
2525
:type s3_bucket: string
2626
:param s3_key: The source s3 key.
2727
:type s3_key: string
28+
:param copy_params: The parameters to be included when issuing
29+
the copy statement in Redshift.
30+
:type copy_params: string
2831
:param origin_schema: The s3 key for the incoming data schema.
2932
Expects a JSON file with a single dict
3033
specifying column and datatype as a
@@ -59,7 +62,9 @@ class S3ToRedshiftOperator(BaseOperator):
5962
:type incremental_key: string
6063
"""
6164

62-
template_fields = ['s3_key', 'origin_schema']
65+
template_fields = ('s3_key',
66+
'origin_schema',
67+
'com')
6368

6469
@apply_defaults
6570
def __init__(self,
@@ -69,6 +74,7 @@ def __init__(self,
6974
redshift_conn_id,
7075
redshift_schema,
7176
table,
77+
copy_params=[],
7278
origin_schema=None,
7379
schema_location='s3',
7480
origin_datatype=None,
@@ -85,6 +91,7 @@ def __init__(self,
8591
self.redshift_conn_id = redshift_conn_id
8692
self.redshift_schema = redshift_schema
8793
self.table = table
94+
self.copy_params = copy_params
8895
self.origin_schema = origin_schema
8996
self.schema_location = schema_location
9097
self.origin_datatype = origin_datatype
@@ -104,7 +111,7 @@ def execute(self, context):
104111
# no conflicts if multiple processes running concurrently.
105112
letters = string.ascii_lowercase
106113
random_string = ''.join(random.choice(letters) for _ in range(7))
107-
self.temp_suffix = '_astro_temp_{0}'.format(random_string)
114+
self.temp_suffix = '_tmp_{0}'.format(random_string)
108115
if self.origin_schema:
109116
schema = self.read_and_format()
110117
pg_hook = PostgresHook(self.redshift_conn_id)
@@ -284,20 +291,25 @@ def getS3Conn():
284291
TRUNCATE TABLE "{0}"."{1}"
285292
'''.format(self.redshift_schema, self.table)
286293

294+
params = '\n'.join(self.copy_params)
295+
296+
# Example params for loading json from US-East-1 S3 region
297+
# params = ["COMPUPDATE OFF",
298+
# "STATUPDATE OFF",
299+
# "JSON 'auto'",
300+
# "TIMEFORMAT 'auto'",
301+
# "TRUNCATECOLUMNS",
302+
# "region as 'us-east-1'"]
303+
287304
base_sql = \
288305
"""
289306
FROM 's3://{0}/{1}'
290307
CREDENTIALS '{2}'
291-
COMPUPDATE OFF
292-
STATUPDATE OFF
293-
JSON 'auto'
294-
TIMEFORMAT '{3}'
295-
TRUNCATECOLUMNS
296-
region as 'us-east-1';
308+
{3};
297309
""".format(self.s3_bucket,
298310
self.s3_key,
299311
getS3Conn(),
300-
self.timeformat)
312+
params)
301313

302314
load_sql = '''COPY "{0}"."{1}" {2}'''.format(self.redshift_schema,
303315
self.table,

0 commit comments

Comments
 (0)