Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tcf_website/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"""DRF Viewsets"""
import asyncio
from threading import Thread
from django.db import connection
from django.db import connections
from django.db.models import Avg, Sum
from django.http import JsonResponse
from rest_framework import viewsets
Expand Down Expand Up @@ -197,7 +197,7 @@ def _run_update():
except (asyncio.TimeoutError, requests.RequestException, ValueError) as exc:
print(f"Enrollment update failed for course {pk}: {exc}")
finally:
connection.close()
connections.close_all()

thread = Thread(target=_run_update, daemon=True)
thread.start()
Expand Down
47 changes: 30 additions & 17 deletions tcf_website/management/commands/fetch_enrollment.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
docker exec -it tcf_django python manage.py fetch_enrollment
"""

import threading
import time
from concurrent.futures import ThreadPoolExecutor

import backoff
import requests
from django.core.management.base import BaseCommand
from django.db import connections
from requests.adapters import HTTPAdapter
from tqdm import tqdm
from urllib3.util.retry import Retry
Expand All @@ -22,28 +24,16 @@
# Maximum time to wait for a response from the server
TIMEOUT = 30
# Number of concurrent workers for fetching data
MAX_WORKERS = 20
MAX_WORKERS = 5
# Initial wait time for backoff (in seconds)
INITIAL_WAIT = 2
# Maximum number of retry attempts
MAX_TRIES = 8
# Kept larger than MAX_WORKERS to handle connection lifecycle issues
MAX_POOL_SIZE = 20 * 5

# Configure session with retry strategy and connection pooling
session = requests.Session()
retry_strategy = Retry(
total=3,
backoff_factor=0.1,
status_forcelist=[429, 500, 502, 503, 504],
)
adapter = HTTPAdapter(
pool_connections=MAX_POOL_SIZE,
pool_maxsize=MAX_POOL_SIZE,
max_retries=retry_strategy,
)
session.mount("http://", adapter)
session.mount("https://", adapter)
# Thread-local storage for requests sessions to avoid contention
_session_local = threading.local()


def should_retry_request(exception):
Expand All @@ -56,6 +46,25 @@ def should_retry_request(exception):
return False


def get_session():
"""Get or create a thread-local requests session."""
if not hasattr(_session_local, "session"):
retry_strategy = Retry(
total=3,
backoff_factor=0.1,
status_forcelist=[429, 500, 502, 503, 504],
)
adapter = HTTPAdapter(
pool_connections=MAX_POOL_SIZE,
pool_maxsize=MAX_POOL_SIZE,
max_retries=retry_strategy,
)
_session_local.session = requests.Session()
_session_local.session.mount("http://", adapter)
_session_local.session.mount("https://", adapter)
return _session_local.session


@backoff.on_exception(
backoff.expo,
requests.exceptions.RequestException,
Expand All @@ -81,6 +90,7 @@ def fetch_section_data(section):
)

try:
session = get_session()
# Fetch and validate response
response = session.get(url, timeout=TIMEOUT)
response.raise_for_status()
Expand Down Expand Up @@ -152,10 +162,12 @@ def handle(self, *args, **options):

print(f"Fetching enrollment data for semester {semester}...")

sections = Section.objects.filter(semester=semester)
total_sections = sections.count()
sections_queryset = Section.objects.filter(semester=semester)
total_sections = sections_queryset.count()
print(f"Found {total_sections} sections")

sections = list(sections_queryset)

# Process sections in parallel using ThreadPoolExecutor
success_count = 0
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
Expand All @@ -171,6 +183,7 @@ def handle(self, *args, **options):
except KeyboardInterrupt:
print("\nProcess interrupted by user. Shutting down...")
executor.shutdown(wait=False)
connections.close_all()
return

elapsed_time = time.time() - start_time
Expand Down
55 changes: 38 additions & 17 deletions tcf_website/tests/test_fetch_enrollment.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
"""Tests for fetch_enrollment management command."""

import threading
from unittest.mock import patch

from django.test import TestCase

from tcf_website.management.commands.fetch_enrollment import fetch_section_data
from tcf_website.management.commands.fetch_enrollment import fetch_section_data, get_session
from tcf_website.models import SectionEnrollment

from .test_utils import setup
Expand All @@ -19,11 +20,13 @@ def setUp(self):
# pylint: disable=no-member
self.section = self.section_course

@patch("tcf_website.management.commands.fetch_enrollment.session.get")
def test_fetch_enrollment_success(self, mock_get):
@patch("tcf_website.management.commands.fetch_enrollment.get_session")
def test_fetch_enrollment_success(self, mock_get_session):
"""Test successful enrollment fetch."""
mock_get.return_value.status_code = 200
mock_get.return_value.json.return_value = {
mock_session = mock_get_session.return_value
mock_session.get.return_value.status_code = 200
mock_session.get.return_value.raise_for_status.return_value = None
mock_session.get.return_value.json.return_value = {
"classes": [
{
"enrollment_total": 15,
Expand All @@ -42,8 +45,8 @@ def test_fetch_enrollment_success(self, mock_get):
self.assertEqual(enrollment.waitlist_taken, 5)
self.assertEqual(enrollment.waitlist_limit, 10)

@patch("tcf_website.management.commands.fetch_enrollment.session.get")
def test_fetch_enrollment_update_existing(self, mock_get):
@patch("tcf_website.management.commands.fetch_enrollment.get_session")
def test_fetch_enrollment_update_existing(self, mock_get_session):
"""Test updating existing enrollment data."""
# Create initial enrollment
SectionEnrollment.objects.create(
Expand All @@ -54,8 +57,10 @@ def test_fetch_enrollment_update_existing(self, mock_get):
waitlist_limit=5,
)

mock_get.return_value.status_code = 200
mock_get.return_value.json.return_value = {
mock_session = mock_get_session.return_value
mock_session.get.return_value.status_code = 200
mock_session.get.return_value.raise_for_status.return_value = None
mock_session.get.return_value.json.return_value = {
"classes": [
{
"enrollment_total": 15,
Expand All @@ -74,24 +79,40 @@ def test_fetch_enrollment_update_existing(self, mock_get):
self.assertEqual(enrollment.waitlist_taken, 8)
self.assertEqual(enrollment.waitlist_limit, 12)

@patch("tcf_website.management.commands.fetch_enrollment.session.get")
def test_fetch_enrollment_empty_response(self, mock_get):
@patch("tcf_website.management.commands.fetch_enrollment.get_session")
def test_fetch_enrollment_empty_response(self, mock_get_session):
"""Test handling of empty API response."""
mock_get.return_value.status_code = 200
mock_get.return_value.json.return_value = {"classes": []}
mock_session = mock_get_session.return_value
mock_session.get.return_value.status_code = 200
mock_session.get.return_value.raise_for_status.return_value = None
mock_session.get.return_value.json.return_value = {"classes": []}

result = fetch_section_data(self.section)

self.assertFalse(result)
self.assertEqual(SectionEnrollment.objects.count(), 0)

@patch("tcf_website.management.commands.fetch_enrollment.session.get")
def test_fetch_enrollment_api_error(self, mock_get):
@patch("tcf_website.management.commands.fetch_enrollment.get_session")
def test_fetch_enrollment_api_error(self, mock_get_session):
"""Test handling of API error."""
mock_get.return_value.status_code = 500
mock_get.return_value.raise_for_status.side_effect = Exception("API Error")
mock_session = mock_get_session.return_value
mock_session.get.return_value.status_code = 500
mock_session.get.return_value.raise_for_status.side_effect = Exception("API Error")

result = fetch_section_data(self.section)

self.assertFalse(result)
self.assertEqual(SectionEnrollment.objects.count(), 0)

def test_thread_local_sessions(self):
"""Test that each thread gets its own session instance."""
sessions = []
for _ in range(5):
thread = threading.Thread(target=lambda: sessions.append(get_session()))
thread.start()
thread.join()
self.assertEqual(
len(set(id(s) for s in sessions)),
5,
"Each thread should have its own session"
)