diff --git a/tcf_website/api/views.py b/tcf_website/api/views.py index 50bff55a..7f180ab6 100644 --- a/tcf_website/api/views.py +++ b/tcf_website/api/views.py @@ -2,7 +2,7 @@ """DRF Viewsets""" import asyncio from threading import Thread -from django.db import connection +from django.db import connections from django.db.models import Avg, Sum from django.http import JsonResponse from rest_framework import viewsets @@ -197,7 +197,7 @@ def _run_update(): except (asyncio.TimeoutError, requests.RequestException, ValueError) as exc: print(f"Enrollment update failed for course {pk}: {exc}") finally: - connection.close() + connections.close_all() thread = Thread(target=_run_update, daemon=True) thread.start() diff --git a/tcf_website/management/commands/fetch_enrollment.py b/tcf_website/management/commands/fetch_enrollment.py index 2545c456..6cd15ab2 100644 --- a/tcf_website/management/commands/fetch_enrollment.py +++ b/tcf_website/management/commands/fetch_enrollment.py @@ -7,12 +7,14 @@ docker exec -it tcf_django python manage.py fetch_enrollment """ +import threading import time from concurrent.futures import ThreadPoolExecutor import backoff import requests from django.core.management.base import BaseCommand +from django.db import connections from requests.adapters import HTTPAdapter from tqdm import tqdm from urllib3.util.retry import Retry @@ -22,7 +24,7 @@ # Maximum time to wait for a response from the server TIMEOUT = 30 # Number of concurrent workers for fetching data -MAX_WORKERS = 20 +MAX_WORKERS = 5 # Initial wait time for backoff (in seconds) INITIAL_WAIT = 2 # Maximum number of retry attempts @@ -30,20 +32,8 @@ # Kept larger than MAX_WORKERS to handle connection lifecycle issues MAX_POOL_SIZE = 20 * 5 -# Configure session with retry strategy and connection pooling -session = requests.Session() -retry_strategy = Retry( - total=3, - backoff_factor=0.1, - status_forcelist=[429, 500, 502, 503, 504], -) -adapter = HTTPAdapter( - pool_connections=MAX_POOL_SIZE, - pool_maxsize=MAX_POOL_SIZE, - max_retries=retry_strategy, -) -session.mount("http://", adapter) -session.mount("https://", adapter) +# Thread-local storage for requests sessions to avoid contention +_session_local = threading.local() def should_retry_request(exception): @@ -56,6 +46,25 @@ def should_retry_request(exception): return False +def get_session(): + """Get or create a thread-local requests session.""" + if not hasattr(_session_local, "session"): + retry_strategy = Retry( + total=3, + backoff_factor=0.1, + status_forcelist=[429, 500, 502, 503, 504], + ) + adapter = HTTPAdapter( + pool_connections=MAX_POOL_SIZE, + pool_maxsize=MAX_POOL_SIZE, + max_retries=retry_strategy, + ) + _session_local.session = requests.Session() + _session_local.session.mount("http://", adapter) + _session_local.session.mount("https://", adapter) + return _session_local.session + + @backoff.on_exception( backoff.expo, requests.exceptions.RequestException, @@ -81,6 +90,7 @@ def fetch_section_data(section): ) try: + session = get_session() # Fetch and validate response response = session.get(url, timeout=TIMEOUT) response.raise_for_status() @@ -152,10 +162,12 @@ def handle(self, *args, **options): print(f"Fetching enrollment data for semester {semester}...") - sections = Section.objects.filter(semester=semester) - total_sections = sections.count() + sections_queryset = Section.objects.filter(semester=semester) + total_sections = sections_queryset.count() print(f"Found {total_sections} sections") + sections = list(sections_queryset) + # Process sections in parallel using ThreadPoolExecutor success_count = 0 with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor: @@ -171,6 +183,7 @@ def handle(self, *args, **options): except KeyboardInterrupt: print("\nProcess interrupted by user. Shutting down...") executor.shutdown(wait=False) + connections.close_all() return elapsed_time = time.time() - start_time diff --git a/tcf_website/tests/test_fetch_enrollment.py b/tcf_website/tests/test_fetch_enrollment.py index 5c4a15e5..bcc1d95d 100644 --- a/tcf_website/tests/test_fetch_enrollment.py +++ b/tcf_website/tests/test_fetch_enrollment.py @@ -1,10 +1,11 @@ """Tests for fetch_enrollment management command.""" +import threading from unittest.mock import patch from django.test import TestCase -from tcf_website.management.commands.fetch_enrollment import fetch_section_data +from tcf_website.management.commands.fetch_enrollment import fetch_section_data, get_session from tcf_website.models import SectionEnrollment from .test_utils import setup @@ -19,11 +20,13 @@ def setUp(self): # pylint: disable=no-member self.section = self.section_course - @patch("tcf_website.management.commands.fetch_enrollment.session.get") - def test_fetch_enrollment_success(self, mock_get): + @patch("tcf_website.management.commands.fetch_enrollment.get_session") + def test_fetch_enrollment_success(self, mock_get_session): """Test successful enrollment fetch.""" - mock_get.return_value.status_code = 200 - mock_get.return_value.json.return_value = { + mock_session = mock_get_session.return_value + mock_session.get.return_value.status_code = 200 + mock_session.get.return_value.raise_for_status.return_value = None + mock_session.get.return_value.json.return_value = { "classes": [ { "enrollment_total": 15, @@ -42,8 +45,8 @@ def test_fetch_enrollment_success(self, mock_get): self.assertEqual(enrollment.waitlist_taken, 5) self.assertEqual(enrollment.waitlist_limit, 10) - @patch("tcf_website.management.commands.fetch_enrollment.session.get") - def test_fetch_enrollment_update_existing(self, mock_get): + @patch("tcf_website.management.commands.fetch_enrollment.get_session") + def test_fetch_enrollment_update_existing(self, mock_get_session): """Test updating existing enrollment data.""" # Create initial enrollment SectionEnrollment.objects.create( @@ -54,8 +57,10 @@ def test_fetch_enrollment_update_existing(self, mock_get): waitlist_limit=5, ) - mock_get.return_value.status_code = 200 - mock_get.return_value.json.return_value = { + mock_session = mock_get_session.return_value + mock_session.get.return_value.status_code = 200 + mock_session.get.return_value.raise_for_status.return_value = None + mock_session.get.return_value.json.return_value = { "classes": [ { "enrollment_total": 15, @@ -74,24 +79,40 @@ def test_fetch_enrollment_update_existing(self, mock_get): self.assertEqual(enrollment.waitlist_taken, 8) self.assertEqual(enrollment.waitlist_limit, 12) - @patch("tcf_website.management.commands.fetch_enrollment.session.get") - def test_fetch_enrollment_empty_response(self, mock_get): + @patch("tcf_website.management.commands.fetch_enrollment.get_session") + def test_fetch_enrollment_empty_response(self, mock_get_session): """Test handling of empty API response.""" - mock_get.return_value.status_code = 200 - mock_get.return_value.json.return_value = {"classes": []} + mock_session = mock_get_session.return_value + mock_session.get.return_value.status_code = 200 + mock_session.get.return_value.raise_for_status.return_value = None + mock_session.get.return_value.json.return_value = {"classes": []} result = fetch_section_data(self.section) self.assertFalse(result) self.assertEqual(SectionEnrollment.objects.count(), 0) - @patch("tcf_website.management.commands.fetch_enrollment.session.get") - def test_fetch_enrollment_api_error(self, mock_get): + @patch("tcf_website.management.commands.fetch_enrollment.get_session") + def test_fetch_enrollment_api_error(self, mock_get_session): """Test handling of API error.""" - mock_get.return_value.status_code = 500 - mock_get.return_value.raise_for_status.side_effect = Exception("API Error") + mock_session = mock_get_session.return_value + mock_session.get.return_value.status_code = 500 + mock_session.get.return_value.raise_for_status.side_effect = Exception("API Error") result = fetch_section_data(self.section) self.assertFalse(result) self.assertEqual(SectionEnrollment.objects.count(), 0) + + def test_thread_local_sessions(self): + """Test that each thread gets its own session instance.""" + sessions = [] + for _ in range(5): + thread = threading.Thread(target=lambda: sessions.append(get_session())) + thread.start() + thread.join() + self.assertEqual( + len(set(id(s) for s in sessions)), + 5, + "Each thread should have its own session" + )