Skip to content

Commit f7a3e25

Browse files
Merge pull request #125 from discord-modmail/fix-aiohttp-session
tests: mock aiohttp responses and fix windows
2 parents 6ab7918 + 25192c5 commit f7a3e25

File tree

6 files changed

+129
-8
lines changed

6 files changed

+129
-8
lines changed

modmail/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,20 @@
1+
import asyncio
12
import logging
23
import logging.handlers
4+
import os
35
from pathlib import Path
46

57
import coloredlogs
68

79
from modmail.log import ModmailLogger
810

911

12+
# On windows aiodns's asyncio support relies on APIs like add_reader (which aiodns uses)
13+
# are not guaranteed to be available, and in particular are not available when using the
14+
# ProactorEventLoop on Windows, this method is only supported with Windows SelectorEventLoop
15+
if os.name == "nt":
16+
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
17+
1018
logging.TRACE = 5
1119
logging.NOTICE = 25
1220
logging.addLevelName(logging.TRACE, "TRACE")

modmail/bot.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import asyncio
22
import logging
33
import signal
4+
import socket
45
import typing as t
5-
from typing import Any
66

7+
import aiohttp
78
import arrow
89
import discord
9-
from aiohttp import ClientSession
1010
from discord import Activity, AllowedMentions, Intents
1111
from discord.client import _cleanup_loop
1212
from discord.ext import commands
@@ -41,9 +41,12 @@ class ModmailBot(commands.Bot):
4141
def __init__(self, **kwargs):
4242
self.config = CONFIG
4343
self.start_time: t.Optional[arrow.Arrow] = None # arrow.utcnow()
44-
self.http_session: t.Optional[ClientSession] = None
44+
self.http_session: t.Optional[aiohttp.ClientSession] = None
4545
self.dispatcher = Dispatcher()
4646

47+
self._connector = None
48+
self._resolver = None
49+
4750
status = discord.Status.online
4851
activity = Activity(type=discord.ActivityType.listening, name="users dming me!")
4952
# listen to messages mentioning the bot or matching the prefix
@@ -65,6 +68,24 @@ def __init__(self, **kwargs):
6568
**kwargs,
6669
)
6770

71+
async def create_connectors(self, *args, **kwargs) -> None:
72+
"""Re-create the connector and set up sessions before logging into Discord."""
73+
# Use asyncio for DNS resolution instead of threads so threads aren't spammed.
74+
self._resolver = aiohttp.AsyncResolver()
75+
76+
# Use AF_INET as its socket family to prevent HTTPS related problems both locally
77+
# and in production.
78+
self._connector = aiohttp.TCPConnector(
79+
resolver=self._resolver,
80+
family=socket.AF_INET,
81+
)
82+
83+
# Client.login() will call HTTPClient.static_login() which will create a session using
84+
# this connector attribute.
85+
self.http.connector = self._connector
86+
87+
self.http_session = aiohttp.ClientSession(connector=self._connector)
88+
6889
async def start(self, token: str, reconnect: bool = True) -> None:
6990
"""
7091
Start the bot.
@@ -74,8 +95,8 @@ async def start(self, token: str, reconnect: bool = True) -> None:
7495
"""
7596
try:
7697
# create the aiohttp session
77-
self.http_session = ClientSession(loop=self.loop)
78-
self.logger.trace("Created ClientSession.")
98+
await self.create_connectors()
99+
self.logger.trace("Created aiohttp.ClientSession.")
79100
# set start time to when we started the bot.
80101
# This is now, since we're about to connect to the gateway.
81102
# This should also be before we load any extensions, since if they have a load time, it should
@@ -122,7 +143,7 @@ def run(self, *args, **kwargs) -> None:
122143
except NotImplementedError:
123144
pass
124145

125-
def stop_loop_on_completion(f: Any) -> None:
146+
def stop_loop_on_completion(f: t.Any) -> None:
126147
loop.stop()
127148

128149
future = asyncio.ensure_future(self.start(*args, **kwargs), loop=loop)
@@ -164,10 +185,16 @@ async def close(self) -> None:
164185
except Exception:
165186
self.logger.error(f"Exception occured while removing cog {cog.name}", exc_info=True)
166187

188+
await super().close()
189+
167190
if self.http_session:
168191
await self.http_session.close()
169192

170-
await super().close()
193+
if self._connector:
194+
await self._connector.close()
195+
196+
if self._resolver:
197+
await self._resolver.close()
171198

172199
def load_extensions(self) -> None:
173200
"""Load all enabled extensions."""

poetry.lock

Lines changed: 16 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ flake8-todo = "~=0.7"
4444
isort = "^5.9.2"
4545
pep8-naming = "~=0.11"
4646
# testing
47+
aioresponses = "^0.7.2"
4748
coverage = { extras = ["toml"], version = "^6.0.2" }
4849
coveralls = "^3.3.1"
4950
pytest = "^6.2.4"

tests/conftest.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,29 @@
1+
import aiohttp
2+
import aioresponses
13
import pytest
4+
5+
6+
@pytest.fixture
7+
def aioresponse():
8+
"""Fixture to mock aiohttp responses."""
9+
with aioresponses.aioresponses() as aioresponse:
10+
yield aioresponse
11+
12+
13+
@pytest.fixture
14+
@pytest.mark.asyncio
15+
async def http_session(aioresponse) -> aiohttp.ClientSession:
16+
"""
17+
Fixture function for a aiohttp.ClientSession.
18+
19+
Requests fixture aioresponse to ensure that all client sessions do not make actual requests.
20+
"""
21+
resolver = aiohttp.AsyncResolver()
22+
connector = aiohttp.TCPConnector(resolver=resolver)
23+
client_session = aiohttp.ClientSession(connector=connector)
24+
25+
yield client_session
26+
27+
await client_session.close()
28+
await connector.close()
29+
await resolver.close()

tests/test_fixtures.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
import aiohttp
6+
import pytest
7+
8+
9+
if TYPE_CHECKING:
10+
import aioresponses
11+
12+
13+
class TestSessionFixture:
14+
"""Grouping for aiohttp.ClientSession fixture tests."""
15+
16+
@pytest.mark.asyncio
17+
async def test_session_fixture_no_requests(self, http_session: aiohttp.ClientSession):
18+
"""
19+
Test all requests fail.
20+
21+
This means that aioresponses is being requested by the http_session fixture.
22+
"""
23+
url = "https://github.com/"
24+
25+
with pytest.raises(aiohttp.ClientConnectionError):
26+
await http_session.get(url)
27+
28+
@pytest.mark.asyncio
29+
async def test_session_fixture_mock_requests(
30+
self, aioresponse: aioresponses.aioresponses, http_session: aiohttp.ClientSession
31+
):
32+
"""
33+
Test all requests fail.
34+
35+
This means that aioresponses is being requested by the http_session fixture.
36+
"""
37+
url = "https://github.com/"
38+
status = 200
39+
aioresponse.get(url, status=status)
40+
41+
async with http_session.get(url) as resp:
42+
assert status == resp.status

0 commit comments

Comments
 (0)