|
| 1 | +import json |
| 2 | +import random |
| 3 | +import string |
| 4 | +import logging |
| 5 | +from airflow.utils.decorators import apply_defaults |
| 6 | +from airflow.models import BaseOperator |
| 7 | +from airflow.hooks.S3_hook import S3Hook |
| 8 | +from airflow.hooks.postgres_hook import PostgresHook |
| 9 | +from airflow.utils.db import provide_session |
| 10 | +from airflow.models import Connection |
| 11 | + |
| 12 | + |
| 13 | +class S3ToRedshiftOperator(BaseOperator): |
| 14 | + """ |
| 15 | + S3 To Redshift Operator |
| 16 | + :param mysql_conn_id: The destination redshift connection id. |
| 17 | + :type mysql_conn_id: string |
| 18 | + :param redshift_schema: The destination redshift schema. |
| 19 | + :type redshift_schema: string |
| 20 | + :param table: The destination redshift table. |
| 21 | + :type table: string |
| 22 | + :param s3_conn_id: The source s3 connection id. |
| 23 | + :type s3_conn_id: string |
| 24 | + :param s3_bucket: The source s3 bucket. |
| 25 | + :type s3_bucket: string |
| 26 | + :param s3_key: The source s3 key. |
| 27 | + :type s3_key: string |
| 28 | + :param origin_schema: The s3 key for the incoming data schema. |
| 29 | + Expects a JSON file with a single dict |
| 30 | + specifying column and datatype as a |
| 31 | + key-value pair. (e.g. "column1":"int(11)") |
| 32 | + :type origin_schema: string |
| 33 | + :param schema_location: The location of the origin schema. This |
| 34 | + can be set to 'S3' or 'Local'. |
| 35 | + If 'S3', it will expect a valid S3 Key. If |
| 36 | + 'Local', it will expect a dictionary that |
| 37 | + is defined in the operator itself. By |
| 38 | + default the location is set to 's3'. |
| 39 | + :type schema_location: string |
| 40 | + :param origin_datatype: The incoming database type from which to |
| 41 | + convert the origin schema. Required when |
| 42 | + specifiying the origin_schema. Current |
| 43 | + possible values include "mysql". |
| 44 | + :type origin_datatype: string |
| 45 | + :param load_type: The method of loading into Redshift that |
| 46 | + should occur. Options are "append", |
| 47 | + "rebuild", and "upsert". Defaults to |
| 48 | + "append." |
| 49 | + :type load_type: string |
| 50 | + :param primary_key: *(optional)* The primary key for the |
| 51 | + destination table. Not enforced by redshift |
| 52 | + and only required if using a load_type of |
| 53 | + "upsert". |
| 54 | + :type primary_key: string |
| 55 | + :param incremental_key: *(optional)* The incremental key to compare |
| 56 | + new data against the destination table |
| 57 | + with. Only required if using a load_type of |
| 58 | + "upsert". |
| 59 | + :type incremental_key: string |
| 60 | + """ |
| 61 | + |
| 62 | + template_fields = ['s3_key', 'origin_schema'] |
| 63 | + |
| 64 | + @apply_defaults |
| 65 | + def __init__(self, |
| 66 | + s3_conn_id, |
| 67 | + s3_bucket, |
| 68 | + s3_key, |
| 69 | + redshift_conn_id, |
| 70 | + redshift_schema, |
| 71 | + table, |
| 72 | + origin_schema=None, |
| 73 | + schema_location='s3', |
| 74 | + origin_datatype=None, |
| 75 | + load_type='append', |
| 76 | + primary_key=None, |
| 77 | + incremental_key=None, |
| 78 | + timeformat='auto', |
| 79 | + *args, |
| 80 | + **kwargs): |
| 81 | + super().__init__(*args, **kwargs) |
| 82 | + self.s3_conn_id = s3_conn_id |
| 83 | + self.s3_bucket = s3_bucket |
| 84 | + self.s3_key = s3_key |
| 85 | + self.redshift_conn_id = redshift_conn_id |
| 86 | + self.redshift_schema = redshift_schema |
| 87 | + self.table = table |
| 88 | + self.origin_schema = origin_schema |
| 89 | + self.schema_location = schema_location |
| 90 | + self.origin_datatype = origin_datatype |
| 91 | + self.load_type = load_type |
| 92 | + self.primary_key = primary_key |
| 93 | + self.incremental_key = incremental_key |
| 94 | + self.timeformat = timeformat |
| 95 | + |
| 96 | + if self.load_type.lower() not in ["append", "rebuild", "upsert"]: |
| 97 | + raise Exception('Please choose "append", "rebuild", or "upsert".') |
| 98 | + |
| 99 | + if self.schema_location.lower() not in ['s3', 'local']: |
| 100 | + raise Exception('Valid Schema Locations are "s3" or "local".') |
| 101 | + |
| 102 | + def execute(self, context): |
| 103 | + # Append a random string to the end of the staging table to ensure |
| 104 | + # no conflicts if multiple processes running concurrently. |
| 105 | + letters = string.ascii_lowercase |
| 106 | + random_string = ''.join(random.choice(letters) for _ in range(7)) |
| 107 | + self.temp_suffix = '_astro_temp_{0}'.format(random_string) |
| 108 | + if self.origin_schema: |
| 109 | + schema = self.read_and_format() |
| 110 | + pg_hook = PostgresHook(self.redshift_conn_id) |
| 111 | + self.create_if_not_exists(schema, pg_hook) |
| 112 | + self.reconcile_schemas(schema, pg_hook) |
| 113 | + self.copy_data(pg_hook, schema) |
| 114 | + |
| 115 | + def read_and_format(self): |
| 116 | + if self.schema_location.lower() == 's3': |
| 117 | + hook = S3Hook(self.s3_conn_id) |
| 118 | + schema = (hook.get_key(self.origin_schema, |
| 119 | + bucket_name= |
| 120 | + '{0}'.format(self.s3_bucket)) |
| 121 | + .get_contents_as_string(encoding='utf-8')) |
| 122 | + schema = json.loads(schema.replace("'", '"')) |
| 123 | + else: |
| 124 | + schema = self.origin_schema |
| 125 | + |
| 126 | + schema_map = { |
| 127 | + "tinyint(1)": "bool", |
| 128 | + "float": "float4", |
| 129 | + "double": "float8", |
| 130 | + "int(11)": "int4", |
| 131 | + "longtext": "text", |
| 132 | + "bigint(21)": "int8" |
| 133 | + } |
| 134 | + |
| 135 | + schemaMapper = [{"avro": "string", |
| 136 | + "mysql": "varchar(256)", |
| 137 | + "redshift": "text"}, |
| 138 | + {"avro": "int", |
| 139 | + "mysql": "int(11)", |
| 140 | + "redshift": "int4"}, |
| 141 | + {"avro": "long", |
| 142 | + "mysql": "bigint(21)", |
| 143 | + "redshift": "int8"}, |
| 144 | + {"avro": "long-timestamp-millis", |
| 145 | + "redshift": "timestamp"}, |
| 146 | + {"avro": "boolean", |
| 147 | + "mysql": "tinyint(1)", |
| 148 | + "redshift": "boolean"}, |
| 149 | + {"avro": "date", |
| 150 | + "mysql": "date", |
| 151 | + "redshift": "date"}, |
| 152 | + {"avro": "long-timestamp-millis", |
| 153 | + "mysql": "timestamp(3)", |
| 154 | + "redshift": "timestamp"}, |
| 155 | + {"mysql": "float", |
| 156 | + "redshift": "float4"}, |
| 157 | + {"mysql": "double", |
| 158 | + "redshift": "float8"}, |
| 159 | + {"mysql": "longtext", |
| 160 | + "redshift": "text"}] |
| 161 | + |
| 162 | + if self.origin_datatype: |
| 163 | + if self.origin_datatype.lower() == 'mysql': |
| 164 | + for i in schema: |
| 165 | + if schema[i] in schema_map: |
| 166 | + schema[i] = schema_map[schema[i]] |
| 167 | + elif self.origin_datatype.lower() == 'avro': |
| 168 | + for i in schema: |
| 169 | + if 'logicalType' in list(i.keys()): |
| 170 | + i['type'] = '-'.join([i['type'], i['logicalType']]) |
| 171 | + del i['logicalType'] |
| 172 | + for e in schemaMapper: |
| 173 | + if 'avro' in list(e.keys()): |
| 174 | + if i['type'] == e['avro']: |
| 175 | + i['type'] = e['redshift'] |
| 176 | + |
| 177 | + print(schema) |
| 178 | + return schema |
| 179 | + |
| 180 | + def reconcile_schemas(self, schema, pg_hook): |
| 181 | + pg_query = \ |
| 182 | + """ |
| 183 | + SELECT column_name, udt_name |
| 184 | + FROM information_schema.columns |
| 185 | + WHERE table_schema = '{0}' AND table_name = '{1}'; |
| 186 | + """.format(self.redshift_schema, self.table) |
| 187 | + |
| 188 | + pg_schema = dict(pg_hook.get_records(pg_query)) |
| 189 | + incoming_keys = [column['name'] for column in schema] |
| 190 | + diff = list(set(incoming_keys) - set(pg_schema.keys())) |
| 191 | + print(diff) |
| 192 | + # Check length of column differential to see if any new columns exist |
| 193 | + if len(diff): |
| 194 | + for i in diff: |
| 195 | + for e in schema: |
| 196 | + if i == e['name']: |
| 197 | + alter_query = \ |
| 198 | + """ |
| 199 | + ALTER TABLE "{0}"."{1}" |
| 200 | + ADD COLUMN {2} {3} |
| 201 | + """.format(self.redshift_schema, |
| 202 | + self.table, |
| 203 | + e['name'], |
| 204 | + e['type']) |
| 205 | + pg_hook.run(alter_query) |
| 206 | + logging.info('The new columns were:' + str(diff)) |
| 207 | + else: |
| 208 | + logging.info('There were no new columns.') |
| 209 | + |
| 210 | + def copy_data(self, pg_hook, schema=None): |
| 211 | + @provide_session |
| 212 | + def get_conn(conn_id, session=None): |
| 213 | + conn = ( |
| 214 | + session.query(Connection) |
| 215 | + .filter(Connection.conn_id == conn_id) |
| 216 | + .first()) |
| 217 | + return conn |
| 218 | + |
| 219 | + def getS3Conn(): |
| 220 | + s3_conn = get_conn(self.s3_conn_id) |
| 221 | + aws_key = s3_conn.extra_dejson.get('aws_access_key_id') |
| 222 | + aws_secret = s3_conn.extra_dejson.get('aws_secret_access_key') |
| 223 | + return ("aws_access_key_id={0};aws_secret_access_key={1}" |
| 224 | + .format(aws_key, aws_secret)) |
| 225 | + |
| 226 | + # Delete records from the destination table where the incremental_key |
| 227 | + # is greater than or equal to the incremental_key of the source table |
| 228 | + # and the primary key is the same. |
| 229 | + # (e.g. Source: {"id": 1, "updated_at": "2017-01-02 00:00:00"}; |
| 230 | + # Destination: {"id": 1, "updated_at": "2017-01-01 00:00:00"}) |
| 231 | + |
| 232 | + delete_sql = \ |
| 233 | + ''' |
| 234 | + DELETE FROM "{rs_schema}"."{rs_table}" |
| 235 | + USING "{rs_schema}"."{rs_table}{rs_suffix}" |
| 236 | + WHERE "{rs_schema}"."{rs_table}"."{rs_pk}" = |
| 237 | + "{rs_schema}"."{rs_table}{rs_suffix}"."{rs_pk}" |
| 238 | + AND "{rs_schema}"."{rs_table}{rs_suffix}"."{rs_ik}" >= |
| 239 | + "{rs_schema}"."{rs_table}"."{rs_ik}" |
| 240 | + '''.format(rs_schema=self.redshift_schema, |
| 241 | + rs_table=self.table, |
| 242 | + rs_pk=self.primary_key, |
| 243 | + rs_suffix=self.temp_suffix, |
| 244 | + rs_ik=self.incremental_key) |
| 245 | + |
| 246 | + # Delete records from the source table where the incremental_key |
| 247 | + # is greater than or equal to the incremental_key of the destination |
| 248 | + # table and the primary key is the same. This is done in the edge case |
| 249 | + # where data is pulled BEFORE it is altered in the source table but |
| 250 | + # AFTER a workflow containing an updated version of the record runs. |
| 251 | + # In this case, not running this will cause the older record to be |
| 252 | + # added as a duplicate to the newer record. |
| 253 | + # (e.g. Source: {"id": 1, "updated_at": "2017-01-01 00:00:00"}; |
| 254 | + # Destination: {"id": 1, "updated_at": "2017-01-02 00:00:00"}) |
| 255 | + |
| 256 | + delete_confirm_sql = \ |
| 257 | + ''' |
| 258 | + DELETE FROM "{rs_schema}"."{rs_table}{rs_suffix}" |
| 259 | + USING "{rs_schema}"."{rs_table}" |
| 260 | + WHERE "{rs_schema}"."{rs_table}{rs_suffix}"."{rs_pk}" = |
| 261 | + "{rs_schema}"."{rs_table}"."{rs_pk}" |
| 262 | + AND "{rs_schema}"."{rs_table}"."{rs_ik}" >= |
| 263 | + "{rs_schema}"."{rs_table}{rs_suffix}"."{rs_ik}" |
| 264 | + '''.format(rs_schema=self.redshift_schema, |
| 265 | + rs_table=self.table, |
| 266 | + rs_pk=self.primary_key, |
| 267 | + rs_suffix=self.temp_suffix, |
| 268 | + rs_ik=self.incremental_key) |
| 269 | + |
| 270 | + append_sql = \ |
| 271 | + ''' |
| 272 | + ALTER TABLE "{0}"."{1}" |
| 273 | + APPEND FROM "{0}"."{1}{2}" |
| 274 | + FILLTARGET |
| 275 | + '''.format(self.redshift_schema, self.table, self.temp_suffix) |
| 276 | + |
| 277 | + drop_temp_sql = \ |
| 278 | + ''' |
| 279 | + DROP TABLE IF EXISTS "{0}"."{1}{2}" |
| 280 | + '''.format(self.redshift_schema, self.table, self.temp_suffix) |
| 281 | + |
| 282 | + truncate_sql = \ |
| 283 | + ''' |
| 284 | + TRUNCATE TABLE "{0}"."{1}" |
| 285 | + '''.format(self.redshift_schema, self.table) |
| 286 | + |
| 287 | + base_sql = \ |
| 288 | + """ |
| 289 | + FROM 's3://{0}/{1}' |
| 290 | + CREDENTIALS '{2}' |
| 291 | + COMPUPDATE OFF |
| 292 | + STATUPDATE OFF |
| 293 | + JSON 'auto' |
| 294 | + TIMEFORMAT '{3}' |
| 295 | + TRUNCATECOLUMNS |
| 296 | + region as 'us-east-1'; |
| 297 | + """.format(self.s3_bucket, |
| 298 | + self.s3_key, |
| 299 | + getS3Conn(), |
| 300 | + self.timeformat) |
| 301 | + |
| 302 | + load_sql = '''COPY "{0}"."{1}" {2}'''.format(self.redshift_schema, |
| 303 | + self.table, |
| 304 | + base_sql) |
| 305 | + if self.load_type == 'append': |
| 306 | + pg_hook.run(load_sql) |
| 307 | + elif self.load_type == 'upsert': |
| 308 | + self.create_if_not_exists(schema, pg_hook, temp=True) |
| 309 | + load_temp_sql = \ |
| 310 | + '''COPY "{0}"."{1}{2}" {3}'''.format(self.redshift_schema, |
| 311 | + self.table, |
| 312 | + self.temp_suffix, |
| 313 | + base_sql) |
| 314 | + pg_hook.run(load_temp_sql) |
| 315 | + pg_hook.run(delete_sql) |
| 316 | + pg_hook.run(delete_confirm_sql) |
| 317 | + pg_hook.run(append_sql, autocommit=True) |
| 318 | + pg_hook.run(drop_temp_sql) |
| 319 | + elif self.load_type == 'rebuild': |
| 320 | + pg_hook.run(truncate_sql) |
| 321 | + pg_hook.run(load_sql) |
| 322 | + |
| 323 | + def create_if_not_exists(self, schema, pg_hook, temp=False): |
| 324 | + output = '' |
| 325 | + for item in schema: |
| 326 | + k = "{quote}{key}{quote}".format(quote='"', key=item['name']) |
| 327 | + field = ' '.join([k, item['type']]) |
| 328 | + output += field |
| 329 | + output += ', ' |
| 330 | + # Remove last comma and space after schema items loop ends |
| 331 | + output = output[:-2] |
| 332 | + if temp: |
| 333 | + copy_table = '{0}{1}'.format(self.table, self.temp_suffix) |
| 334 | + else: |
| 335 | + copy_table = self.table |
| 336 | + create_schema_query = \ |
| 337 | + '''CREATE SCHEMA IF NOT EXISTS "{0}";'''.format( |
| 338 | + self.redshift_schema) |
| 339 | + create_table_query = \ |
| 340 | + '''CREATE TABLE IF NOT EXISTS "{0}"."{1}" ({2})'''.format( |
| 341 | + self.redshift_schema, |
| 342 | + copy_table, |
| 343 | + output) |
| 344 | + pg_hook.run(create_schema_query) |
| 345 | + pg_hook.run(create_table_query) |
0 commit comments