diff --git a/enterprise_access/apps/customer_billing/api.py b/enterprise_access/apps/customer_billing/api.py index db8dd3a2..38061c4e 100644 --- a/enterprise_access/apps/customer_billing/api.py +++ b/enterprise_access/apps/customer_billing/api.py @@ -423,7 +423,10 @@ def create_free_trial_checkout_session( checkout_intent=intent, ) - intent.update_stripe_session_id(checkout_session['id']) + intent.update_stripe_identifiers( + session_id=checkout_session['id'], + customer_id=checkout_session.get('customer'), + ) logger.info(f'Updated checkout intent {intent.id} with Stripe session {checkout_session["id"]}') return checkout_session diff --git a/enterprise_access/apps/customer_billing/models.py b/enterprise_access/apps/customer_billing/models.py index 5d031d8b..d95208cb 100644 --- a/enterprise_access/apps/customer_billing/models.py +++ b/enterprise_access/apps/customer_billing/models.py @@ -188,7 +188,7 @@ def is_valid_state_transition( allowed_transitions = ALLOWED_CHECKOUT_INTENT_STATE_TRANSITIONS.get(current_state, []) return new_state in allowed_transitions - def mark_as_paid(self, stripe_session_id=None): + def mark_as_paid(self, stripe_session_id=None, stripe_customer_id=None, **kwargs): """Mark the intent as paid after successful Stripe checkout.""" if not self.is_valid_state_transition(CheckoutIntentState(self.state), CheckoutIntentState.PAID): raise ValueError(f"Cannot transition from {self.state} to {CheckoutIntentState.PAID}.") @@ -197,10 +197,17 @@ def mark_as_paid(self, stripe_session_id=None): if self.state == CheckoutIntentState.PAID and stripe_session_id != self.stripe_checkout_session_id: raise ValueError("Cannot transition from PAID to PAID with a different stripe_checkout_session_id") + if stripe_customer_id: + if self.state == CheckoutIntentState.PAID and stripe_customer_id != self.stripe_customer_id: + raise ValueError("Cannot transition from PAID to PAID with a different stripe_customer_id") + self.state = CheckoutIntentState.PAID if stripe_session_id: self.stripe_checkout_session_id = stripe_session_id - self.save(update_fields=['state', 'stripe_checkout_session_id', 'modified']) + if stripe_customer_id: + self.stripe_customer_id = stripe_customer_id + + self.save(update_fields=['state', 'stripe_checkout_session_id', 'stripe_customer_id', 'modified']) logger.info(f'CheckoutIntent {self} marked as {CheckoutIntentState.PAID}.') return self @@ -467,6 +474,19 @@ def for_user(cls, user): return cls.objects.filter(user=user).first() def update_stripe_session_id(self, session_id): - """Update the associated Stripe checkout session ID.""" + """ + Deprecated in favor of update_stripe_identifiers below. + Update the associated Stripe checkout session ID. + """ self.stripe_checkout_session_id = session_id self.save(update_fields=['stripe_checkout_session_id', 'modified']) + + def update_stripe_identifiers(self, session_id=None, customer_id=None): + """ + Updates stripe identifiers related to this checkout intent record. + """ + if session_id: + self.stripe_checkout_session_id = session_id + if customer_id: + self.stripe_customer_id = customer_id + self.save(update_fields=['stripe_checkout_session_id', 'stripe_customer_id', 'modified']) diff --git a/enterprise_access/apps/customer_billing/stripe_event_handlers.py b/enterprise_access/apps/customer_billing/stripe_event_handlers.py index e8c68d84..9dfcaa36 100644 --- a/enterprise_access/apps/customer_billing/stripe_event_handlers.py +++ b/enterprise_access/apps/customer_billing/stripe_event_handlers.py @@ -61,8 +61,10 @@ def invoice_paid(event: stripe.Event) -> None: Handle invoice.paid events. """ # Extract relevant metadata for logging - invoice_id = event.data.object.id - subscription_details = event.data.object.parent.subscription_details + invoice = event.data.object + invoice_id = invoice.id + stripe_customer_id = invoice['customer'] + subscription_details = invoice.parent.subscription_details subscription_id = subscription_details['subscription'] # Extract the checkout_intent ID from the related subscription. @@ -78,11 +80,12 @@ def invoice_paid(event: stripe.Event) -> None: logger.info( 'Found existing CheckoutIntent record with ' f'id={checkout_intent_id}, ' + f'stripe_customer_id={stripe_customer_id}, ' f'stripe_checkout_session_id={checkout_intent.stripe_checkout_session_id}, ' f'state={checkout_intent.state}. ' 'Marking intent as paid...' ) - checkout_intent.mark_as_paid() + checkout_intent.mark_as_paid(stripe_customer_id=stripe_customer_id) @on_stripe_event('customer.subscription.trial_will_end') @staticmethod diff --git a/enterprise_access/apps/customer_billing/tests/test_api.py b/enterprise_access/apps/customer_billing/tests/test_api.py index 16b40549..bfe69ff2 100644 --- a/enterprise_access/apps/customer_billing/tests/test_api.py +++ b/enterprise_access/apps/customer_billing/tests/test_api.py @@ -96,7 +96,10 @@ def test_create_free_trial_checkout_session_success( mock_lms_client = mock_lms_client_class.return_value mock_lms_client.get_lms_user_account.return_value = [{'id': self.user.lms_user_id}] mock_lms_client.get_enterprise_customer_data.side_effect = raise_404_error - mock_stripe.checkout.Session.create.return_value = {'id': 'test-stripe-checkout-session'} + mock_stripe.checkout.Session.create.return_value = { + 'id': 'test-stripe-checkout-session', + 'customer': 'cust-123', + } mock_stripe.Customer.search.return_value.data = [] # Actually call the API under test. @@ -112,7 +115,7 @@ def test_create_free_trial_checkout_session_success( # Assert API response. self.assertEqual( result, - {'id': 'test-stripe-checkout-session'}, + {'id': 'test-stripe-checkout-session', 'customer': 'cust-123'}, ) # Assert that a CheckoutIntent was created @@ -121,6 +124,7 @@ def test_create_free_trial_checkout_session_success( self.assertEqual(intent.enterprise_slug, 'my-sluggy') self.assertEqual(intent.enterprise_name, 'My Cool Company') self.assertEqual(intent.stripe_checkout_session_id, 'test-stripe-checkout-session') + self.assertEqual(intent.stripe_customer_id, 'cust-123') self.assertFalse(intent.is_expired()) # Assert library methods were called correctly. diff --git a/enterprise_access/apps/customer_billing/tests/test_models.py b/enterprise_access/apps/customer_billing/tests/test_models.py index 3ff07fec..755fc3d1 100644 --- a/enterprise_access/apps/customer_billing/tests/test_models.py +++ b/enterprise_access/apps/customer_billing/tests/test_models.py @@ -539,3 +539,89 @@ def test_create_intent_with_terms_metadata(self): terms_metadata=complex_metadata ) self.assertEqual(intent3.terms_metadata, complex_metadata) + + def test_mark_as_paid_with_stripe_customer_id(self): + """Test mark_as_paid with stripe_customer_id parameter.""" + intent = CheckoutIntent.create_intent( + user=cast(AbstractUser, self.user1), + slug=self.basic_data['enterprise_slug'], + name=self.basic_data['enterprise_name'], + quantity=self.basic_data['quantity'] + ) + + # Test marking as paid with both session_id and customer_id + intent.mark_as_paid('cs_test_123', 'cus_test_456') + self.assertEqual(intent.state, CheckoutIntentState.PAID) + self.assertEqual(intent.stripe_checkout_session_id, 'cs_test_123') + self.assertEqual(intent.stripe_customer_id, 'cus_test_456') + + def test_mark_as_paid_with_only_stripe_customer_id(self): + """Test mark_as_paid with only stripe_customer_id parameter.""" + intent = CheckoutIntent.create_intent( + user=cast(AbstractUser, self.user2), + slug='another-slug', + name='Another Enterprise', + quantity=7 + ) + + # Test marking as paid with only customer_id + intent.mark_as_paid(stripe_customer_id='cus_test_789') + self.assertEqual(intent.state, CheckoutIntentState.PAID) + self.assertIsNone(intent.stripe_checkout_session_id) + self.assertEqual(intent.stripe_customer_id, 'cus_test_789') + + def test_mark_as_paid_idempotent_with_stripe_customer_id(self): + """Test that mark_as_paid is idempotent when called with same stripe_customer_id.""" + intent = CheckoutIntent.create_intent( + user=cast(AbstractUser, self.user1), + slug=self.basic_data['enterprise_slug'], + name=self.basic_data['enterprise_name'], + quantity=self.basic_data['quantity'] + ) + + # Mark as paid first time + intent.mark_as_paid('cs_test_123', 'cus_test_456') + first_modified = intent.modified + + # Mark as paid again with same values - should be idempotent + intent.mark_as_paid('cs_test_123', 'cus_test_456') + self.assertEqual(intent.state, CheckoutIntentState.PAID) + self.assertEqual(intent.stripe_checkout_session_id, 'cs_test_123') + self.assertEqual(intent.stripe_customer_id, 'cus_test_456') + # Modified time should have changed since we called save() + self.assertGreater(intent.modified, first_modified) + + def test_mark_as_paid_different_stripe_customer_id_raises_error(self): + """Test that mark_as_paid raises error when called with different stripe_customer_id.""" + intent = CheckoutIntent.create_intent( + user=cast(AbstractUser, self.user1), + slug=self.basic_data['enterprise_slug'], + name=self.basic_data['enterprise_name'], + quantity=self.basic_data['quantity'] + ) + + # Mark as paid first time + intent.mark_as_paid(stripe_customer_id='cus_test_456') + + # Try to mark as paid with different customer_id - should raise ValueError + with self.assertRaises(ValueError) as context: + intent.mark_as_paid(stripe_customer_id='cus_test_different') + + self.assertIn('Cannot transition from PAID to PAID with a different stripe_customer_id', str(context.exception)) + + def test_mark_as_paid_update_fields_includes_stripe_customer_id(self): + """Test that save() includes stripe_customer_id in update_fields.""" + intent = CheckoutIntent.create_intent( + user=cast(AbstractUser, self.user1), + slug=self.basic_data['enterprise_slug'], + name=self.basic_data['enterprise_name'], + quantity=self.basic_data['quantity'] + ) + + with mock.patch.object(intent, 'save') as mock_save: + intent.mark_as_paid(stripe_customer_id='cus_test_456') + + # Verify save was called with correct update_fields + mock_save.assert_called_once_with( + update_fields=['state', 'stripe_checkout_session_id', 'stripe_customer_id', 'modified'] + ) diff --git a/enterprise_access/apps/customer_billing/tests/test_stripe_event_handlers.py b/enterprise_access/apps/customer_billing/tests/test_stripe_event_handlers.py index b7418a40..95b832c0 100644 --- a/enterprise_access/apps/customer_billing/tests/test_stripe_event_handlers.py +++ b/enterprise_access/apps/customer_billing/tests/test_stripe_event_handlers.py @@ -122,18 +122,21 @@ def test_invoice_paid_handler( ): """Test various scenarios for the invoice.paid event handler.""" if checkout_intent_state == CheckoutIntentState.PAID: - self.checkout_intent.update_stripe_session_id(self.stripe_checkout_session_id) - self.checkout_intent.mark_as_paid() + self.checkout_intent.mark_as_paid( + stripe_session_id=self.stripe_checkout_session_id, + stripe_customer_id='cus_test_789', + ) subscription_id = 'sub_test_123456' mock_subscription = self._create_mock_stripe_subscription(intent_id_override or self.checkout_intent.id) invoice_data = { 'id': 'in_test_123456', + 'customer': 'cus_test_789', 'parent': { 'subscription_details': { 'metadata': mock_subscription, 'subscription': subscription_id - } + }, }, } @@ -159,6 +162,7 @@ def test_invoice_paid_handler( mock_logger.info.assert_any_call( 'Found existing CheckoutIntent record with ' f'id={self.checkout_intent.id}, ' + f'stripe_customer_id=cus_test_789, ' f'stripe_checkout_session_id={self.checkout_intent.stripe_checkout_session_id}, ' f'state={checkout_intent_state}. ' 'Marking intent as paid...' @@ -166,3 +170,77 @@ def test_invoice_paid_handler( mock_logger.info.assert_any_call( f'[StripeEventHandler] handler for complete.' ) + + @mock.patch('enterprise_access.apps.customer_billing.stripe_event_handlers.logger') + def test_invoice_paid_handler_sets_stripe_customer_id(self, mock_logger): + """Test that invoice.paid handler correctly sets stripe_customer_id on CheckoutIntent.""" + subscription_id = 'sub_test_customer_id_123' + stripe_customer_id = 'cus_test_customer_456' + mock_subscription = self._create_mock_stripe_subscription(self.checkout_intent.id) + + invoice_data = { + 'id': 'in_test_customer_123', + 'customer': stripe_customer_id, + 'parent': { + 'subscription_details': { + 'metadata': mock_subscription, + 'subscription': subscription_id + } + }, + } + + mock_event = self._create_mock_stripe_event('invoice.paid', invoice_data) + + # Verify initial state + self.assertEqual(self.checkout_intent.state, CheckoutIntentState.CREATED) + self.assertIsNone(self.checkout_intent.stripe_customer_id) + + # Handle the event + StripeEventHandler.dispatch(mock_event) + + # Verify the checkout intent was updated correctly + self.checkout_intent.refresh_from_db() + self.assertEqual(self.checkout_intent.state, CheckoutIntentState.PAID) + self.assertEqual(self.checkout_intent.stripe_customer_id, stripe_customer_id) + + # Verify logging includes the customer_id + mock_logger.info.assert_any_call( + 'Found existing CheckoutIntent record with ' + f'id={self.checkout_intent.id}, ' + f'stripe_customer_id={stripe_customer_id}, ' + f'stripe_checkout_session_id={self.checkout_intent.stripe_checkout_session_id}, ' + f'state={CheckoutIntentState.CREATED}. ' + 'Marking intent as paid...' + ) + + def test_invoice_paid_handler_idempotent_with_same_customer_id(self): + """Test that invoice.paid handler is idempotent when called with same stripe_customer_id.""" + subscription_id = 'sub_test_idempotent_123' + stripe_customer_id = 'cus_test_idempotent_456' + mock_subscription = self._create_mock_stripe_subscription(self.checkout_intent.id) + + # First mark the intent as paid with the customer_id + self.checkout_intent.mark_as_paid(stripe_customer_id=stripe_customer_id) + self.assertEqual(self.checkout_intent.state, CheckoutIntentState.PAID) + self.assertEqual(self.checkout_intent.stripe_customer_id, stripe_customer_id) + + invoice_data = { + 'id': 'in_test_idempotent_123', + 'customer': stripe_customer_id, + 'parent': { + 'subscription_details': { + 'metadata': mock_subscription, + 'subscription': subscription_id + } + }, + } + + mock_event = self._create_mock_stripe_event('invoice.paid', invoice_data) + + # Handle the event - should be idempotent + StripeEventHandler.dispatch(mock_event) + + # Verify the checkout intent state remains unchanged + self.checkout_intent.refresh_from_db() + self.assertEqual(self.checkout_intent.state, CheckoutIntentState.PAID) + self.assertEqual(self.checkout_intent.stripe_customer_id, stripe_customer_id)