Skip to content

Commit c8efc05

Browse files
committed
- add decorator to compile sql for pool or return scalar from statement
1 parent e003454 commit c8efc05

File tree

2 files changed

+21
-7
lines changed

2 files changed

+21
-7
lines changed

app/api/stuff.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,10 @@ async def create_stuff(
4545
async def find_stuff(name: str, db_session: AsyncSession = Depends(get_db)):
4646
result = await Stuff.find(db_session, name)
4747
if not result:
48-
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Stuff with name {name} not found.")
48+
raise HTTPException(
49+
status_code=status.HTTP_404_NOT_FOUND,
50+
detail=f"Stuff with name {name} not found.",
51+
)
4952
return result
5053

5154

@@ -91,7 +94,6 @@ async def find_stuff_pool(
9194
return result
9295

9396

94-
9597
@router.delete("/{name}")
9698
async def delete_stuff(name: str, db_session: AsyncSession = Depends(get_db)):
9799
stuff = await Stuff.find(db_session, name)

app/models/stuff.py

+17-5
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,20 @@
99
from app.models.base import Base
1010
from app.models.nonsense import Nonsense
1111

12+
from functools import wraps
13+
14+
15+
def compile_sql_or_scalar(func):
16+
@wraps(func)
17+
async def wrapper(cls, db_session, name, compile_sql=False, *args, **kwargs):
18+
stmt = await func(cls, db_session, name, *args, **kwargs)
19+
if compile_sql:
20+
return stmt.compile(compile_kwargs={"literal_binds": True})
21+
result = await db_session.execute(stmt)
22+
return result.scalars().first()
23+
24+
return wrapper
25+
1226

1327
class Stuff(Base):
1428
__tablename__ = "stuff"
@@ -24,12 +38,10 @@ class Stuff(Base):
2438
)
2539

2640
@classmethod
27-
async def find(cls, db_session: AsyncSession, name: str, compile_sql: bool = False):
41+
@compile_sql_or_scalar
42+
async def find(cls, db_session: AsyncSession, name: str, compile_sql=False):
2843
stmt = select(cls).options(joinedload(cls.nonsense)).where(cls.name == name)
29-
if compile_sql:
30-
return stmt.compile(compile_kwargs={"literal_binds": True})
31-
result = await db_session.execute(stmt)
32-
return result.scalars().first()
44+
return stmt
3345

3446

3547
class StuffFullOfNonsense(Base):

0 commit comments

Comments
 (0)