diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py index bf83b0409..0f7739d1d 100644 --- a/src/google/adk/sessions/database_session_service.py +++ b/src/google/adk/sessions/database_session_service.py @@ -460,6 +460,7 @@ async def get_session( storage_events = ( session_factory.query(StorageEvent) .filter(StorageEvent.session_id == storage_session.id) + .filter(StorageEvent.user_id == user_id) .filter(timestamp_filter) .order_by(StorageEvent.timestamp.desc()) .limit( diff --git a/tests/unittests/sessions/test_session_service.py b/tests/unittests/sessions/test_session_service.py index d8344194f..995d739f5 100644 --- a/tests/unittests/sessions/test_session_service.py +++ b/tests/unittests/sessions/test_session_service.py @@ -126,6 +126,7 @@ async def test_session_state(service_type): app_name = 'my_app' user_id_1 = 'user1' user_id_2 = 'user2' + user_id_malicious = 'malicious' session_id_11 = 'session11' session_id_12 = 'session12' session_id_2 = 'session2' @@ -148,6 +149,10 @@ async def test_session_state(service_type): app_name=app_name, user_id=user_id_2, session_id=session_id_2 ) + await session_service.create_session( + app_name=app_name, user_id=user_id_malicious, session_id=session_id_11 + ) + assert session_11.state.get('key11') == 'value11' event = Event( @@ -196,6 +201,13 @@ async def test_session_state(service_type): assert session_11.state.get('user:key1') == 'value1' assert not session_11.state.get('temp:key') + # Make sure a malicious user can obtain a session and events not belonging to them + session_mismatch = await session_service.get_session( + app_name=app_name, user_id=user_id_malicious, session_id=session_id_11 + ) + + assert len(session_mismatch.events) == 0 + @pytest.mark.asyncio @pytest.mark.parametrize(