Skip to content

Commit b2c3302

Browse files
ic
0 parents  commit b2c3302

File tree

4 files changed

+360
-0
lines changed

4 files changed

+360
-0
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
.DS_Store

__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from airflow.plugins_manager import AirflowPlugin
2+
from redshift_plugin.operators.s3_to_redshift import S3ToRedshiftOperator
3+
4+
5+
class S3ToRedshiftPlugin(AirflowPlugin):
6+
name = "S3ToRedshiftPlugin"
7+
operators = [S3ToRedshiftOperator]
8+
# Leave in for explicitness
9+
hooks = []
10+
executors = []
11+
macros = []
12+
admin_views = []
13+
flask_blueprints = []
14+
menu_links = []

operators/__init__.py

Whitespace-only changes.

operators/s3_to_redshift.py

Lines changed: 345 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,345 @@
1+
import json
2+
import random
3+
import string
4+
import logging
5+
from airflow.utils.decorators import apply_defaults
6+
from airflow.models import BaseOperator
7+
from airflow.hooks.S3_hook import S3Hook
8+
from airflow.hooks.postgres_hook import PostgresHook
9+
from airflow.utils.db import provide_session
10+
from airflow.models import Connection
11+
12+
13+
class S3ToRedshiftOperator(BaseOperator):
14+
"""
15+
S3 To Redshift Operator
16+
:param mysql_conn_id: The destination redshift connection id.
17+
:type mysql_conn_id: string
18+
:param redshift_schema: The destination redshift schema.
19+
:type redshift_schema: string
20+
:param table: The destination redshift table.
21+
:type table: string
22+
:param s3_conn_id: The source s3 connection id.
23+
:type s3_conn_id: string
24+
:param s3_bucket: The source s3 bucket.
25+
:type s3_bucket: string
26+
:param s3_key: The source s3 key.
27+
:type s3_key: string
28+
:param origin_schema: The s3 key for the incoming data schema.
29+
Expects a JSON file with a single dict
30+
specifying column and datatype as a
31+
key-value pair. (e.g. "column1":"int(11)")
32+
:type origin_schema: string
33+
:param schema_location: The location of the origin schema. This
34+
can be set to 'S3' or 'Local'.
35+
If 'S3', it will expect a valid S3 Key. If
36+
'Local', it will expect a dictionary that
37+
is defined in the operator itself. By
38+
default the location is set to 's3'.
39+
:type schema_location: string
40+
:param origin_datatype: The incoming database type from which to
41+
convert the origin schema. Required when
42+
specifiying the origin_schema. Current
43+
possible values include "mysql".
44+
:type origin_datatype: string
45+
:param load_type: The method of loading into Redshift that
46+
should occur. Options are "append",
47+
"rebuild", and "upsert". Defaults to
48+
"append."
49+
:type load_type: string
50+
:param primary_key: *(optional)* The primary key for the
51+
destination table. Not enforced by redshift
52+
and only required if using a load_type of
53+
"upsert".
54+
:type primary_key: string
55+
:param incremental_key: *(optional)* The incremental key to compare
56+
new data against the destination table
57+
with. Only required if using a load_type of
58+
"upsert".
59+
:type incremental_key: string
60+
"""
61+
62+
template_fields = ['s3_key', 'origin_schema']
63+
64+
@apply_defaults
65+
def __init__(self,
66+
s3_conn_id,
67+
s3_bucket,
68+
s3_key,
69+
redshift_conn_id,
70+
redshift_schema,
71+
table,
72+
origin_schema=None,
73+
schema_location='s3',
74+
origin_datatype=None,
75+
load_type='append',
76+
primary_key=None,
77+
incremental_key=None,
78+
timeformat='auto',
79+
*args,
80+
**kwargs):
81+
super().__init__(*args, **kwargs)
82+
self.s3_conn_id = s3_conn_id
83+
self.s3_bucket = s3_bucket
84+
self.s3_key = s3_key
85+
self.redshift_conn_id = redshift_conn_id
86+
self.redshift_schema = redshift_schema
87+
self.table = table
88+
self.origin_schema = origin_schema
89+
self.schema_location = schema_location
90+
self.origin_datatype = origin_datatype
91+
self.load_type = load_type
92+
self.primary_key = primary_key
93+
self.incremental_key = incremental_key
94+
self.timeformat = timeformat
95+
96+
if self.load_type.lower() not in ["append", "rebuild", "upsert"]:
97+
raise Exception('Please choose "append", "rebuild", or "upsert".')
98+
99+
if self.schema_location.lower() not in ['s3', 'local']:
100+
raise Exception('Valid Schema Locations are "s3" or "local".')
101+
102+
def execute(self, context):
103+
# Append a random string to the end of the staging table to ensure
104+
# no conflicts if multiple processes running concurrently.
105+
letters = string.ascii_lowercase
106+
random_string = ''.join(random.choice(letters) for _ in range(7))
107+
self.temp_suffix = '_astro_temp_{0}'.format(random_string)
108+
if self.origin_schema:
109+
schema = self.read_and_format()
110+
pg_hook = PostgresHook(self.redshift_conn_id)
111+
self.create_if_not_exists(schema, pg_hook)
112+
self.reconcile_schemas(schema, pg_hook)
113+
self.copy_data(pg_hook, schema)
114+
115+
def read_and_format(self):
116+
if self.schema_location.lower() == 's3':
117+
hook = S3Hook(self.s3_conn_id)
118+
schema = (hook.get_key(self.origin_schema,
119+
bucket_name=
120+
'{0}'.format(self.s3_bucket))
121+
.get_contents_as_string(encoding='utf-8'))
122+
schema = json.loads(schema.replace("'", '"'))
123+
else:
124+
schema = self.origin_schema
125+
126+
schema_map = {
127+
"tinyint(1)": "bool",
128+
"float": "float4",
129+
"double": "float8",
130+
"int(11)": "int4",
131+
"longtext": "text",
132+
"bigint(21)": "int8"
133+
}
134+
135+
schemaMapper = [{"avro": "string",
136+
"mysql": "varchar(256)",
137+
"redshift": "text"},
138+
{"avro": "int",
139+
"mysql": "int(11)",
140+
"redshift": "int4"},
141+
{"avro": "long",
142+
"mysql": "bigint(21)",
143+
"redshift": "int8"},
144+
{"avro": "long-timestamp-millis",
145+
"redshift": "timestamp"},
146+
{"avro": "boolean",
147+
"mysql": "tinyint(1)",
148+
"redshift": "boolean"},
149+
{"avro": "date",
150+
"mysql": "date",
151+
"redshift": "date"},
152+
{"avro": "long-timestamp-millis",
153+
"mysql": "timestamp(3)",
154+
"redshift": "timestamp"},
155+
{"mysql": "float",
156+
"redshift": "float4"},
157+
{"mysql": "double",
158+
"redshift": "float8"},
159+
{"mysql": "longtext",
160+
"redshift": "text"}]
161+
162+
if self.origin_datatype:
163+
if self.origin_datatype.lower() == 'mysql':
164+
for i in schema:
165+
if schema[i] in schema_map:
166+
schema[i] = schema_map[schema[i]]
167+
elif self.origin_datatype.lower() == 'avro':
168+
for i in schema:
169+
if 'logicalType' in list(i.keys()):
170+
i['type'] = '-'.join([i['type'], i['logicalType']])
171+
del i['logicalType']
172+
for e in schemaMapper:
173+
if 'avro' in list(e.keys()):
174+
if i['type'] == e['avro']:
175+
i['type'] = e['redshift']
176+
177+
print(schema)
178+
return schema
179+
180+
def reconcile_schemas(self, schema, pg_hook):
181+
pg_query = \
182+
"""
183+
SELECT column_name, udt_name
184+
FROM information_schema.columns
185+
WHERE table_schema = '{0}' AND table_name = '{1}';
186+
""".format(self.redshift_schema, self.table)
187+
188+
pg_schema = dict(pg_hook.get_records(pg_query))
189+
incoming_keys = [column['name'] for column in schema]
190+
diff = list(set(incoming_keys) - set(pg_schema.keys()))
191+
print(diff)
192+
# Check length of column differential to see if any new columns exist
193+
if len(diff):
194+
for i in diff:
195+
for e in schema:
196+
if i == e['name']:
197+
alter_query = \
198+
"""
199+
ALTER TABLE "{0}"."{1}"
200+
ADD COLUMN {2} {3}
201+
""".format(self.redshift_schema,
202+
self.table,
203+
e['name'],
204+
e['type'])
205+
pg_hook.run(alter_query)
206+
logging.info('The new columns were:' + str(diff))
207+
else:
208+
logging.info('There were no new columns.')
209+
210+
def copy_data(self, pg_hook, schema=None):
211+
@provide_session
212+
def get_conn(conn_id, session=None):
213+
conn = (
214+
session.query(Connection)
215+
.filter(Connection.conn_id == conn_id)
216+
.first())
217+
return conn
218+
219+
def getS3Conn():
220+
s3_conn = get_conn(self.s3_conn_id)
221+
aws_key = s3_conn.extra_dejson.get('aws_access_key_id')
222+
aws_secret = s3_conn.extra_dejson.get('aws_secret_access_key')
223+
return ("aws_access_key_id={0};aws_secret_access_key={1}"
224+
.format(aws_key, aws_secret))
225+
226+
# Delete records from the destination table where the incremental_key
227+
# is greater than or equal to the incremental_key of the source table
228+
# and the primary key is the same.
229+
# (e.g. Source: {"id": 1, "updated_at": "2017-01-02 00:00:00"};
230+
# Destination: {"id": 1, "updated_at": "2017-01-01 00:00:00"})
231+
232+
delete_sql = \
233+
'''
234+
DELETE FROM "{rs_schema}"."{rs_table}"
235+
USING "{rs_schema}"."{rs_table}{rs_suffix}"
236+
WHERE "{rs_schema}"."{rs_table}"."{rs_pk}" =
237+
"{rs_schema}"."{rs_table}{rs_suffix}"."{rs_pk}"
238+
AND "{rs_schema}"."{rs_table}{rs_suffix}"."{rs_ik}" >=
239+
"{rs_schema}"."{rs_table}"."{rs_ik}"
240+
'''.format(rs_schema=self.redshift_schema,
241+
rs_table=self.table,
242+
rs_pk=self.primary_key,
243+
rs_suffix=self.temp_suffix,
244+
rs_ik=self.incremental_key)
245+
246+
# Delete records from the source table where the incremental_key
247+
# is greater than or equal to the incremental_key of the destination
248+
# table and the primary key is the same. This is done in the edge case
249+
# where data is pulled BEFORE it is altered in the source table but
250+
# AFTER a workflow containing an updated version of the record runs.
251+
# In this case, not running this will cause the older record to be
252+
# added as a duplicate to the newer record.
253+
# (e.g. Source: {"id": 1, "updated_at": "2017-01-01 00:00:00"};
254+
# Destination: {"id": 1, "updated_at": "2017-01-02 00:00:00"})
255+
256+
delete_confirm_sql = \
257+
'''
258+
DELETE FROM "{rs_schema}"."{rs_table}{rs_suffix}"
259+
USING "{rs_schema}"."{rs_table}"
260+
WHERE "{rs_schema}"."{rs_table}{rs_suffix}"."{rs_pk}" =
261+
"{rs_schema}"."{rs_table}"."{rs_pk}"
262+
AND "{rs_schema}"."{rs_table}"."{rs_ik}" >=
263+
"{rs_schema}"."{rs_table}{rs_suffix}"."{rs_ik}"
264+
'''.format(rs_schema=self.redshift_schema,
265+
rs_table=self.table,
266+
rs_pk=self.primary_key,
267+
rs_suffix=self.temp_suffix,
268+
rs_ik=self.incremental_key)
269+
270+
append_sql = \
271+
'''
272+
ALTER TABLE "{0}"."{1}"
273+
APPEND FROM "{0}"."{1}{2}"
274+
FILLTARGET
275+
'''.format(self.redshift_schema, self.table, self.temp_suffix)
276+
277+
drop_temp_sql = \
278+
'''
279+
DROP TABLE IF EXISTS "{0}"."{1}{2}"
280+
'''.format(self.redshift_schema, self.table, self.temp_suffix)
281+
282+
truncate_sql = \
283+
'''
284+
TRUNCATE TABLE "{0}"."{1}"
285+
'''.format(self.redshift_schema, self.table)
286+
287+
base_sql = \
288+
"""
289+
FROM 's3://{0}/{1}'
290+
CREDENTIALS '{2}'
291+
COMPUPDATE OFF
292+
STATUPDATE OFF
293+
JSON 'auto'
294+
TIMEFORMAT '{3}'
295+
TRUNCATECOLUMNS
296+
region as 'us-east-1';
297+
""".format(self.s3_bucket,
298+
self.s3_key,
299+
getS3Conn(),
300+
self.timeformat)
301+
302+
load_sql = '''COPY "{0}"."{1}" {2}'''.format(self.redshift_schema,
303+
self.table,
304+
base_sql)
305+
if self.load_type == 'append':
306+
pg_hook.run(load_sql)
307+
elif self.load_type == 'upsert':
308+
self.create_if_not_exists(schema, pg_hook, temp=True)
309+
load_temp_sql = \
310+
'''COPY "{0}"."{1}{2}" {3}'''.format(self.redshift_schema,
311+
self.table,
312+
self.temp_suffix,
313+
base_sql)
314+
pg_hook.run(load_temp_sql)
315+
pg_hook.run(delete_sql)
316+
pg_hook.run(delete_confirm_sql)
317+
pg_hook.run(append_sql, autocommit=True)
318+
pg_hook.run(drop_temp_sql)
319+
elif self.load_type == 'rebuild':
320+
pg_hook.run(truncate_sql)
321+
pg_hook.run(load_sql)
322+
323+
def create_if_not_exists(self, schema, pg_hook, temp=False):
324+
output = ''
325+
for item in schema:
326+
k = "{quote}{key}{quote}".format(quote='"', key=item['name'])
327+
field = ' '.join([k, item['type']])
328+
output += field
329+
output += ', '
330+
# Remove last comma and space after schema items loop ends
331+
output = output[:-2]
332+
if temp:
333+
copy_table = '{0}{1}'.format(self.table, self.temp_suffix)
334+
else:
335+
copy_table = self.table
336+
create_schema_query = \
337+
'''CREATE SCHEMA IF NOT EXISTS "{0}";'''.format(
338+
self.redshift_schema)
339+
create_table_query = \
340+
'''CREATE TABLE IF NOT EXISTS "{0}"."{1}" ({2})'''.format(
341+
self.redshift_schema,
342+
copy_table,
343+
output)
344+
pg_hook.run(create_schema_query)
345+
pg_hook.run(create_table_query)

0 commit comments

Comments
 (0)