Skip to content

Commit 3752504

Browse files
committed
Use mock auth service in tests
1 parent 8112741 commit 3752504

File tree

4 files changed

+100
-4
lines changed

4 files changed

+100
-4
lines changed

soauth/service/mock.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
"""
2+
The mock Auth Provider, used for testing.
3+
"""
4+
5+
from sqlalchemy.ext.asyncio import AsyncSession
6+
from structlog.typing import FilteringBoundLogger
7+
8+
from soauth.config.settings import Settings
9+
from soauth.database.login import LoginRequest
10+
from soauth.database.user import User
11+
from soauth.service import user as user_service
12+
from soauth.service.provider import AuthProvider, BaseLoginError
13+
14+
15+
class MockLoginError(BaseLoginError):
16+
pass
17+
18+
19+
class MockProvider(AuthProvider):
20+
name = "mock"
21+
22+
user_name: str
23+
full_name: str
24+
email: str
25+
grants: str
26+
27+
def __init__(self, user_name: str, full_name: str, email: str, grants: str):
28+
self.user_name = user_name
29+
self.full_name = full_name
30+
self.email = email
31+
self.grants = grants
32+
33+
async def redirect(self, login_request: LoginRequest, settings: Settings) -> str:
34+
# Does not make sense for the mock provider
35+
raise NotImplementedError
36+
37+
async def login(
38+
self,
39+
code: str,
40+
settings: Settings,
41+
conn: AsyncSession,
42+
log: FilteringBoundLogger,
43+
) -> User:
44+
try:
45+
user = await user_service.read_by_name(user_name=self.user_name, conn=conn)
46+
except user_service.UserNotFound:
47+
user = await user_service.create_user(
48+
user_name=self.user_name,
49+
email=self.email,
50+
full_name=self.email,
51+
grants=self.grants,
52+
conn=conn,
53+
log=log,
54+
)
55+
56+
conn.add(user)
57+
58+
return user
59+
60+
async def refresh(
61+
self,
62+
user: User,
63+
settings: Settings,
64+
conn: AsyncSession,
65+
log: FilteringBoundLogger,
66+
) -> User:
67+
user.full_name = self.full_name
68+
user.email = self.email
69+
user.grants = self.grants
70+
71+
conn.add(user)
72+
73+
return user

tests/test_service/conftest.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from soauth.config.settings import Settings
99
from soauth.service import app as app_service
1010
from soauth.service import user as user_service
11+
from soauth.service.mock import MockProvider
1112

1213

1314
@pytest_asyncio.fixture(scope="session")
@@ -70,3 +71,14 @@ async def app(session_manager, logger, user, server_settings):
7071
conn=conn,
7172
log=logger,
7273
)
74+
75+
76+
@pytest_asyncio.fixture(scope="session")
77+
async def provider():
78+
# Same as `user`
79+
yield MockProvider(
80+
user_name="admin",
81+
email="admin@simonsobservatory.org",
82+
full_name="Admin User",
83+
grants="admin",
84+
)

tests/test_service/test_flow.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
@pytest.mark.asyncio(loop_scope="session")
1414
async def test_primary_then_secondary(
15-
user, app, logger, server_settings, session_manager
15+
user, app, logger, server_settings, session_manager, provider
1616
):
1717
async with session_manager.session() as conn:
1818
async with conn.begin():
@@ -50,6 +50,7 @@ async def test_primary_then_secondary(
5050
settings=server_settings,
5151
conn=conn,
5252
log=logger,
53+
provider=provider,
5354
)
5455

5556
auth = key_content.access_token

tests/test_service/test_refresh.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111

1212

1313
@pytest.mark.asyncio(loop_scope="session")
14-
async def test_create_refresh_key(user, app, session_manager, logger, server_settings):
14+
async def test_create_refresh_key(
15+
user, app, session_manager, logger, server_settings, provider
16+
):
1517
async with session_manager.session() as conn:
1618
async with conn.begin():
1719
encoded, refresh_key = await refresh_service.create_refresh_key(
@@ -60,7 +62,11 @@ async def test_create_refresh_key(user, app, session_manager, logger, server_set
6062
encoded_payload=encoded, conn=conn
6163
)
6264
await refresh_service.refresh_refresh_key(
63-
payload=decoded, settings=server_settings, conn=conn, log=logger
65+
payload=decoded,
66+
settings=server_settings,
67+
conn=conn,
68+
log=logger,
69+
provider=provider,
6470
)
6571

6672
# Now let's refresh our new refresh key.
@@ -73,7 +79,11 @@ async def test_create_refresh_key(user, app, session_manager, logger, server_set
7379
refreshed_encoded,
7480
refreshed_refresh_key,
7581
) = await refresh_service.refresh_refresh_key(
76-
payload=decoded, settings=server_settings, conn=conn, log=logger
82+
payload=decoded,
83+
settings=server_settings,
84+
conn=conn,
85+
log=logger,
86+
provider=provider,
7787
)
7888
REFRESHED_KEY_ID = refreshed_refresh_key.refresh_key_id
7989
assert refreshed_refresh_key.previous == NEW_REFRESH_KEY_ID

0 commit comments

Comments
 (0)