From 5ad86367ddf7e62f5770cc89040a52183350308e Mon Sep 17 00:00:00 2001 From: Brandon Istfan Date: Sun, 2 Nov 2025 18:11:14 -0500 Subject: [PATCH 1/2] fix: enrollment thread and temp storage issue --- .../management/commands/fetch_enrollment.py | 47 ++++++++----- tcf_website/tests/test_fetch_enrollment.py | 69 ++++++++++++++----- 2 files changed, 82 insertions(+), 34 deletions(-) diff --git a/tcf_website/management/commands/fetch_enrollment.py b/tcf_website/management/commands/fetch_enrollment.py index 2545c456..d00802a0 100644 --- a/tcf_website/management/commands/fetch_enrollment.py +++ b/tcf_website/management/commands/fetch_enrollment.py @@ -7,12 +7,14 @@ docker exec -it tcf_django python manage.py fetch_enrollment """ +import threading import time from concurrent.futures import ThreadPoolExecutor import backoff import requests from django.core.management.base import BaseCommand +from django.db import connections from requests.adapters import HTTPAdapter from tqdm import tqdm from urllib3.util.retry import Retry @@ -22,7 +24,7 @@ # Maximum time to wait for a response from the server TIMEOUT = 30 # Number of concurrent workers for fetching data -MAX_WORKERS = 20 +MAX_WORKERS = 5 # Initial wait time for backoff (in seconds) INITIAL_WAIT = 2 # Maximum number of retry attempts @@ -30,20 +32,8 @@ # Kept larger than MAX_WORKERS to handle connection lifecycle issues MAX_POOL_SIZE = 20 * 5 -# Configure session with retry strategy and connection pooling -session = requests.Session() -retry_strategy = Retry( - total=3, - backoff_factor=0.1, - status_forcelist=[429, 500, 502, 503, 504], -) -adapter = HTTPAdapter( - pool_connections=MAX_POOL_SIZE, - pool_maxsize=MAX_POOL_SIZE, - max_retries=retry_strategy, -) -session.mount("http://", adapter) -session.mount("https://", adapter) +# Thread-local storage for requests sessions to avoid contention +_session_local = threading.local() def should_retry_request(exception): @@ -56,6 +46,25 @@ def should_retry_request(exception): return False +def get_session(): + """Get or create a thread-local requests session.""" + if not hasattr(_session_local, "session"): + retry_strategy = Retry( + total=3, + backoff_factor=0.1, + status_forcelist=[429, 500, 502, 503, 504], + ) + adapter = HTTPAdapter( + pool_connections=MAX_POOL_SIZE, + pool_maxsize=MAX_POOL_SIZE, + max_retries=retry_strategy, + ) + _session_local.session = requests.Session() + _session_local.session.mount("http://", adapter) + _session_local.session.mount("https://", adapter) + return _session_local.session + + @backoff.on_exception( backoff.expo, requests.exceptions.RequestException, @@ -81,6 +90,7 @@ def fetch_section_data(section): ) try: + session = get_session() # Fetch and validate response response = session.get(url, timeout=TIMEOUT) response.raise_for_status() @@ -152,9 +162,11 @@ def handle(self, *args, **options): print(f"Fetching enrollment data for semester {semester}...") - sections = Section.objects.filter(semester=semester) - total_sections = sections.count() + sections_queryset = Section.objects.filter(semester=semester) + total_sections = sections_queryset.count() print(f"Found {total_sections} sections") + + sections = list(sections_queryset) # Process sections in parallel using ThreadPoolExecutor success_count = 0 @@ -171,6 +183,7 @@ def handle(self, *args, **options): except KeyboardInterrupt: print("\nProcess interrupted by user. Shutting down...") executor.shutdown(wait=False) + connections.close_all() return elapsed_time = time.time() - start_time diff --git a/tcf_website/tests/test_fetch_enrollment.py b/tcf_website/tests/test_fetch_enrollment.py index 5c4a15e5..41b66fde 100644 --- a/tcf_website/tests/test_fetch_enrollment.py +++ b/tcf_website/tests/test_fetch_enrollment.py @@ -1,10 +1,13 @@ """Tests for fetch_enrollment management command.""" +import gc +import threading from unittest.mock import patch +from django.db import connections from django.test import TestCase -from tcf_website.management.commands.fetch_enrollment import fetch_section_data +from tcf_website.management.commands.fetch_enrollment import fetch_section_data, get_session from tcf_website.models import SectionEnrollment from .test_utils import setup @@ -19,11 +22,13 @@ def setUp(self): # pylint: disable=no-member self.section = self.section_course - @patch("tcf_website.management.commands.fetch_enrollment.session.get") - def test_fetch_enrollment_success(self, mock_get): + @patch("tcf_website.management.commands.fetch_enrollment.get_session") + def test_fetch_enrollment_success(self, mock_get_session): """Test successful enrollment fetch.""" - mock_get.return_value.status_code = 200 - mock_get.return_value.json.return_value = { + mock_session = mock_get_session.return_value + mock_session.get.return_value.status_code = 200 + mock_session.get.return_value.raise_for_status.return_value = None + mock_session.get.return_value.json.return_value = { "classes": [ { "enrollment_total": 15, @@ -42,8 +47,8 @@ def test_fetch_enrollment_success(self, mock_get): self.assertEqual(enrollment.waitlist_taken, 5) self.assertEqual(enrollment.waitlist_limit, 10) - @patch("tcf_website.management.commands.fetch_enrollment.session.get") - def test_fetch_enrollment_update_existing(self, mock_get): + @patch("tcf_website.management.commands.fetch_enrollment.get_session") + def test_fetch_enrollment_update_existing(self, mock_get_session): """Test updating existing enrollment data.""" # Create initial enrollment SectionEnrollment.objects.create( @@ -54,8 +59,10 @@ def test_fetch_enrollment_update_existing(self, mock_get): waitlist_limit=5, ) - mock_get.return_value.status_code = 200 - mock_get.return_value.json.return_value = { + mock_session = mock_get_session.return_value + mock_session.get.return_value.status_code = 200 + mock_session.get.return_value.raise_for_status.return_value = None + mock_session.get.return_value.json.return_value = { "classes": [ { "enrollment_total": 15, @@ -74,24 +81,52 @@ def test_fetch_enrollment_update_existing(self, mock_get): self.assertEqual(enrollment.waitlist_taken, 8) self.assertEqual(enrollment.waitlist_limit, 12) - @patch("tcf_website.management.commands.fetch_enrollment.session.get") - def test_fetch_enrollment_empty_response(self, mock_get): + @patch("tcf_website.management.commands.fetch_enrollment.get_session") + def test_fetch_enrollment_empty_response(self, mock_get_session): """Test handling of empty API response.""" - mock_get.return_value.status_code = 200 - mock_get.return_value.json.return_value = {"classes": []} + mock_session = mock_get_session.return_value + mock_session.get.return_value.status_code = 200 + mock_session.get.return_value.raise_for_status.return_value = None + mock_session.get.return_value.json.return_value = {"classes": []} result = fetch_section_data(self.section) self.assertFalse(result) self.assertEqual(SectionEnrollment.objects.count(), 0) - @patch("tcf_website.management.commands.fetch_enrollment.session.get") - def test_fetch_enrollment_api_error(self, mock_get): + @patch("tcf_website.management.commands.fetch_enrollment.get_session") + def test_fetch_enrollment_api_error(self, mock_get_session): """Test handling of API error.""" - mock_get.return_value.status_code = 500 - mock_get.return_value.raise_for_status.side_effect = Exception("API Error") + mock_session = mock_get_session.return_value + mock_session.get.return_value.status_code = 500 + mock_session.get.return_value.raise_for_status.side_effect = Exception("API Error") result = fetch_section_data(self.section) self.assertFalse(result) self.assertEqual(SectionEnrollment.objects.count(), 0) + + @patch("tcf_website.management.commands.fetch_enrollment.get_session") + def test_no_resource_leaks(self, mock_get_session): + """Test that threads and database connections are properly cleaned up.""" + mock_session = mock_get_session.return_value + mock_session.get.return_value.status_code = 200 + mock_session.get.return_value.raise_for_status.return_value = None + mock_session.get.return_value.json.return_value = { + "classes": [{"enrollment_total": 15, "class_capacity": 20, "wait_tot": 0, "wait_cap": 0}] + } + + # Check thread isolation: each thread gets its own session + sessions = [] + for _ in range(5): + thread = threading.Thread(target=lambda: sessions.append(get_session())) + thread.start() + thread.join() + self.assertEqual(len(set(id(s) for s in sessions)), 5, "Each thread should have its own session") + + # Check connection cleanup: no connection leak after fetch + initial_connections = len([c for c in connections.all() if c.is_usable()]) + fetch_section_data(self.section) + gc.collect() + final_connections = len([c for c in connections.all() if c.is_usable()]) + self.assertLessEqual(final_connections, initial_connections + 1, "Connections should be closed") From 50d6f716e08db56eb2a3270317e95a3b039af333 Mon Sep 17 00:00:00 2001 From: Brandon Istfan Date: Sun, 2 Nov 2025 18:23:15 -0500 Subject: [PATCH 2/2] lint fixes --- tcf_website/api/views.py | 4 +-- .../management/commands/fetch_enrollment.py | 2 +- tcf_website/tests/test_fetch_enrollment.py | 28 +++++-------------- 3 files changed, 10 insertions(+), 24 deletions(-) diff --git a/tcf_website/api/views.py b/tcf_website/api/views.py index 50bff55a..7f180ab6 100644 --- a/tcf_website/api/views.py +++ b/tcf_website/api/views.py @@ -2,7 +2,7 @@ """DRF Viewsets""" import asyncio from threading import Thread -from django.db import connection +from django.db import connections from django.db.models import Avg, Sum from django.http import JsonResponse from rest_framework import viewsets @@ -197,7 +197,7 @@ def _run_update(): except (asyncio.TimeoutError, requests.RequestException, ValueError) as exc: print(f"Enrollment update failed for course {pk}: {exc}") finally: - connection.close() + connections.close_all() thread = Thread(target=_run_update, daemon=True) thread.start() diff --git a/tcf_website/management/commands/fetch_enrollment.py b/tcf_website/management/commands/fetch_enrollment.py index d00802a0..6cd15ab2 100644 --- a/tcf_website/management/commands/fetch_enrollment.py +++ b/tcf_website/management/commands/fetch_enrollment.py @@ -165,7 +165,7 @@ def handle(self, *args, **options): sections_queryset = Section.objects.filter(semester=semester) total_sections = sections_queryset.count() print(f"Found {total_sections} sections") - + sections = list(sections_queryset) # Process sections in parallel using ThreadPoolExecutor diff --git a/tcf_website/tests/test_fetch_enrollment.py b/tcf_website/tests/test_fetch_enrollment.py index 41b66fde..bcc1d95d 100644 --- a/tcf_website/tests/test_fetch_enrollment.py +++ b/tcf_website/tests/test_fetch_enrollment.py @@ -1,10 +1,8 @@ """Tests for fetch_enrollment management command.""" -import gc import threading from unittest.mock import patch -from django.db import connections from django.test import TestCase from tcf_website.management.commands.fetch_enrollment import fetch_section_data, get_session @@ -106,27 +104,15 @@ def test_fetch_enrollment_api_error(self, mock_get_session): self.assertFalse(result) self.assertEqual(SectionEnrollment.objects.count(), 0) - @patch("tcf_website.management.commands.fetch_enrollment.get_session") - def test_no_resource_leaks(self, mock_get_session): - """Test that threads and database connections are properly cleaned up.""" - mock_session = mock_get_session.return_value - mock_session.get.return_value.status_code = 200 - mock_session.get.return_value.raise_for_status.return_value = None - mock_session.get.return_value.json.return_value = { - "classes": [{"enrollment_total": 15, "class_capacity": 20, "wait_tot": 0, "wait_cap": 0}] - } - - # Check thread isolation: each thread gets its own session + def test_thread_local_sessions(self): + """Test that each thread gets its own session instance.""" sessions = [] for _ in range(5): thread = threading.Thread(target=lambda: sessions.append(get_session())) thread.start() thread.join() - self.assertEqual(len(set(id(s) for s in sessions)), 5, "Each thread should have its own session") - - # Check connection cleanup: no connection leak after fetch - initial_connections = len([c for c in connections.all() if c.is_usable()]) - fetch_section_data(self.section) - gc.collect() - final_connections = len([c for c in connections.all() if c.is_usable()]) - self.assertLessEqual(final_connections, initial_connections + 1, "Connections should be closed") + self.assertEqual( + len(set(id(s) for s in sessions)), + 5, + "Each thread should have its own session" + )