Skip to content

Commit c09e49e

Browse files
authored
fix: wrap_exceptions is re-enabled (#475)
`wrap_exceptions` is now correctly passed into the exception handler context manager. Fixes #472
1 parent 8d4b59a commit c09e49e

File tree

2 files changed

+128
-48
lines changed

2 files changed

+128
-48
lines changed

advanced_alchemy/repository/_async.py

Lines changed: 64 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -623,7 +623,9 @@ async def add(
623623
error_messages=error_messages,
624624
default_messages=self.error_messages,
625625
)
626-
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
626+
with wrap_sqlalchemy_exception(
627+
error_messages=error_messages, dialect_name=self._dialect.name, wrap_exceptions=self.wrap_exceptions
628+
):
627629
instance = await self._attach_to_session(data)
628630
await self._flush_or_commit(auto_commit=auto_commit)
629631
await self._refresh(instance, auto_refresh=auto_refresh)
@@ -654,7 +656,9 @@ async def add_many(
654656
error_messages=error_messages,
655657
default_messages=self.error_messages,
656658
)
657-
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
659+
with wrap_sqlalchemy_exception(
660+
error_messages=error_messages, dialect_name=self._dialect.name, wrap_exceptions=self.wrap_exceptions
661+
):
658662
self.session.add_all(data)
659663
await self._flush_or_commit(auto_commit=auto_commit)
660664
for datum in data:
@@ -696,7 +700,9 @@ async def delete(
696700
error_messages=error_messages,
697701
default_messages=self.error_messages,
698702
)
699-
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
703+
with wrap_sqlalchemy_exception(
704+
error_messages=error_messages, dialect_name=self._dialect.name, wrap_exceptions=self.wrap_exceptions
705+
):
700706
execution_options = self._get_execution_options(execution_options)
701707
instance = await self.get(
702708
item_id,
@@ -747,7 +753,9 @@ async def delete_many(
747753
error_messages=error_messages,
748754
default_messages=self.error_messages,
749755
)
750-
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
756+
with wrap_sqlalchemy_exception(
757+
error_messages=error_messages, dialect_name=self._dialect.name, wrap_exceptions=self.wrap_exceptions
758+
):
751759
execution_options = self._get_execution_options(execution_options)
752760
loader_options, _loader_options_have_wildcard = self._get_loader_options(load)
753761
id_attribute = get_instrumented_attr(
@@ -846,7 +854,9 @@ async def delete_where(
846854
error_messages=error_messages,
847855
default_messages=self.error_messages,
848856
)
849-
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
857+
with wrap_sqlalchemy_exception(
858+
error_messages=error_messages, dialect_name=self._dialect.name, wrap_exceptions=self.wrap_exceptions
859+
):
850860
execution_options = self._get_execution_options(execution_options)
851861
loader_options, _loader_options_have_wildcard = self._get_loader_options(load)
852862
model_type = self.model_type
@@ -1002,7 +1012,9 @@ async def get(
10021012
error_messages=error_messages,
10031013
default_messages=self.error_messages,
10041014
)
1005-
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
1015+
with wrap_sqlalchemy_exception(
1016+
error_messages=error_messages, dialect_name=self._dialect.name, wrap_exceptions=self.wrap_exceptions
1017+
):
10061018
execution_options = self._get_execution_options(execution_options)
10071019
statement = self.statement if statement is None else statement
10081020
loader_options, loader_options_have_wildcard = self._get_loader_options(load)
@@ -1051,7 +1063,9 @@ async def get_one(
10511063
error_messages=error_messages,
10521064
default_messages=self.error_messages,
10531065
)
1054-
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
1066+
with wrap_sqlalchemy_exception(
1067+
error_messages=error_messages, dialect_name=self._dialect.name, wrap_exceptions=self.wrap_exceptions
1068+
):
10551069
execution_options = self._get_execution_options(execution_options)
10561070
statement = self.statement if statement is None else statement
10571071
loader_options, loader_options_have_wildcard = self._get_loader_options(load)
@@ -1099,7 +1113,9 @@ async def get_one_or_none(
10991113
error_messages=error_messages,
11001114
default_messages=self.error_messages,
11011115
)
1102-
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
1116+
with wrap_sqlalchemy_exception(
1117+
error_messages=error_messages, dialect_name=self._dialect.name, wrap_exceptions=self.wrap_exceptions
1118+
):
11031119
execution_options = self._get_execution_options(execution_options)
11041120
statement = self.statement if statement is None else statement
11051121
loader_options, loader_options_have_wildcard = self._get_loader_options(load)
@@ -1167,7 +1183,9 @@ async def get_or_upsert(
11671183
error_messages=error_messages,
11681184
default_messages=self.error_messages,
11691185
)
1170-
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
1186+
with wrap_sqlalchemy_exception(
1187+
error_messages=error_messages, dialect_name=self._dialect.name, wrap_exceptions=self.wrap_exceptions
1188+
):
11711189
if match_fields := self._get_match_fields(match_fields=match_fields):
11721190
match_filter = {
11731191
field_name: kwargs.get(field_name)
@@ -1254,7 +1272,9 @@ async def get_and_update(
12541272
error_messages=error_messages,
12551273
default_messages=self.error_messages,
12561274
)
1257-
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
1275+
with wrap_sqlalchemy_exception(
1276+
error_messages=error_messages, dialect_name=self._dialect.name, wrap_exceptions=self.wrap_exceptions
1277+
):
12581278
if match_fields := self._get_match_fields(match_fields=match_fields):
12591279
match_filter = {
12601280
field_name: kwargs.get(field_name)
@@ -1311,7 +1331,9 @@ async def count(
13111331
error_messages=error_messages,
13121332
default_messages=self.error_messages,
13131333
)
1314-
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
1334+
with wrap_sqlalchemy_exception(
1335+
error_messages=error_messages, dialect_name=self._dialect.name, wrap_exceptions=self.wrap_exceptions
1336+
):
13151337
execution_options = self._get_execution_options(execution_options)
13161338
statement = self.statement if statement is None else statement
13171339
loader_options, loader_options_have_wildcard = self._get_loader_options(load)
@@ -1374,7 +1396,9 @@ async def update(
13741396
error_messages=error_messages,
13751397
default_messages=self.error_messages,
13761398
)
1377-
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
1399+
with wrap_sqlalchemy_exception(
1400+
error_messages=error_messages, dialect_name=self._dialect.name, wrap_exceptions=self.wrap_exceptions
1401+
):
13781402
item_id = self.get_id_attribute_value(
13791403
data,
13801404
id_attribute=id_attribute,
@@ -1430,7 +1454,9 @@ async def update_many(
14301454
default_messages=self.error_messages,
14311455
)
14321456
data_to_update: list[dict[str, Any]] = [v.to_dict() if isinstance(v, self.model_type) else v for v in data] # type: ignore[misc]
1433-
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
1457+
with wrap_sqlalchemy_exception(
1458+
error_messages=error_messages, dialect_name=self._dialect.name, wrap_exceptions=self.wrap_exceptions
1459+
):
14341460
execution_options = self._get_execution_options(execution_options)
14351461
loader_options = self._get_loader_options(load)[0]
14361462
supports_returning = self._dialect.update_executemany_returning and self._dialect.name != "oracle"
@@ -1598,7 +1624,9 @@ async def _list_and_count_window(
15981624
error_messages=error_messages,
15991625
default_messages=self.error_messages,
16001626
)
1601-
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
1627+
with wrap_sqlalchemy_exception(
1628+
error_messages=error_messages, dialect_name=self._dialect.name, wrap_exceptions=self.wrap_exceptions
1629+
):
16021630
execution_options = self._get_execution_options(execution_options)
16031631
statement = self.statement if statement is None else statement
16041632
loader_options, loader_options_have_wildcard = self._get_loader_options(load)
@@ -1655,7 +1683,9 @@ async def _list_and_count_basic(
16551683
error_messages=error_messages,
16561684
default_messages=self.error_messages,
16571685
)
1658-
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
1686+
with wrap_sqlalchemy_exception(
1687+
error_messages=error_messages, dialect_name=self._dialect.name, wrap_exceptions=self.wrap_exceptions
1688+
):
16591689
execution_options = self._get_execution_options(execution_options)
16601690
statement = self.statement if statement is None else statement
16611691
loader_options, loader_options_have_wildcard = self._get_loader_options(load)
@@ -1763,7 +1793,9 @@ async def upsert(
17631793
auto_expunge=auto_expunge,
17641794
auto_refresh=auto_refresh,
17651795
)
1766-
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
1796+
with wrap_sqlalchemy_exception(
1797+
error_messages=error_messages, dialect_name=self._dialect.name, wrap_exceptions=self.wrap_exceptions
1798+
):
17671799
for field_name, new_field_value in data.to_dict(exclude={self.id_attribute}).items():
17681800
field = getattr(existing, field_name, MISSING)
17691801
if field is not MISSING and field != new_field_value:
@@ -1838,7 +1870,9 @@ async def upsert_many(
18381870
]
18391871
match_filter.append(any_(matched_values) == field if self._prefer_any else field.in_(matched_values)) # type: ignore[arg-type]
18401872

1841-
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
1873+
with wrap_sqlalchemy_exception(
1874+
error_messages=error_messages, dialect_name=self._dialect.name, wrap_exceptions=self.wrap_exceptions
1875+
):
18421876
existing_objs = await self.list(
18431877
*match_filter,
18441878
load=load,
@@ -1943,7 +1977,9 @@ async def list(
19431977
error_messages=error_messages,
19441978
default_messages=self.error_messages,
19451979
)
1946-
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
1980+
with wrap_sqlalchemy_exception(
1981+
error_messages=error_messages, dialect_name=self._dialect.name, wrap_exceptions=self.wrap_exceptions
1982+
):
19471983
execution_options = self._get_execution_options(execution_options)
19481984
statement = self.statement if statement is None else statement
19491985
loader_options, loader_options_have_wildcard = self._get_loader_options(load)
@@ -2097,25 +2133,29 @@ class SQLAlchemyAsyncQueryRepository:
20972133
"""
20982134

20992135
error_messages: Optional[ErrorMessages] = None
2136+
wrap_exceptions: bool = True
21002137

21012138
def __init__(
21022139
self,
21032140
*,
21042141
session: Union[AsyncSession, async_scoped_session[AsyncSession]],
21052142
error_messages: Optional[ErrorMessages] = None,
2143+
wrap_exceptions: bool = True,
21062144
**kwargs: Any,
21072145
) -> None:
21082146
"""Repository pattern for SQLAlchemy models.
21092147
21102148
Args:
21112149
session: Session managing the unit-of-work for the operation.
21122150
error_messages: A set of error messages to use for operations.
2151+
wrap_exceptions: Whether to wrap exceptions in a SQLAlchemy exception.
21132152
**kwargs: Additional arguments.
21142153
21152154
"""
21162155
super().__init__(**kwargs)
21172156
self.session = session
21182157
self.error_messages = error_messages
2158+
self.wrap_exceptions = wrap_exceptions
21192159
self._dialect = self.session.bind.dialect if self.session.bind is not None else self.session.get_bind().dialect
21202160

21212161
async def get_one(
@@ -2132,7 +2172,7 @@ async def get_one(
21322172
Returns:
21332173
The retrieved instance.
21342174
"""
2135-
with wrap_sqlalchemy_exception(error_messages=self.error_messages):
2175+
with wrap_sqlalchemy_exception(error_messages=self.error_messages, wrap_exceptions=self.wrap_exceptions):
21362176
statement = self._filter_statement_by_kwargs(statement, **kwargs)
21372177
instance = (await self.execute(statement)).scalar_one_or_none()
21382178
return self.check_not_found(instance)
@@ -2151,7 +2191,7 @@ async def get_one_or_none(
21512191
Returns:
21522192
The retrieved instance or None
21532193
"""
2154-
with wrap_sqlalchemy_exception(error_messages=self.error_messages):
2194+
with wrap_sqlalchemy_exception(error_messages=self.error_messages, wrap_exceptions=self.wrap_exceptions):
21552195
statement = self._filter_statement_by_kwargs(statement, **kwargs)
21562196
instance = (await self.execute(statement)).scalar_one_or_none()
21572197
return instance or None
@@ -2166,7 +2206,7 @@ async def count(self, statement: Select[Any], **kwargs: Any) -> int:
21662206
Returns:
21672207
Count of records returned by query, ignoring pagination.
21682208
"""
2169-
with wrap_sqlalchemy_exception(error_messages=self.error_messages):
2209+
with wrap_sqlalchemy_exception(error_messages=self.error_messages, wrap_exceptions=self.wrap_exceptions):
21702210
statement = statement.with_only_columns(sql_func.count(text("1")), maintain_column_froms=True).order_by(
21712211
None,
21722212
)
@@ -2210,7 +2250,7 @@ async def _list_and_count_window(
22102250
Count of records returned by query using an analytical window function, ignoring pagination.
22112251
"""
22122252

2213-
with wrap_sqlalchemy_exception(error_messages=self.error_messages):
2253+
with wrap_sqlalchemy_exception(error_messages=self.error_messages, wrap_exceptions=self.wrap_exceptions):
22142254
statement = statement.add_columns(over(sql_func.count(text("1"))))
22152255
statement = self._filter_statement_by_kwargs(statement, **kwargs)
22162256
result = await self.execute(statement)
@@ -2241,7 +2281,7 @@ async def _list_and_count_basic(
22412281
Count of records returned by query using 2 queries, ignoring pagination.
22422282
"""
22432283

2244-
with wrap_sqlalchemy_exception(error_messages=self.error_messages):
2284+
with wrap_sqlalchemy_exception(error_messages=self.error_messages, wrap_exceptions=self.wrap_exceptions):
22452285
statement = self._filter_statement_by_kwargs(statement, **kwargs)
22462286
count_result = await self.session.execute(self._get_count_stmt(statement))
22472287
count = count_result.scalar_one()
@@ -2261,7 +2301,7 @@ async def list(self, statement: Select[Any], **kwargs: Any) -> list[Row[Any]]:
22612301
Returns:
22622302
The list of instances, after filtering applied.
22632303
"""
2264-
with wrap_sqlalchemy_exception(error_messages=self.error_messages):
2304+
with wrap_sqlalchemy_exception(error_messages=self.error_messages, wrap_exceptions=self.wrap_exceptions):
22652305
statement = self._filter_statement_by_kwargs(statement, **kwargs)
22662306
result = await self.execute(statement)
22672307
return list(result.all())

0 commit comments

Comments
 (0)