From d1b2e0672fb009511ca1b7a3d24a00d13d0c87de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sindri=20Sn=C3=A6r=20Gunnarsson?= Date: Wed, 25 Jun 2025 11:49:20 +0000 Subject: [PATCH] fix: scenario where a user can access another users events given the same session id test: refine test_session_state in test_session_state to catch event leakage fix: revert app name to my_app for test_session_state test style: fix pyink style warnings --- src/google/adk/sessions/database_session_service.py | 1 + tests/unittests/sessions/test_session_service.py | 12 ++++++++++++ 2 files changed, 13 insertions(+) diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py index bf83b0409..0f7739d1d 100644 --- a/src/google/adk/sessions/database_session_service.py +++ b/src/google/adk/sessions/database_session_service.py @@ -460,6 +460,7 @@ async def get_session( storage_events = ( session_factory.query(StorageEvent) .filter(StorageEvent.session_id == storage_session.id) + .filter(StorageEvent.user_id == user_id) .filter(timestamp_filter) .order_by(StorageEvent.timestamp.desc()) .limit( diff --git a/tests/unittests/sessions/test_session_service.py b/tests/unittests/sessions/test_session_service.py index d8344194f..995d739f5 100644 --- a/tests/unittests/sessions/test_session_service.py +++ b/tests/unittests/sessions/test_session_service.py @@ -126,6 +126,7 @@ async def test_session_state(service_type): app_name = 'my_app' user_id_1 = 'user1' user_id_2 = 'user2' + user_id_malicious = 'malicious' session_id_11 = 'session11' session_id_12 = 'session12' session_id_2 = 'session2' @@ -148,6 +149,10 @@ async def test_session_state(service_type): app_name=app_name, user_id=user_id_2, session_id=session_id_2 ) + await session_service.create_session( + app_name=app_name, user_id=user_id_malicious, session_id=session_id_11 + ) + assert session_11.state.get('key11') == 'value11' event = Event( @@ -196,6 +201,13 @@ async def test_session_state(service_type): assert session_11.state.get('user:key1') == 'value1' assert not session_11.state.get('temp:key') + # Make sure a malicious user can obtain a session and events not belonging to them + session_mismatch = await session_service.get_session( + app_name=app_name, user_id=user_id_malicious, session_id=session_id_11 + ) + + assert len(session_mismatch.events) == 0 + @pytest.mark.asyncio @pytest.mark.parametrize(