Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/subscription-controller/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add methods `startSubscriptionWithCrypto` and `getCryptoApproveTransactionParams` method ([#6456](https://github.com/MetaMask/core/pull/6456))
- Added `triggerAccessTokenRefresh` to trigger an access token refresh ([#6374](https://github.com/MetaMask/core/pull/6374))
- Add two new controller state metadata properties: `includeInStateLogs` and `usedInUi` ([#6504](https://github.com/MetaMask/core/pull/6504))
- Added `updatePaymentMethodCard` and `updatePaymentMethodCrypto` methods ([#6539](https://github.com/MetaMask/core/pull/6539))

[Unreleased]: https://github.com/MetaMask/core/
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import type {
PricingPaymentMethod,
StartCryptoSubscriptionRequest,
StartCryptoSubscriptionResponse,
UpdatePaymentMethodOpts,
} from './types';
import {
PaymentType,
Expand Down Expand Up @@ -86,6 +87,12 @@ const MOCK_PRICE_INFO_RESPONSE: PricingResponse = {
paymentMethods: [MOCK_PRICING_PAYMENT_METHOD],
};

const MOCK_GET_SUBSCRIPTIONS_RESPONSE = {
customerId: 'cus_1',
subscriptions: [MOCK_SUBSCRIPTION],
trialedProducts: [],
};

/**
* Creates a custom subscription messenger, in case tests need different permissions
*
Expand Down Expand Up @@ -158,13 +165,17 @@ function createMockSubscriptionService() {
const mockStartSubscriptionWithCard = jest.fn();
const mockGetPricing = jest.fn();
const mockStartSubscriptionWithCrypto = jest.fn();
const mockUpdatePaymentMethodCard = jest.fn();
const mockUpdatePaymentMethodCrypto = jest.fn();

const mockService = {
getSubscriptions: mockGetSubscriptions,
cancelSubscription: mockCancelSubscription,
startSubscriptionWithCard: mockStartSubscriptionWithCard,
getPricing: mockGetPricing,
startSubscriptionWithCrypto: mockStartSubscriptionWithCrypto,
updatePaymentMethodCard: mockUpdatePaymentMethodCard,
updatePaymentMethodCrypto: mockUpdatePaymentMethodCrypto,
};

return {
Expand All @@ -174,6 +185,8 @@ function createMockSubscriptionService() {
mockStartSubscriptionWithCard,
mockGetPricing,
mockStartSubscriptionWithCrypto,
mockUpdatePaymentMethodCard,
mockUpdatePaymentMethodCrypto,
};
}

Expand Down Expand Up @@ -271,17 +284,13 @@ describe('SubscriptionController', () => {
describe('getSubscription', () => {
it('should fetch and store subscription successfully', async () => {
await withController(async ({ controller, mockService }) => {
mockService.getSubscriptions.mockResolvedValue({
customerId: 'cus_1',
subscriptions: [MOCK_SUBSCRIPTION],
trialedProducts: [],
});
mockService.getSubscriptions.mockResolvedValue(
MOCK_GET_SUBSCRIPTIONS_RESPONSE,
);

const result = await controller.getSubscriptions();

expect(result).toStrictEqual([MOCK_SUBSCRIPTION]);
// For backward compatibility during refactor, keep single subscription mirror if present
// but assert new state field
expect(controller.state.subscriptions).toStrictEqual([
MOCK_SUBSCRIPTION,
]);
Expand Down Expand Up @@ -881,4 +890,76 @@ describe('SubscriptionController', () => {
});
});
});

describe('updatePaymentMethod', () => {
it('should update card payment method successfully', async () => {
await withController(async ({ controller, mockService }) => {
mockService.updatePaymentMethodCard.mockResolvedValue({});
mockService.getSubscriptions.mockResolvedValue(
MOCK_GET_SUBSCRIPTIONS_RESPONSE,
);

await controller.updatePaymentMethod({
subscriptionId: 'sub_123456789',
paymentType: PaymentType.byCard,
recurringInterval: RecurringInterval.month,
});

expect(mockService.updatePaymentMethodCard).toHaveBeenCalledWith({
subscriptionId: 'sub_123456789',
recurringInterval: RecurringInterval.month,
});

expect(controller.state.subscriptions).toStrictEqual([
MOCK_SUBSCRIPTION,
]);
});
});

it('should update crypto payment method successfully', async () => {
await withController(async ({ controller, mockService }) => {
mockService.updatePaymentMethodCrypto.mockResolvedValue({});
mockService.getSubscriptions.mockResolvedValue(
MOCK_GET_SUBSCRIPTIONS_RESPONSE,
);

const opts: UpdatePaymentMethodOpts = {
paymentType: PaymentType.byCrypto,
subscriptionId: 'sub_123456789',
chainId: '0x1',
payerAddress: '0x0000000000000000000000000000000000000001',
tokenSymbol: 'USDC',
rawTransaction: '0xdeadbeef',
recurringInterval: RecurringInterval.month,
billingCycles: 3,
};

await controller.updatePaymentMethod(opts);

const req = {
...opts,
paymentType: undefined,
};
expect(mockService.updatePaymentMethodCrypto).toHaveBeenCalledWith(req);

expect(controller.state.subscriptions).toStrictEqual([
MOCK_SUBSCRIPTION,
]);
});
});

it('throws when invalid payment type', async () => {
await withController(async ({ controller }) => {
const opts = {
subscriptionId: 'sub_123456789',
paymentType: 'invalid',
recurringInterval: RecurringInterval.month,
};
// @ts-expect-error Intentionally testing with invalid payment type.
await expect(controller.updatePaymentMethod(opts)).rejects.toThrow(
'Invalid payment type',
);
});
});
});
});
26 changes: 25 additions & 1 deletion packages/subscription-controller/src/SubscriptionController.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import type {
ProductPrice,
StartCryptoSubscriptionRequest,
TokenPaymentInfo,
UpdatePaymentMethodOpts,
} from './types';
import {
PaymentType,
Expand Down Expand Up @@ -57,6 +58,10 @@ export type SubscriptionControllerStartSubscriptionWithCryptoAction = {
type: `${typeof controllerName}:startSubscriptionWithCrypto`;
handler: SubscriptionController['startSubscriptionWithCrypto'];
};
export type SubscriptionControllerUpdatePaymentMethodAction = {
type: `${typeof controllerName}:updatePaymentMethod`;
handler: SubscriptionController['updatePaymentMethod'];
};

export type SubscriptionControllerGetStateAction = ControllerGetStateAction<
typeof controllerName,
Expand All @@ -69,7 +74,8 @@ export type SubscriptionControllerActions =
| SubscriptionControllerGetPricingAction
| SubscriptionControllerGetStateAction
| SubscriptionControllerGetCryptoApproveTransactionParamsAction
| SubscriptionControllerStartSubscriptionWithCryptoAction;
| SubscriptionControllerStartSubscriptionWithCryptoAction
| SubscriptionControllerUpdatePaymentMethodAction;

export type AllowedActions =
| AuthenticationController.AuthenticationControllerGetBearerToken
Expand Down Expand Up @@ -208,6 +214,11 @@ export class SubscriptionController extends BaseController<
'SubscriptionController:startSubscriptionWithCrypto',
this.startSubscriptionWithCrypto.bind(this),
);

this.messagingSystem.registerActionHandler(
'SubscriptionController:updatePaymentMethod',
this.updatePaymentMethod.bind(this),
);
}

/**
Expand Down Expand Up @@ -322,6 +333,19 @@ export class SubscriptionController extends BaseController<
};
}

async updatePaymentMethod(opts: UpdatePaymentMethodOpts) {
if (opts.paymentType === PaymentType.byCard) {
const { paymentType, ...cardRequest } = opts;
await this.#subscriptionService.updatePaymentMethodCard(cardRequest);
} else if (opts.paymentType === PaymentType.byCrypto) {
const { paymentType, ...cryptoRequest } = opts;
await this.#subscriptionService.updatePaymentMethodCrypto(cryptoRequest);
} else {
throw new Error('Invalid payment type');
}
await this.getSubscriptions();
}

/**
* Calculate total subscription price amount from price info
* e.g: $8 per month * 12 months min billing cycles = $96
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@ import {
SubscriptionControllerErrorMessage,
} from './constants';
import { SubscriptionServiceError } from './errors';
import { SubscriptionService } from './SubscriptionService';
import { SUBSCRIPTION_URL, SubscriptionService } from './SubscriptionService';
import type {
StartSubscriptionRequest,
StartCryptoSubscriptionRequest,
Subscription,
PricingResponse,
UpdatePaymentMethodCardRequest,
UpdatePaymentMethodCryptoRequest,
} from './types';
import {
PaymentType,
Expand Down Expand Up @@ -59,6 +61,11 @@ const MOCK_START_SUBSCRIPTION_RESPONSE = {
checkoutSessionUrl: 'https://checkout.example.com/session/123',
};

const MOCK_HEADERS = {
'Content-Type': 'application/json',
Authorization: `Bearer ${MOCK_ACCESS_TOKEN}`,
};

/**
* Creates a mock subscription service config for testing
*
Expand Down Expand Up @@ -291,4 +298,67 @@ describe('SubscriptionService', () => {
expect(result).toStrictEqual(mockPricingResponse);
});
});

describe('updatePaymentMethodCard', () => {
it('should update card payment method successfully', async () => {
await withMockSubscriptionService(async ({ service, config }) => {
const request: UpdatePaymentMethodCardRequest = {
subscriptionId: 'sub_123456789',
recurringInterval: RecurringInterval.month,
};

handleFetchMock.mockResolvedValue({});

await service.updatePaymentMethodCard(request);

expect(handleFetchMock).toHaveBeenCalledWith(
SUBSCRIPTION_URL(
config.env,
'subscriptions/sub_123456789/payment-method/card',
),
{
method: 'PATCH',
headers: MOCK_HEADERS,
body: JSON.stringify({
...request,
subscriptionId: undefined,
}),
},
);
});
});

it('should update crypto payment method successfully', async () => {
await withMockSubscriptionService(async ({ service, config }) => {
const request: UpdatePaymentMethodCryptoRequest = {
subscriptionId: 'sub_123456789',
chainId: '0x1',
payerAddress: '0x0000000000000000000000000000000000000001',
tokenSymbol: 'USDC',
rawTransaction: '0xdeadbeef',
recurringInterval: RecurringInterval.month,
billingCycles: 3,
};

handleFetchMock.mockResolvedValue({});

await service.updatePaymentMethodCrypto(request);

expect(handleFetchMock).toHaveBeenCalledWith(
SUBSCRIPTION_URL(
config.env,
'subscriptions/sub_123456789/payment-method/crypto',
),
{
method: 'PATCH',
headers: MOCK_HEADERS,
body: JSON.stringify({
...request,
subscriptionId: undefined,
}),
},
);
});
});
});
});
18 changes: 18 additions & 0 deletions packages/subscription-controller/src/SubscriptionService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ import type {
StartCryptoSubscriptionResponse,
StartSubscriptionRequest,
StartSubscriptionResponse,
UpdatePaymentMethodCardRequest,
UpdatePaymentMethodCryptoRequest,
} from './types';

export type SubscriptionServiceConfig = {
Expand Down Expand Up @@ -65,6 +67,22 @@ export class SubscriptionService implements ISubscriptionService {
return await this.#makeRequest(path, 'POST', request);
}

async updatePaymentMethodCard(request: UpdatePaymentMethodCardRequest) {
const path = `subscriptions/${request.subscriptionId}/payment-method/card`;
await this.#makeRequest(path, 'PATCH', {
...request,
subscriptionId: undefined,
});
}

async updatePaymentMethodCrypto(request: UpdatePaymentMethodCryptoRequest) {
const path = `subscriptions/${request.subscriptionId}/payment-method/crypto`;
await this.#makeRequest(path, 'PATCH', {
...request,
subscriptionId: undefined,
});
}

async #makeRequest<Result>(
path: string,
method: 'GET' | 'POST' | 'DELETE' | 'PUT' | 'PATCH' = 'GET',
Expand Down
1 change: 1 addition & 0 deletions packages/subscription-controller/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ export type {
Currency,
PricingPaymentMethod,
PricingResponse,
UpdatePaymentMethodOpts,
} from './types';
export { SubscriptionServiceError } from './errors';
export { Env, SubscriptionControllerErrorMessage } from './constants';
Expand Down
36 changes: 36 additions & 0 deletions packages/subscription-controller/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -189,4 +189,40 @@ export type ISubscriptionService = {
startSubscriptionWithCrypto(
request: StartCryptoSubscriptionRequest,
): Promise<StartCryptoSubscriptionResponse>;
updatePaymentMethodCard(
request: UpdatePaymentMethodCardRequest,
): Promise<void>;
updatePaymentMethodCrypto(
request: UpdatePaymentMethodCryptoRequest,
): Promise<void>;
};

export type UpdatePaymentMethodOpts =
| ({
paymentType: PaymentType.byCard;
} & UpdatePaymentMethodCardRequest)
| ({
paymentType: PaymentType.byCrypto;
} & UpdatePaymentMethodCryptoRequest);

export type UpdatePaymentMethodCardRequest = {
/**
* Subscription ID
*/
subscriptionId: string;

/**
* Recurring interval
*/
recurringInterval: RecurringInterval;
};

export type UpdatePaymentMethodCryptoRequest = {
subscriptionId: string;
chainId: Hex;
payerAddress: Hex;
tokenSymbol: string;
rawTransaction: Hex;
recurringInterval: RecurringInterval;
billingCycles: number;
};
Loading