diff --git a/bonobo_sqlalchemy/writers.py b/bonobo_sqlalchemy/writers.py index 0734d41..537d53d 100644 --- a/bonobo_sqlalchemy/writers.py +++ b/bonobo_sqlalchemy/writers.py @@ -95,10 +95,7 @@ def commit(self, table, connection, buffer, force=False): if force or (buffer.qsize() >= self.buffer_size): with connection.begin(): while buffer.qsize() > 0: - try: - yield self.insert_or_update(table, connection, buffer.get()) - except Exception as exc: - yield exc + yield self.insert_or_update(table, connection, buffer.get()) def insert_or_update(self, table, connection, row): """ Actual database load transformation logic, without the buffering / transaction logic. @@ -110,11 +107,12 @@ def insert_or_update(self, table, connection, row): # TODO XXX use actual database function instead of this stupid thing now = datetime.datetime.now() + target_row = row._asdict() column_names = table.columns.keys() # UpdatedAt field configured ? Let's set the value in source hash if self.updated_at_field in column_names: - row[self.updated_at_field] = now # XXX not pure ... + target_row[self.updated_at_field] = now # XXX not pure ... # Update logic if dbrow: @@ -122,9 +120,9 @@ def insert_or_update(self, table, connection, row): raise ProhibitedOperationError('UPDATE operations are not allowed by this transformation.') query = table.update().values( - **{col: row.get(col) - for col in self.get_columns_for(column_names, row, dbrow)} - ).where(and_(*(getattr(table.c, col) == row.get(col) for col in self.discriminant))) + **{col: target_row.get(col) + for col in self.get_columns_for(column_names, target_row, dbrow)} + ).where(and_(*(getattr(table.c, col) == target_row.get(col) for col in self.discriminant))) # INSERT else: @@ -132,12 +130,12 @@ def insert_or_update(self, table, connection, row): raise ProhibitedOperationError('INSERT operations are not allowed by this transformation.') if self.created_at_field in column_names: - row[self.created_at_field] = now # XXX UNPURE + target_row[self.created_at_field] = now # XXX UNPURE else: - if self.created_at_field in row: - del row[self.created_at_field] # UNPURE + if self.created_at_field in target_row: + del target_row[self.created_at_field] # UNPURE - query = table.insert().values(**{col: row.get(col) for col in self.get_columns_for(column_names, row)}) + query = table.insert().values(**{col: target_row.get(col) for col in self.get_columns_for(column_names, target_row)}) # Execute try: @@ -175,7 +173,7 @@ def find(self, connection, table, row): return dict(row) if row else None - def get_columns_for(self, column_names, row, dbrow=None): + def get_columns_for(self, column_names, target_row, dbrow=None): """Retrieve list of table column names for which we have a value in given hash. """ @@ -184,12 +182,7 @@ def get_columns_for(self, column_names, row, dbrow=None): else: candidates = column_names - try: - fields = row._fields - except AttributeError as exc: - fields = list(row.keys()) - - return set(candidates).intersection(fields) + return set(candidates).intersection(target_row.keys()) def add_fetch_columns(self, *columns, **aliased_columns): self.fetch_columns = {