@@ -623,7 +623,9 @@ async def add(
623
623
error_messages = error_messages ,
624
624
default_messages = self .error_messages ,
625
625
)
626
- with wrap_sqlalchemy_exception (error_messages = error_messages , dialect_name = self ._dialect .name ):
626
+ with wrap_sqlalchemy_exception (
627
+ error_messages = error_messages , dialect_name = self ._dialect .name , wrap_exceptions = self .wrap_exceptions
628
+ ):
627
629
instance = await self ._attach_to_session (data )
628
630
await self ._flush_or_commit (auto_commit = auto_commit )
629
631
await self ._refresh (instance , auto_refresh = auto_refresh )
@@ -654,7 +656,9 @@ async def add_many(
654
656
error_messages = error_messages ,
655
657
default_messages = self .error_messages ,
656
658
)
657
- with wrap_sqlalchemy_exception (error_messages = error_messages , dialect_name = self ._dialect .name ):
659
+ with wrap_sqlalchemy_exception (
660
+ error_messages = error_messages , dialect_name = self ._dialect .name , wrap_exceptions = self .wrap_exceptions
661
+ ):
658
662
self .session .add_all (data )
659
663
await self ._flush_or_commit (auto_commit = auto_commit )
660
664
for datum in data :
@@ -696,7 +700,9 @@ async def delete(
696
700
error_messages = error_messages ,
697
701
default_messages = self .error_messages ,
698
702
)
699
- with wrap_sqlalchemy_exception (error_messages = error_messages , dialect_name = self ._dialect .name ):
703
+ with wrap_sqlalchemy_exception (
704
+ error_messages = error_messages , dialect_name = self ._dialect .name , wrap_exceptions = self .wrap_exceptions
705
+ ):
700
706
execution_options = self ._get_execution_options (execution_options )
701
707
instance = await self .get (
702
708
item_id ,
@@ -747,7 +753,9 @@ async def delete_many(
747
753
error_messages = error_messages ,
748
754
default_messages = self .error_messages ,
749
755
)
750
- with wrap_sqlalchemy_exception (error_messages = error_messages , dialect_name = self ._dialect .name ):
756
+ with wrap_sqlalchemy_exception (
757
+ error_messages = error_messages , dialect_name = self ._dialect .name , wrap_exceptions = self .wrap_exceptions
758
+ ):
751
759
execution_options = self ._get_execution_options (execution_options )
752
760
loader_options , _loader_options_have_wildcard = self ._get_loader_options (load )
753
761
id_attribute = get_instrumented_attr (
@@ -846,7 +854,9 @@ async def delete_where(
846
854
error_messages = error_messages ,
847
855
default_messages = self .error_messages ,
848
856
)
849
- with wrap_sqlalchemy_exception (error_messages = error_messages , dialect_name = self ._dialect .name ):
857
+ with wrap_sqlalchemy_exception (
858
+ error_messages = error_messages , dialect_name = self ._dialect .name , wrap_exceptions = self .wrap_exceptions
859
+ ):
850
860
execution_options = self ._get_execution_options (execution_options )
851
861
loader_options , _loader_options_have_wildcard = self ._get_loader_options (load )
852
862
model_type = self .model_type
@@ -1002,7 +1012,9 @@ async def get(
1002
1012
error_messages = error_messages ,
1003
1013
default_messages = self .error_messages ,
1004
1014
)
1005
- with wrap_sqlalchemy_exception (error_messages = error_messages , dialect_name = self ._dialect .name ):
1015
+ with wrap_sqlalchemy_exception (
1016
+ error_messages = error_messages , dialect_name = self ._dialect .name , wrap_exceptions = self .wrap_exceptions
1017
+ ):
1006
1018
execution_options = self ._get_execution_options (execution_options )
1007
1019
statement = self .statement if statement is None else statement
1008
1020
loader_options , loader_options_have_wildcard = self ._get_loader_options (load )
@@ -1051,7 +1063,9 @@ async def get_one(
1051
1063
error_messages = error_messages ,
1052
1064
default_messages = self .error_messages ,
1053
1065
)
1054
- with wrap_sqlalchemy_exception (error_messages = error_messages , dialect_name = self ._dialect .name ):
1066
+ with wrap_sqlalchemy_exception (
1067
+ error_messages = error_messages , dialect_name = self ._dialect .name , wrap_exceptions = self .wrap_exceptions
1068
+ ):
1055
1069
execution_options = self ._get_execution_options (execution_options )
1056
1070
statement = self .statement if statement is None else statement
1057
1071
loader_options , loader_options_have_wildcard = self ._get_loader_options (load )
@@ -1099,7 +1113,9 @@ async def get_one_or_none(
1099
1113
error_messages = error_messages ,
1100
1114
default_messages = self .error_messages ,
1101
1115
)
1102
- with wrap_sqlalchemy_exception (error_messages = error_messages , dialect_name = self ._dialect .name ):
1116
+ with wrap_sqlalchemy_exception (
1117
+ error_messages = error_messages , dialect_name = self ._dialect .name , wrap_exceptions = self .wrap_exceptions
1118
+ ):
1103
1119
execution_options = self ._get_execution_options (execution_options )
1104
1120
statement = self .statement if statement is None else statement
1105
1121
loader_options , loader_options_have_wildcard = self ._get_loader_options (load )
@@ -1167,7 +1183,9 @@ async def get_or_upsert(
1167
1183
error_messages = error_messages ,
1168
1184
default_messages = self .error_messages ,
1169
1185
)
1170
- with wrap_sqlalchemy_exception (error_messages = error_messages , dialect_name = self ._dialect .name ):
1186
+ with wrap_sqlalchemy_exception (
1187
+ error_messages = error_messages , dialect_name = self ._dialect .name , wrap_exceptions = self .wrap_exceptions
1188
+ ):
1171
1189
if match_fields := self ._get_match_fields (match_fields = match_fields ):
1172
1190
match_filter = {
1173
1191
field_name : kwargs .get (field_name )
@@ -1254,7 +1272,9 @@ async def get_and_update(
1254
1272
error_messages = error_messages ,
1255
1273
default_messages = self .error_messages ,
1256
1274
)
1257
- with wrap_sqlalchemy_exception (error_messages = error_messages , dialect_name = self ._dialect .name ):
1275
+ with wrap_sqlalchemy_exception (
1276
+ error_messages = error_messages , dialect_name = self ._dialect .name , wrap_exceptions = self .wrap_exceptions
1277
+ ):
1258
1278
if match_fields := self ._get_match_fields (match_fields = match_fields ):
1259
1279
match_filter = {
1260
1280
field_name : kwargs .get (field_name )
@@ -1311,7 +1331,9 @@ async def count(
1311
1331
error_messages = error_messages ,
1312
1332
default_messages = self .error_messages ,
1313
1333
)
1314
- with wrap_sqlalchemy_exception (error_messages = error_messages , dialect_name = self ._dialect .name ):
1334
+ with wrap_sqlalchemy_exception (
1335
+ error_messages = error_messages , dialect_name = self ._dialect .name , wrap_exceptions = self .wrap_exceptions
1336
+ ):
1315
1337
execution_options = self ._get_execution_options (execution_options )
1316
1338
statement = self .statement if statement is None else statement
1317
1339
loader_options , loader_options_have_wildcard = self ._get_loader_options (load )
@@ -1374,7 +1396,9 @@ async def update(
1374
1396
error_messages = error_messages ,
1375
1397
default_messages = self .error_messages ,
1376
1398
)
1377
- with wrap_sqlalchemy_exception (error_messages = error_messages , dialect_name = self ._dialect .name ):
1399
+ with wrap_sqlalchemy_exception (
1400
+ error_messages = error_messages , dialect_name = self ._dialect .name , wrap_exceptions = self .wrap_exceptions
1401
+ ):
1378
1402
item_id = self .get_id_attribute_value (
1379
1403
data ,
1380
1404
id_attribute = id_attribute ,
@@ -1430,7 +1454,9 @@ async def update_many(
1430
1454
default_messages = self .error_messages ,
1431
1455
)
1432
1456
data_to_update : list [dict [str , Any ]] = [v .to_dict () if isinstance (v , self .model_type ) else v for v in data ] # type: ignore[misc]
1433
- with wrap_sqlalchemy_exception (error_messages = error_messages , dialect_name = self ._dialect .name ):
1457
+ with wrap_sqlalchemy_exception (
1458
+ error_messages = error_messages , dialect_name = self ._dialect .name , wrap_exceptions = self .wrap_exceptions
1459
+ ):
1434
1460
execution_options = self ._get_execution_options (execution_options )
1435
1461
loader_options = self ._get_loader_options (load )[0 ]
1436
1462
supports_returning = self ._dialect .update_executemany_returning and self ._dialect .name != "oracle"
@@ -1598,7 +1624,9 @@ async def _list_and_count_window(
1598
1624
error_messages = error_messages ,
1599
1625
default_messages = self .error_messages ,
1600
1626
)
1601
- with wrap_sqlalchemy_exception (error_messages = error_messages , dialect_name = self ._dialect .name ):
1627
+ with wrap_sqlalchemy_exception (
1628
+ error_messages = error_messages , dialect_name = self ._dialect .name , wrap_exceptions = self .wrap_exceptions
1629
+ ):
1602
1630
execution_options = self ._get_execution_options (execution_options )
1603
1631
statement = self .statement if statement is None else statement
1604
1632
loader_options , loader_options_have_wildcard = self ._get_loader_options (load )
@@ -1655,7 +1683,9 @@ async def _list_and_count_basic(
1655
1683
error_messages = error_messages ,
1656
1684
default_messages = self .error_messages ,
1657
1685
)
1658
- with wrap_sqlalchemy_exception (error_messages = error_messages , dialect_name = self ._dialect .name ):
1686
+ with wrap_sqlalchemy_exception (
1687
+ error_messages = error_messages , dialect_name = self ._dialect .name , wrap_exceptions = self .wrap_exceptions
1688
+ ):
1659
1689
execution_options = self ._get_execution_options (execution_options )
1660
1690
statement = self .statement if statement is None else statement
1661
1691
loader_options , loader_options_have_wildcard = self ._get_loader_options (load )
@@ -1763,7 +1793,9 @@ async def upsert(
1763
1793
auto_expunge = auto_expunge ,
1764
1794
auto_refresh = auto_refresh ,
1765
1795
)
1766
- with wrap_sqlalchemy_exception (error_messages = error_messages , dialect_name = self ._dialect .name ):
1796
+ with wrap_sqlalchemy_exception (
1797
+ error_messages = error_messages , dialect_name = self ._dialect .name , wrap_exceptions = self .wrap_exceptions
1798
+ ):
1767
1799
for field_name , new_field_value in data .to_dict (exclude = {self .id_attribute }).items ():
1768
1800
field = getattr (existing , field_name , MISSING )
1769
1801
if field is not MISSING and field != new_field_value :
@@ -1838,7 +1870,9 @@ async def upsert_many(
1838
1870
]
1839
1871
match_filter .append (any_ (matched_values ) == field if self ._prefer_any else field .in_ (matched_values )) # type: ignore[arg-type]
1840
1872
1841
- with wrap_sqlalchemy_exception (error_messages = error_messages , dialect_name = self ._dialect .name ):
1873
+ with wrap_sqlalchemy_exception (
1874
+ error_messages = error_messages , dialect_name = self ._dialect .name , wrap_exceptions = self .wrap_exceptions
1875
+ ):
1842
1876
existing_objs = await self .list (
1843
1877
* match_filter ,
1844
1878
load = load ,
@@ -1943,7 +1977,9 @@ async def list(
1943
1977
error_messages = error_messages ,
1944
1978
default_messages = self .error_messages ,
1945
1979
)
1946
- with wrap_sqlalchemy_exception (error_messages = error_messages , dialect_name = self ._dialect .name ):
1980
+ with wrap_sqlalchemy_exception (
1981
+ error_messages = error_messages , dialect_name = self ._dialect .name , wrap_exceptions = self .wrap_exceptions
1982
+ ):
1947
1983
execution_options = self ._get_execution_options (execution_options )
1948
1984
statement = self .statement if statement is None else statement
1949
1985
loader_options , loader_options_have_wildcard = self ._get_loader_options (load )
@@ -2097,25 +2133,29 @@ class SQLAlchemyAsyncQueryRepository:
2097
2133
"""
2098
2134
2099
2135
error_messages : Optional [ErrorMessages ] = None
2136
+ wrap_exceptions : bool = True
2100
2137
2101
2138
def __init__ (
2102
2139
self ,
2103
2140
* ,
2104
2141
session : Union [AsyncSession , async_scoped_session [AsyncSession ]],
2105
2142
error_messages : Optional [ErrorMessages ] = None ,
2143
+ wrap_exceptions : bool = True ,
2106
2144
** kwargs : Any ,
2107
2145
) -> None :
2108
2146
"""Repository pattern for SQLAlchemy models.
2109
2147
2110
2148
Args:
2111
2149
session: Session managing the unit-of-work for the operation.
2112
2150
error_messages: A set of error messages to use for operations.
2151
+ wrap_exceptions: Whether to wrap exceptions in a SQLAlchemy exception.
2113
2152
**kwargs: Additional arguments.
2114
2153
2115
2154
"""
2116
2155
super ().__init__ (** kwargs )
2117
2156
self .session = session
2118
2157
self .error_messages = error_messages
2158
+ self .wrap_exceptions = wrap_exceptions
2119
2159
self ._dialect = self .session .bind .dialect if self .session .bind is not None else self .session .get_bind ().dialect
2120
2160
2121
2161
async def get_one (
@@ -2132,7 +2172,7 @@ async def get_one(
2132
2172
Returns:
2133
2173
The retrieved instance.
2134
2174
"""
2135
- with wrap_sqlalchemy_exception (error_messages = self .error_messages ):
2175
+ with wrap_sqlalchemy_exception (error_messages = self .error_messages , wrap_exceptions = self . wrap_exceptions ):
2136
2176
statement = self ._filter_statement_by_kwargs (statement , ** kwargs )
2137
2177
instance = (await self .execute (statement )).scalar_one_or_none ()
2138
2178
return self .check_not_found (instance )
@@ -2151,7 +2191,7 @@ async def get_one_or_none(
2151
2191
Returns:
2152
2192
The retrieved instance or None
2153
2193
"""
2154
- with wrap_sqlalchemy_exception (error_messages = self .error_messages ):
2194
+ with wrap_sqlalchemy_exception (error_messages = self .error_messages , wrap_exceptions = self . wrap_exceptions ):
2155
2195
statement = self ._filter_statement_by_kwargs (statement , ** kwargs )
2156
2196
instance = (await self .execute (statement )).scalar_one_or_none ()
2157
2197
return instance or None
@@ -2166,7 +2206,7 @@ async def count(self, statement: Select[Any], **kwargs: Any) -> int:
2166
2206
Returns:
2167
2207
Count of records returned by query, ignoring pagination.
2168
2208
"""
2169
- with wrap_sqlalchemy_exception (error_messages = self .error_messages ):
2209
+ with wrap_sqlalchemy_exception (error_messages = self .error_messages , wrap_exceptions = self . wrap_exceptions ):
2170
2210
statement = statement .with_only_columns (sql_func .count (text ("1" )), maintain_column_froms = True ).order_by (
2171
2211
None ,
2172
2212
)
@@ -2210,7 +2250,7 @@ async def _list_and_count_window(
2210
2250
Count of records returned by query using an analytical window function, ignoring pagination.
2211
2251
"""
2212
2252
2213
- with wrap_sqlalchemy_exception (error_messages = self .error_messages ):
2253
+ with wrap_sqlalchemy_exception (error_messages = self .error_messages , wrap_exceptions = self . wrap_exceptions ):
2214
2254
statement = statement .add_columns (over (sql_func .count (text ("1" ))))
2215
2255
statement = self ._filter_statement_by_kwargs (statement , ** kwargs )
2216
2256
result = await self .execute (statement )
@@ -2241,7 +2281,7 @@ async def _list_and_count_basic(
2241
2281
Count of records returned by query using 2 queries, ignoring pagination.
2242
2282
"""
2243
2283
2244
- with wrap_sqlalchemy_exception (error_messages = self .error_messages ):
2284
+ with wrap_sqlalchemy_exception (error_messages = self .error_messages , wrap_exceptions = self . wrap_exceptions ):
2245
2285
statement = self ._filter_statement_by_kwargs (statement , ** kwargs )
2246
2286
count_result = await self .session .execute (self ._get_count_stmt (statement ))
2247
2287
count = count_result .scalar_one ()
@@ -2261,7 +2301,7 @@ async def list(self, statement: Select[Any], **kwargs: Any) -> list[Row[Any]]:
2261
2301
Returns:
2262
2302
The list of instances, after filtering applied.
2263
2303
"""
2264
- with wrap_sqlalchemy_exception (error_messages = self .error_messages ):
2304
+ with wrap_sqlalchemy_exception (error_messages = self .error_messages , wrap_exceptions = self . wrap_exceptions ):
2265
2305
statement = self ._filter_statement_by_kwargs (statement , ** kwargs )
2266
2306
result = await self .execute (statement )
2267
2307
return list (result .all ())
0 commit comments