Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion enterprise_access/apps/customer_billing/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,10 @@ def create_free_trial_checkout_session(
checkout_intent=intent,
)

intent.update_stripe_session_id(checkout_session['id'])
intent.update_stripe_identifiers(
session_id=checkout_session['id'],
customer_id=checkout_session.get('customer'),
)
logger.info(f'Updated checkout intent {intent.id} with Stripe session {checkout_session["id"]}')

return checkout_session
26 changes: 23 additions & 3 deletions enterprise_access/apps/customer_billing/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def is_valid_state_transition(
allowed_transitions = ALLOWED_CHECKOUT_INTENT_STATE_TRANSITIONS.get(current_state, [])
return new_state in allowed_transitions

def mark_as_paid(self, stripe_session_id=None):
def mark_as_paid(self, stripe_session_id=None, stripe_customer_id=None, **kwargs):
"""Mark the intent as paid after successful Stripe checkout."""
if not self.is_valid_state_transition(CheckoutIntentState(self.state), CheckoutIntentState.PAID):
raise ValueError(f"Cannot transition from {self.state} to {CheckoutIntentState.PAID}.")
Expand All @@ -197,10 +197,17 @@ def mark_as_paid(self, stripe_session_id=None):
if self.state == CheckoutIntentState.PAID and stripe_session_id != self.stripe_checkout_session_id:
raise ValueError("Cannot transition from PAID to PAID with a different stripe_checkout_session_id")

if stripe_customer_id:
if self.state == CheckoutIntentState.PAID and stripe_customer_id != self.stripe_customer_id:
raise ValueError("Cannot transition from PAID to PAID with a different stripe_customer_id")

self.state = CheckoutIntentState.PAID
if stripe_session_id:
self.stripe_checkout_session_id = stripe_session_id
self.save(update_fields=['state', 'stripe_checkout_session_id', 'modified'])
if stripe_customer_id:
self.stripe_customer_id = stripe_customer_id

self.save(update_fields=['state', 'stripe_checkout_session_id', 'stripe_customer_id', 'modified'])
logger.info(f'CheckoutIntent {self} marked as {CheckoutIntentState.PAID}.')
return self

Expand Down Expand Up @@ -467,6 +474,19 @@ def for_user(cls, user):
return cls.objects.filter(user=user).first()

def update_stripe_session_id(self, session_id):
"""Update the associated Stripe checkout session ID."""
"""
Deprecated in favor of update_stripe_identifiers below.
Update the associated Stripe checkout session ID.
"""
self.stripe_checkout_session_id = session_id
self.save(update_fields=['stripe_checkout_session_id', 'modified'])

def update_stripe_identifiers(self, session_id=None, customer_id=None):
"""
Updates stripe identifiers related to this checkout intent record.
"""
if session_id:
self.stripe_checkout_session_id = session_id
if customer_id:
self.stripe_customer_id = customer_id
self.save(update_fields=['stripe_checkout_session_id', 'stripe_customer_id', 'modified'])
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,10 @@ def invoice_paid(event: stripe.Event) -> None:
Handle invoice.paid events.
"""
# Extract relevant metadata for logging
invoice_id = event.data.object.id
subscription_details = event.data.object.parent.subscription_details
invoice = event.data.object
invoice_id = invoice.id
stripe_customer_id = invoice['customer']
subscription_details = invoice.parent.subscription_details
subscription_id = subscription_details['subscription']

# Extract the checkout_intent ID from the related subscription.
Expand All @@ -78,11 +80,12 @@ def invoice_paid(event: stripe.Event) -> None:
logger.info(
'Found existing CheckoutIntent record with '
f'id={checkout_intent_id}, '
f'stripe_customer_id={stripe_customer_id}, '
f'stripe_checkout_session_id={checkout_intent.stripe_checkout_session_id}, '
f'state={checkout_intent.state}. '
'Marking intent as paid...'
)
checkout_intent.mark_as_paid()
checkout_intent.mark_as_paid(stripe_customer_id=stripe_customer_id)

@on_stripe_event('customer.subscription.trial_will_end')
@staticmethod
Expand Down
8 changes: 6 additions & 2 deletions enterprise_access/apps/customer_billing/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,10 @@ def test_create_free_trial_checkout_session_success(
mock_lms_client = mock_lms_client_class.return_value
mock_lms_client.get_lms_user_account.return_value = [{'id': self.user.lms_user_id}]
mock_lms_client.get_enterprise_customer_data.side_effect = raise_404_error
mock_stripe.checkout.Session.create.return_value = {'id': 'test-stripe-checkout-session'}
mock_stripe.checkout.Session.create.return_value = {
'id': 'test-stripe-checkout-session',
'customer': 'cust-123',
}
mock_stripe.Customer.search.return_value.data = []

# Actually call the API under test.
Expand All @@ -112,7 +115,7 @@ def test_create_free_trial_checkout_session_success(
# Assert API response.
self.assertEqual(
result,
{'id': 'test-stripe-checkout-session'},
{'id': 'test-stripe-checkout-session', 'customer': 'cust-123'},
)

# Assert that a CheckoutIntent was created
Expand All @@ -121,6 +124,7 @@ def test_create_free_trial_checkout_session_success(
self.assertEqual(intent.enterprise_slug, 'my-sluggy')
self.assertEqual(intent.enterprise_name, 'My Cool Company')
self.assertEqual(intent.stripe_checkout_session_id, 'test-stripe-checkout-session')
self.assertEqual(intent.stripe_customer_id, 'cust-123')
self.assertFalse(intent.is_expired())

# Assert library methods were called correctly.
Expand Down
86 changes: 86 additions & 0 deletions enterprise_access/apps/customer_billing/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,3 +539,89 @@ def test_create_intent_with_terms_metadata(self):
terms_metadata=complex_metadata
)
self.assertEqual(intent3.terms_metadata, complex_metadata)

def test_mark_as_paid_with_stripe_customer_id(self):
"""Test mark_as_paid with stripe_customer_id parameter."""
intent = CheckoutIntent.create_intent(
user=cast(AbstractUser, self.user1),
slug=self.basic_data['enterprise_slug'],
name=self.basic_data['enterprise_name'],
quantity=self.basic_data['quantity']
)

# Test marking as paid with both session_id and customer_id
intent.mark_as_paid('cs_test_123', 'cus_test_456')
self.assertEqual(intent.state, CheckoutIntentState.PAID)
self.assertEqual(intent.stripe_checkout_session_id, 'cs_test_123')
self.assertEqual(intent.stripe_customer_id, 'cus_test_456')

def test_mark_as_paid_with_only_stripe_customer_id(self):
"""Test mark_as_paid with only stripe_customer_id parameter."""
intent = CheckoutIntent.create_intent(
user=cast(AbstractUser, self.user2),
slug='another-slug',
name='Another Enterprise',
quantity=7
)

# Test marking as paid with only customer_id
intent.mark_as_paid(stripe_customer_id='cus_test_789')
self.assertEqual(intent.state, CheckoutIntentState.PAID)
self.assertIsNone(intent.stripe_checkout_session_id)
self.assertEqual(intent.stripe_customer_id, 'cus_test_789')

def test_mark_as_paid_idempotent_with_stripe_customer_id(self):
"""Test that mark_as_paid is idempotent when called with same stripe_customer_id."""
intent = CheckoutIntent.create_intent(
user=cast(AbstractUser, self.user1),
slug=self.basic_data['enterprise_slug'],
name=self.basic_data['enterprise_name'],
quantity=self.basic_data['quantity']
)

# Mark as paid first time
intent.mark_as_paid('cs_test_123', 'cus_test_456')
first_modified = intent.modified

# Mark as paid again with same values - should be idempotent
intent.mark_as_paid('cs_test_123', 'cus_test_456')
self.assertEqual(intent.state, CheckoutIntentState.PAID)
self.assertEqual(intent.stripe_checkout_session_id, 'cs_test_123')
self.assertEqual(intent.stripe_customer_id, 'cus_test_456')
# Modified time should have changed since we called save()
self.assertGreater(intent.modified, first_modified)

def test_mark_as_paid_different_stripe_customer_id_raises_error(self):
"""Test that mark_as_paid raises error when called with different stripe_customer_id."""
intent = CheckoutIntent.create_intent(
user=cast(AbstractUser, self.user1),
slug=self.basic_data['enterprise_slug'],
name=self.basic_data['enterprise_name'],
quantity=self.basic_data['quantity']
)

# Mark as paid first time
intent.mark_as_paid(stripe_customer_id='cus_test_456')

# Try to mark as paid with different customer_id - should raise ValueError
with self.assertRaises(ValueError) as context:
intent.mark_as_paid(stripe_customer_id='cus_test_different')

self.assertIn('Cannot transition from PAID to PAID with a different stripe_customer_id', str(context.exception))

def test_mark_as_paid_update_fields_includes_stripe_customer_id(self):
"""Test that save() includes stripe_customer_id in update_fields."""
intent = CheckoutIntent.create_intent(
user=cast(AbstractUser, self.user1),
slug=self.basic_data['enterprise_slug'],
name=self.basic_data['enterprise_name'],
quantity=self.basic_data['quantity']
)

with mock.patch.object(intent, 'save') as mock_save:
intent.mark_as_paid(stripe_customer_id='cus_test_456')

# Verify save was called with correct update_fields
mock_save.assert_called_once_with(
update_fields=['state', 'stripe_checkout_session_id', 'stripe_customer_id', 'modified']
)
Original file line number Diff line number Diff line change
Expand Up @@ -122,18 +122,21 @@ def test_invoice_paid_handler(
):
"""Test various scenarios for the invoice.paid event handler."""
if checkout_intent_state == CheckoutIntentState.PAID:
self.checkout_intent.update_stripe_session_id(self.stripe_checkout_session_id)
self.checkout_intent.mark_as_paid()
self.checkout_intent.mark_as_paid(
stripe_session_id=self.stripe_checkout_session_id,
stripe_customer_id='cus_test_789',
)

subscription_id = 'sub_test_123456'
mock_subscription = self._create_mock_stripe_subscription(intent_id_override or self.checkout_intent.id)
invoice_data = {
'id': 'in_test_123456',
'customer': 'cus_test_789',
'parent': {
'subscription_details': {
'metadata': mock_subscription,
'subscription': subscription_id
}
},
},
}

Expand All @@ -159,10 +162,85 @@ def test_invoice_paid_handler(
mock_logger.info.assert_any_call(
'Found existing CheckoutIntent record with '
f'id={self.checkout_intent.id}, '
f'stripe_customer_id=cus_test_789, '
f'stripe_checkout_session_id={self.checkout_intent.stripe_checkout_session_id}, '
f'state={checkout_intent_state}. '
'Marking intent as paid...'
)
mock_logger.info.assert_any_call(
f'[StripeEventHandler] handler for <stripe.Event id={mock_event.id} type=invoice.paid> complete.'
)

@mock.patch('enterprise_access.apps.customer_billing.stripe_event_handlers.logger')
def test_invoice_paid_handler_sets_stripe_customer_id(self, mock_logger):
"""Test that invoice.paid handler correctly sets stripe_customer_id on CheckoutIntent."""
subscription_id = 'sub_test_customer_id_123'
stripe_customer_id = 'cus_test_customer_456'
mock_subscription = self._create_mock_stripe_subscription(self.checkout_intent.id)

invoice_data = {
'id': 'in_test_customer_123',
'customer': stripe_customer_id,
'parent': {
'subscription_details': {
'metadata': mock_subscription,
'subscription': subscription_id
}
},
}

mock_event = self._create_mock_stripe_event('invoice.paid', invoice_data)

# Verify initial state
self.assertEqual(self.checkout_intent.state, CheckoutIntentState.CREATED)
self.assertIsNone(self.checkout_intent.stripe_customer_id)

# Handle the event
StripeEventHandler.dispatch(mock_event)

# Verify the checkout intent was updated correctly
self.checkout_intent.refresh_from_db()
self.assertEqual(self.checkout_intent.state, CheckoutIntentState.PAID)
self.assertEqual(self.checkout_intent.stripe_customer_id, stripe_customer_id)

# Verify logging includes the customer_id
mock_logger.info.assert_any_call(
'Found existing CheckoutIntent record with '
f'id={self.checkout_intent.id}, '
f'stripe_customer_id={stripe_customer_id}, '
f'stripe_checkout_session_id={self.checkout_intent.stripe_checkout_session_id}, '
f'state={CheckoutIntentState.CREATED}. '
'Marking intent as paid...'
)

def test_invoice_paid_handler_idempotent_with_same_customer_id(self):
"""Test that invoice.paid handler is idempotent when called with same stripe_customer_id."""
subscription_id = 'sub_test_idempotent_123'
stripe_customer_id = 'cus_test_idempotent_456'
mock_subscription = self._create_mock_stripe_subscription(self.checkout_intent.id)

# First mark the intent as paid with the customer_id
self.checkout_intent.mark_as_paid(stripe_customer_id=stripe_customer_id)
self.assertEqual(self.checkout_intent.state, CheckoutIntentState.PAID)
self.assertEqual(self.checkout_intent.stripe_customer_id, stripe_customer_id)

invoice_data = {
'id': 'in_test_idempotent_123',
'customer': stripe_customer_id,
'parent': {
'subscription_details': {
'metadata': mock_subscription,
'subscription': subscription_id
}
},
}

mock_event = self._create_mock_stripe_event('invoice.paid', invoice_data)

# Handle the event - should be idempotent
StripeEventHandler.dispatch(mock_event)

# Verify the checkout intent state remains unchanged
self.checkout_intent.refresh_from_db()
self.assertEqual(self.checkout_intent.state, CheckoutIntentState.PAID)
self.assertEqual(self.checkout_intent.stripe_customer_id, stripe_customer_id)