Skip to content

Commit 8ab05b1

Browse files
authored
Fix pool on testing (#2)
During testing the pool is injected to be able to rollback the transaction on every test. If the asyncpg.Pool is injected on construction (testing) it should not be closed by the on_disconnect event from the asgi
1 parent 7af40cb commit 8ab05b1

File tree

4 files changed

+11
-4
lines changed

4 files changed

+11
-4
lines changed

CHANGELOG.rst

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1-
## 1.0.0
1+
- 1.0.1
2+
Fix testing pool incorrectly disconnected
23

3-
- Initial release
4+
- 1.0.0
5+
Initial release

fastapi_asyncpg/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def __init__(
4444
async def on_connect(self):
4545
"""handler called during initialitzation of asgi app, that connects to
4646
the db"""
47+
# if the pool is comming from outside (tests), don't connect it
4748
if self._pool:
4849
self.app.state.pool = self._pool
4950
return
@@ -53,6 +54,10 @@ async def on_connect(self):
5354
self.app.state.pool = pool
5455

5556
async def on_disconnect(self):
57+
# if the pool is comming from outside, don't desconnect it
58+
# someone else will do (usualy a pytest fixture)
59+
if self._pool:
60+
return
5661
await self.app.state.pool.close()
5762

5863
def on_init(self, func):

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
setup(
1010
name="fastapi_asyncpg",
11-
version="1.0.0",
11+
version="1.0.1",
1212
url="https://github.com/jordic/fastapi_asyncpg",
1313
license="MIT",
1414
author="Jordi collell",

tests/test_db.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ async def test_pool_releases_connections(asgiapp):
116116
res = await client.post("/", json={"key": "test", "value": "val1"})
117117
assert res.status_code == 200
118118
tasks = []
119-
for i in range(5):
119+
for i in range(20):
120120
tasks.append(client.get("/test"))
121121

122122
await asyncio.gather(*tasks)

0 commit comments

Comments
 (0)