Skip to content

Commit e48c500

Browse files
committed
feat(commands): reset_db for Postgres backend
- connects to POSTGRES_HOST as an elevated POSTGRES_USER - drops and recreates DJANGO_USER - drops and recreates DJANGO_DB, giving ownership to DJANGO_USER with some other minor configuration override database engine for tests where we continue to use SQLite
1 parent a9aaa52 commit e48c500

File tree

9 files changed

+196
-7
lines changed

9 files changed

+196
-7
lines changed

bin/reset_db.sh

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,7 @@ set -ex
55
DB_RESET="${DJANGO_DB_RESET:-true}"
66

77
if [[ $DB_RESET = true ]]; then
8-
# construct the path to the database file from environment or default
9-
DB_DIR="${DJANGO_STORAGE_DIR:-.}"
10-
DB_FILE="${DJANGO_DB_FILE:-django.db}"
11-
DB_PATH="${DB_DIR}/${DB_FILE}"
12-
13-
rm -f "${DB_PATH}"
8+
python manage.py reset_db
149

1510
# run database migrations and other initialization
1611
bin/init.sh

pems/core/management/__init__.py

Whitespace-only changes.

pems/core/management/commands/__init__.py

Whitespace-only changes.
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import os
2+
3+
import psycopg
4+
5+
from django.core.management.base import BaseCommand
6+
from django.db import connection
7+
8+
9+
class Command(BaseCommand):
10+
help = "Completely resets the database (DESTRUCTIVE)."
11+
12+
def admin_connection(self) -> psycopg.Connection:
13+
db_host = connection.settings_dict["HOST"]
14+
db_port = connection.settings_dict["PORT"]
15+
16+
postgres_db = os.environ.get("POSTGRES_DB", "postgres")
17+
admin_name = os.environ.get("POSTGRES_USER", "postgres")
18+
admin_password = os.environ.get("POSTGRES_PASSWORD")
19+
20+
if not admin_password:
21+
self.stderr.write(self.style.ERROR("POSTGRES_PASSWORD environment variable not set."))
22+
return
23+
24+
return psycopg.connect(
25+
host=db_host,
26+
port=db_port,
27+
user=admin_name,
28+
password=admin_password,
29+
dbname=postgres_db,
30+
# Execute SQL commands immediately
31+
autocommit=True,
32+
)
33+
34+
def handle(self, *args, **options):
35+
db_name = connection.settings_dict["NAME"]
36+
db_user = connection.settings_dict["USER"]
37+
db_password = connection.settings_dict["PASSWORD"]
38+
39+
if not db_password:
40+
self.stderr.write(self.style.ERROR("DJANGO_DB_PASSWORD environment variable not set."))
41+
return
42+
43+
try:
44+
with self.admin_connection() as conn:
45+
cursor = conn.cursor()
46+
self.stdout.write(self.style.WARNING("Attempting database reset..."))
47+
48+
# Revoke existing connections
49+
cursor.execute(
50+
"""
51+
SELECT pg_terminate_backend(pg_stat_activity.pid)
52+
FROM pg_stat_activity
53+
WHERE datname = %s AND pg_stat_activity.pid <> pg_backend_pid()
54+
""",
55+
[db_name],
56+
)
57+
self.stdout.write(self.style.SUCCESS(f"Terminated existing connections to '{db_name}'."))
58+
59+
# Drop database
60+
cursor.execute(f"DROP DATABASE IF EXISTS {db_name}")
61+
self.stdout.write(self.style.SUCCESS(f"Database '{db_name}' dropped."))
62+
63+
# Create Django user
64+
cursor.execute(f"DROP USER IF EXISTS {db_user}")
65+
cursor.execute(f"CREATE USER {db_user} WITH PASSWORD '{db_password}'")
66+
self.stdout.write(self.style.SUCCESS(f"Django user '{db_user}' created."))
67+
68+
# Create Django database with Django user as owner
69+
cursor.execute(f"CREATE DATABASE {db_name} WITH OWNER {db_user} ENCODING 'UTF-8'")
70+
self.stdout.write(self.style.SUCCESS(f"Database '{db_name}' created and owned by '{db_user}'."))
71+
72+
self.stdout.write(self.style.SUCCESS("Database reset and user setup complete."))
73+
74+
except Exception as e:
75+
self.stderr.write(self.style.ERROR(f"Error during database reset: {e}"))

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ preserve_blank_lines = true
4848
use_gitignore = true
4949

5050
[tool.pytest.ini_options]
51-
DJANGO_SETTINGS_MODULE = "pems.settings"
51+
DJANGO_SETTINGS_MODULE = "tests.pytest.settings"
5252

5353
[tool.setuptools.packages.find]
5454
include = ["pems*"]

tests/pytest/pems/core/management/__init__.py

Whitespace-only changes.

tests/pytest/pems/core/management/commands/__init__.py

Whitespace-only changes.
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
import pytest
2+
3+
from pems.core.management.commands import reset_db
4+
5+
6+
@pytest.fixture
7+
def command(mocker):
8+
cmd = reset_db.Command()
9+
# mocking command I/O for tests
10+
cmd.stderr = mocker.Mock()
11+
cmd.stdout = mocker.Mock()
12+
cmd.style.ERROR = str
13+
cmd.style.SUCCESS = str
14+
cmd.style.WARNING = str
15+
return cmd
16+
17+
18+
@pytest.fixture(autouse=True)
19+
def db_settings(settings):
20+
settings.DATABASES["default"]["NAME"] = "test_db"
21+
settings.DATABASES["default"]["USER"] = "test_user"
22+
settings.DATABASES["default"]["PASSWORD"] = "test_password"
23+
settings.DATABASES["default"]["HOST"] = "test_host"
24+
settings.DATABASES["default"]["PORT"] = "1234"
25+
return settings
26+
27+
28+
@pytest.fixture
29+
def mock_connect(mocker):
30+
mock = mocker.patch.object(reset_db.psycopg, "connect", mocker.Mock())
31+
return mock
32+
33+
34+
@pytest.fixture
35+
def mock_cursor(mocker):
36+
return mocker.Mock()
37+
38+
39+
@pytest.fixture
40+
def mock_conn(mocker, mock_cursor):
41+
mock = mocker.MagicMock()
42+
# fake context manager support for `with manager() as ctx:`
43+
mock.__enter__.return_value = mock
44+
mock.cursor.return_value = mock_cursor
45+
return mock
46+
47+
48+
@pytest.fixture
49+
def mock_admin_connection(mocker, mock_conn):
50+
return mocker.patch.object(reset_db.Command, "admin_connection", return_value=mock_conn)
51+
52+
53+
@pytest.mark.django_db
54+
class TestCommand:
55+
@pytest.mark.parametrize(
56+
"env_vars,should_connect",
57+
[
58+
({"POSTGRES_DB": "test_db", "POSTGRES_USER": "test_user", "POSTGRES_PASSWORD": "test_pass"}, True),
59+
({"POSTGRES_DB": "test_db", "POSTGRES_USER": "test_user", "POSTGRES_PASSWORD": ""}, False),
60+
],
61+
)
62+
def test_admin_connection(self, command: reset_db.Command, monkeypatch, mock_connect, env_vars, should_connect):
63+
for key, value in env_vars.items():
64+
monkeypatch.setenv(key, value)
65+
66+
result = command.admin_connection()
67+
68+
if should_connect:
69+
mock_connect.assert_called_once_with(
70+
host="test_host", port="1234", user="test_user", password="test_pass", dbname="test_db", autocommit=True
71+
)
72+
assert result is not None
73+
else:
74+
assert result is None
75+
76+
def test_handle_success(self, command: reset_db.Command, mocker, mock_admin_connection, mock_cursor):
77+
mocker.patch("os.environ.get", return_value="test_password")
78+
79+
command.handle()
80+
81+
mock_admin_connection.assert_called_once()
82+
assert mock_cursor.execute.call_count == 5
83+
84+
# Verify success messages were logged
85+
expected_messages = [
86+
"Attempting database reset...",
87+
"Terminated existing connections to 'test_db'.",
88+
"Database 'test_db' dropped.",
89+
"Django user 'test_user' created.",
90+
"Database 'test_db' created and owned by 'test_user'.",
91+
"Database reset and user setup complete.",
92+
]
93+
94+
for message in expected_messages:
95+
command.stdout.write.assert_any_call(message)
96+
97+
def test_handle_no_db_password(self, command: reset_db.Command, db_settings, mock_admin_connection):
98+
db_settings.DATABASES["default"]["PASSWORD"] = None
99+
100+
command.handle()
101+
102+
mock_admin_connection.assert_not_called()
103+
command.stderr.write.assert_any_call("DJANGO_DB_PASSWORD environment variable not set.")
104+
105+
def test_handle_admin_connection_error(self, command: reset_db.Command, mock_admin_connection):
106+
mock_admin_connection.side_effect = Exception("Admin connection failed.")
107+
108+
command.handle()
109+
110+
mock_admin_connection.assert_called_once()
111+
command.stderr.write.assert_any_call("Error during database reset: Admin connection failed.")

tests/pytest/settings.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from pems.settings import * # noqa: F401, F403
2+
3+
DATABASES = {
4+
"default": {
5+
"ENGINE": "django.db.backends.sqlite3",
6+
"NAME": "test",
7+
}
8+
}

0 commit comments

Comments
 (0)