Skip to content

Fix InsertOrUpdate for Bonobo 0.6 #44

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 12 additions & 19 deletions bonobo_sqlalchemy/writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,7 @@ def commit(self, table, connection, buffer, force=False):
if force or (buffer.qsize() >= self.buffer_size):
with connection.begin():
while buffer.qsize() > 0:
try:
yield self.insert_or_update(table, connection, buffer.get())
except Exception as exc:
yield exc
yield self.insert_or_update(table, connection, buffer.get())

def insert_or_update(self, table, connection, row):
""" Actual database load transformation logic, without the buffering / transaction logic.
Expand All @@ -110,34 +107,35 @@ def insert_or_update(self, table, connection, row):
# TODO XXX use actual database function instead of this stupid thing
now = datetime.datetime.now()

target_row = row._asdict()
column_names = table.columns.keys()

# UpdatedAt field configured ? Let's set the value in source hash
if self.updated_at_field in column_names:
row[self.updated_at_field] = now # XXX not pure ...
target_row[self.updated_at_field] = now # XXX not pure ...

# Update logic
if dbrow:
if not UPDATE in self.allowed_operations:
raise ProhibitedOperationError('UPDATE operations are not allowed by this transformation.')

query = table.update().values(
**{col: row.get(col)
for col in self.get_columns_for(column_names, row, dbrow)}
).where(and_(*(getattr(table.c, col) == row.get(col) for col in self.discriminant)))
**{col: target_row.get(col)
for col in self.get_columns_for(column_names, target_row, dbrow)}
).where(and_(*(getattr(table.c, col) == target_row.get(col) for col in self.discriminant)))

# INSERT
else:
if not INSERT in self.allowed_operations:
raise ProhibitedOperationError('INSERT operations are not allowed by this transformation.')

if self.created_at_field in column_names:
row[self.created_at_field] = now # XXX UNPURE
target_row[self.created_at_field] = now # XXX UNPURE
else:
if self.created_at_field in row:
del row[self.created_at_field] # UNPURE
if self.created_at_field in target_row:
del target_row[self.created_at_field] # UNPURE

query = table.insert().values(**{col: row.get(col) for col in self.get_columns_for(column_names, row)})
query = table.insert().values(**{col: target_row.get(col) for col in self.get_columns_for(column_names, target_row)})

# Execute
try:
Expand Down Expand Up @@ -175,7 +173,7 @@ def find(self, connection, table, row):

return dict(row) if row else None

def get_columns_for(self, column_names, row, dbrow=None):
def get_columns_for(self, column_names, target_row, dbrow=None):
"""Retrieve list of table column names for which we have a value in given hash.

"""
Expand All @@ -184,12 +182,7 @@ def get_columns_for(self, column_names, row, dbrow=None):
else:
candidates = column_names

try:
fields = row._fields
except AttributeError as exc:
fields = list(row.keys())

return set(candidates).intersection(fields)
return set(candidates).intersection(target_row.keys())

def add_fetch_columns(self, *columns, **aliased_columns):
self.fetch_columns = {
Expand Down