Skip to content

Commit cf7d75e

Browse files
committed
tests again
1 parent 29da04d commit cf7d75e

File tree

2 files changed

+10
-6
lines changed

2 files changed

+10
-6
lines changed

tests/integration/test_fastapi_reconnection_isolation.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def _get_cassandra_control(self, container=None):
2222

2323
@pytest.mark.integration
2424
@pytest.mark.asyncio
25-
async def test_session_health_check_pattern(self):
25+
async def test_session_health_check_pattern(self, cassandra_container):
2626
"""Test the FastAPI health check pattern that might prevent reconnection."""
2727
print("\n=== Testing FastAPI Health Check Pattern ===")
2828

@@ -71,7 +71,7 @@ async def health_check():
7171

7272
# Disable Cassandra
7373
print("\nDisabling Cassandra...")
74-
control = self._get_cassandra_control()
74+
control = self._get_cassandra_control(cassandra_container)
7575

7676
if os.environ.get("CI") == "true":
7777
# Still test that health check works with available service
@@ -132,7 +132,7 @@ async def health_check():
132132

133133
@pytest.mark.integration
134134
@pytest.mark.asyncio
135-
async def test_global_session_reconnection(self):
135+
async def test_global_session_reconnection(self, cassandra_container):
136136
"""Test reconnection with global session variable like FastAPI."""
137137
print("\n=== Testing Global Session Reconnection ===")
138138

@@ -170,7 +170,7 @@ async def test_global_session_reconnection(self):
170170
print("✓ Initial query works")
171171

172172
# Get control interface
173-
control = self._get_cassandra_control()
173+
control = self._get_cassandra_control(cassandra_container)
174174

175175
if os.environ.get("CI") == "true":
176176
print("\nSkipping outage simulation in CI")

tests/unit/test_sql_injection_protection.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,12 @@ async def test_no_string_interpolation_in_queries(self):
104104
await mock_session.prepare("SELECT * FROM users LIMIT ?")
105105
await mock_session.execute(mock_stmt, [int(limit.split(";")[0])]) # Parse safely
106106

107-
# Verify no direct string interpolation
108-
assert all("DROP TABLE" not in str(call) for call in mock_session.execute.call_args_list)
107+
# Verify prepared statements were used (not string interpolation)
108+
# The execute calls should have the mock statement and parameters, not raw SQL
109+
for exec_call in mock_session.execute.call_args_list:
110+
# Each call should be execute(mock_stmt, [params])
111+
assert exec_call[0][0] == mock_stmt # First arg is the prepared statement
112+
assert isinstance(exec_call[0][1], list) # Second arg is parameters list
109113

110114
@pytest.mark.asyncio
111115
async def test_hardcoded_keyspace_names(self):

0 commit comments

Comments
 (0)