Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/shield-controller/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added

- Add two new controller state metadata properties: `includeInStateLogs` and `usedInUi` ([#6504](https://github.com/MetaMask/core/pull/6504))
- Add signature coverage checking ([#6501](https://github.com/MetaMask/core/pull/6501))

### Changed

Expand Down
2 changes: 2 additions & 0 deletions packages/shield-controller/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
"@lavamoat/allow-scripts": "^3.0.4",
"@lavamoat/preinstall-always-fail": "^2.1.0",
"@metamask/auto-changelog": "^3.4.4",
"@metamask/signature-controller": "^33.0.0",
"@metamask/transaction-controller": "^60.3.0",
"@ts-bridge/cli": "^0.6.1",
"@types/jest": "^27.4.1",
Expand All @@ -67,6 +68,7 @@
"uuid": "^8.3.2"
},
"peerDependencies": {
"@metamask/signature-controller": "^33.0.0",
"@metamask/transaction-controller": "^60.0.0"
},
"engines": {
Expand Down
76 changes: 75 additions & 1 deletion packages/shield-controller/src/ShieldController.test.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import { deriveStateFromMetadata } from '@metamask/base-controller';
import type { SignatureControllerState } from '@metamask/signature-controller';
import type { TransactionControllerState } from '@metamask/transaction-controller';

import { ShieldController } from './ShieldController';
import { createMockBackend } from '../tests/mocks/backend';
import { createMockMessenger } from '../tests/mocks/messenger';
import { generateMockTxMeta } from '../tests/utils';
import {
generateMockSignatureRequest,
generateMockTxMeta,
} from '../tests/utils';

/**
* Sets up a ShieldController for testing.
Expand Down Expand Up @@ -144,6 +148,76 @@ describe('ShieldController', () => {
});
});

describe('checkSignatureCoverage', () => {
it('should check signature coverage', async () => {
const { baseMessenger, backend } = setup();
const signatureRequest = generateMockSignatureRequest();
const coverageResultReceived = new Promise<void>((resolve) => {
baseMessenger.subscribe(
'ShieldController:coverageResultReceived',
(_coverageResult) => resolve(),
);
});
baseMessenger.publish(
'SignatureController:stateChange',
{
signatureRequests: { [signatureRequest.id]: signatureRequest },
} as SignatureControllerState,
undefined as never,
);
expect(await coverageResultReceived).toBeUndefined();
expect(backend.checkSignatureCoverage).toHaveBeenCalledWith(
signatureRequest,
);
});
});

it('should check coverage for multiple signature request', async () => {
const { baseMessenger, backend } = setup();
const signatureRequest1 = generateMockSignatureRequest();
const coverageResultReceived1 = new Promise<void>((resolve) => {
baseMessenger.subscribe(
'ShieldController:coverageResultReceived',
(_coverageResult) => resolve(),
);
});
baseMessenger.publish(
'SignatureController:stateChange',
{
signatureRequests: {
[signatureRequest1.id]: signatureRequest1,
},
} as SignatureControllerState,
undefined as never,
);
expect(await coverageResultReceived1).toBeUndefined();
expect(backend.checkSignatureCoverage).toHaveBeenCalledWith(
signatureRequest1,
);

const signatureRequest2 = generateMockSignatureRequest();
const coverageResultReceived2 = new Promise<void>((resolve) => {
baseMessenger.subscribe(
'ShieldController:coverageResultReceived',
(_coverageResult) => resolve(),
);
});
baseMessenger.publish(
'SignatureController:stateChange',
{
signatureRequests: {
[signatureRequest2.id]: signatureRequest2,
},
} as SignatureControllerState,
undefined as never,
);

expect(await coverageResultReceived2).toBeUndefined();
expect(backend.checkSignatureCoverage).toHaveBeenCalledWith(
signatureRequest2,
);
});

describe('metadata', () => {
it('includes expected state in debug snapshots', () => {
const { controller } = setup();
Expand Down
84 changes: 80 additions & 4 deletions packages/shield-controller/src/ShieldController.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@ import type {
ControllerStateChangeEvent,
RestrictedMessenger,
} from '@metamask/base-controller';
import {
SignatureRequestType,
type SignatureRequest,
type SignatureStateChange,
} from '@metamask/signature-controller';
import type {
TransactionControllerStateChangeEvent,
TransactionMeta,
Expand Down Expand Up @@ -82,7 +87,9 @@ type AllowedActions = never;
/**
* The external events available to the ShieldController.
*/
type AllowedEvents = TransactionControllerStateChangeEvent;
type AllowedEvents =
| SignatureStateChange
| TransactionControllerStateChangeEvent;

/**
* The messenger of the {@link ShieldController}.
Expand Down Expand Up @@ -138,6 +145,11 @@ export class ShieldController extends BaseController<
previousTransactions: TransactionMeta[] | undefined,
) => void;

readonly #signatureControllerStateChangeHandler: (
signatureRequests: Record<string, SignatureRequest>,
previousSignatureRequests: Record<string, SignatureRequest> | undefined,
) => void;

constructor(options: ShieldControllerOptions) {
const {
messenger,
Expand All @@ -161,6 +173,8 @@ export class ShieldController extends BaseController<
this.#transactionHistoryLimit = transactionHistoryLimit;
this.#transactionControllerStateChangeHandler =
this.#handleTransactionControllerStateChange.bind(this);
this.#signatureControllerStateChangeHandler =
this.#handleSignatureControllerStateChange.bind(this);
}

start() {
Expand All @@ -169,13 +183,54 @@ export class ShieldController extends BaseController<
this.#transactionControllerStateChangeHandler,
(state) => state.transactions,
);

this.messagingSystem.subscribe(
'SignatureController:stateChange',
this.#signatureControllerStateChangeHandler,
(state) => state.signatureRequests,
);
}

stop() {
this.messagingSystem.unsubscribe(
'TransactionController:stateChange',
this.#transactionControllerStateChangeHandler,
);

this.messagingSystem.unsubscribe(
'SignatureController:stateChange',
this.#signatureControllerStateChangeHandler,
);
}

#handleSignatureControllerStateChange(
signatureRequests: Record<string, SignatureRequest>,
previousSignatureRequests: Record<string, SignatureRequest> | undefined,
) {
const signatureRequestsArray = Object.values(signatureRequests);
const previousSignatureRequestsArray = Object.values(
previousSignatureRequests ?? {},
);
const previousSignatureRequestsById = new Map<string, SignatureRequest>(
previousSignatureRequestsArray.map((request) => [request.id, request]),
);
for (const signatureRequest of signatureRequestsArray) {
const previousSignatureRequest = previousSignatureRequestsById.get(
signatureRequest.id,
);

// Check coverage if the signature request is new and has type
// `personal_sign`.
if (
!previousSignatureRequest &&
signatureRequest.type === SignatureRequestType.PersonalSign
) {
this.checkSignatureCoverage(signatureRequest).catch(
// istanbul ignore next
(error) => log('Error checking coverage:', error),
);
}
}
}

#handleTransactionControllerStateChange(
Expand Down Expand Up @@ -212,7 +267,7 @@ export class ShieldController extends BaseController<
*/
async checkCoverage(txMeta: TransactionMeta): Promise<CoverageResult> {
// Check coverage
const coverageResult = await this.#fetchCoverageResult(txMeta);
const coverageResult = await this.#backend.checkCoverage(txMeta);

// Publish coverage result
this.messagingSystem.publish(
Expand All @@ -226,8 +281,29 @@ export class ShieldController extends BaseController<
return coverageResult;
}

async #fetchCoverageResult(txMeta: TransactionMeta): Promise<CoverageResult> {
return this.#backend.checkCoverage(txMeta);
/**
* Checks the coverage of a signature request.
*
* @param signatureRequest - The signature request to check coverage for.
* @returns The coverage result.
*/
async checkSignatureCoverage(
signatureRequest: SignatureRequest,
): Promise<CoverageResult> {
// Check coverage
const coverageResult =
await this.#backend.checkSignatureCoverage(signatureRequest);

// Publish coverage result
this.messagingSystem.publish(
`${controllerName}:coverageResultReceived`,
coverageResult,
);

// Update state
this.#addCoverageResult(signatureRequest.id, coverageResult);

return coverageResult;
}

#addCoverageResult(txId: string, coverageResult: CoverageResult) {
Expand Down
42 changes: 41 additions & 1 deletion packages/shield-controller/src/backend.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import { ShieldRemoteBackend } from './backend';
import { generateMockTxMeta, getRandomCoverageStatus } from '../tests/utils';
import {
generateMockSignatureRequest,
generateMockTxMeta,
getRandomCoverageStatus,
} from '../tests/utils';

/**
* Setup the test environment.
Expand Down Expand Up @@ -141,4 +145,40 @@ describe('ShieldRemoteBackend', () => {
// that the polling loop is exited as expected.
await new Promise((resolve) => setTimeout(resolve, 10));
});

describe('checkSignatureCoverage', () => {
it('should check signature coverage', async () => {
const { backend, fetchMock, getAccessToken } = setup();

// Mock init coverage check.
fetchMock.mockResolvedValueOnce({
status: 200,
json: jest.fn().mockResolvedValue({ coverageId: 'coverageId' }),
} as unknown as Response);

// Mock get coverage result.
const status = getRandomCoverageStatus();
fetchMock.mockResolvedValueOnce({
status: 200,
json: jest.fn().mockResolvedValue({ status }),
} as unknown as Response);

const signatureRequest = generateMockSignatureRequest();
const coverageResult =
await backend.checkSignatureCoverage(signatureRequest);
expect(coverageResult).toStrictEqual({ status });
expect(fetchMock).toHaveBeenCalledTimes(2);
expect(getAccessToken).toHaveBeenCalledTimes(2);
});

it('throws with invalid data', async () => {
const { backend } = setup();

const signatureRequest = generateMockSignatureRequest();
signatureRequest.messageParams.data = [];
await expect(
backend.checkSignatureCoverage(signatureRequest),
).rejects.toThrow('Signature data must be a string');
});
});
});
59 changes: 45 additions & 14 deletions packages/shield-controller/src/backend.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import type { SignatureRequest } from '@metamask/signature-controller';
import type { TransactionMeta } from '@metamask/transaction-controller';

import type { CoverageResult, CoverageStatus, ShieldBackend } from './types';
Expand All @@ -16,6 +17,14 @@ export type InitCoverageCheckRequest = {
origin?: string;
};

export type InitSignatureCoverageCheckRequest = {
chainId: string;
data: string;
from: string;
method: string;
origin?: string;
};

export type InitCoverageCheckResponse = {
coverageId: string;
};
Expand Down Expand Up @@ -59,9 +68,7 @@ export class ShieldRemoteBackend implements ShieldBackend {
this.#fetch = fetchFn;
}

checkCoverage: (txMeta: TransactionMeta) => Promise<CoverageResult> = async (
txMeta,
) => {
async checkCoverage(txMeta: TransactionMeta): Promise<CoverageResult> {
const reqBody: InitCoverageCheckRequest = {
txParams: [
{
Expand All @@ -76,22 +83,46 @@ export class ShieldRemoteBackend implements ShieldBackend {
origin: txMeta.origin,
};

const { coverageId } = await this.#initCoverageCheck(reqBody);
const { coverageId } = await this.#initCoverageCheck(
'v1/transaction/coverage/init',
reqBody,
);

return this.#getCoverageResult(coverageId);
};
}

async checkSignatureCoverage(
signatureRequest: SignatureRequest,
): Promise<CoverageResult> {
if (typeof signatureRequest.messageParams.data !== 'string') {
throw new Error('Signature data must be a string');
}

const reqBody: InitSignatureCoverageCheckRequest = {
chainId: signatureRequest.chainId,
data: signatureRequest.messageParams.data,
from: signatureRequest.messageParams.from,
method: signatureRequest.type,
origin: signatureRequest.messageParams.origin,
};

const { coverageId } = await this.#initCoverageCheck(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please note that we can't always except coverageId to be in the response, with the 200 status.

Please check the swagger, as it has updated recently. (please select schema tab under the 200 response, not the example value)

Screenshot 2025-09-11 at 6 56 50 PM

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The sample cases of 200 responses without coverageId ~

  • where origin is localhost or empty string (requests directly from metamask wallet)
  • when providing invalid chain id values

In such cases, we don't proceed to process the rulesets, we straight away return the Unknown result from the server.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This case also applies to transaction-coverage

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we throw an error in that case?
So basically now the init endpoint can also return a result? Kind of confusing. This creates unnecessary edge cases I believe.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we please change the API backend. I do see several issues with this change. The expectation so far is that if we have a coverage result, either for signatures, or for transactions, it always has the same format. If we don't have that, a lot of complexity is added to the whole system. (Currently we can just treat the results within the same data structure. Then we might need separate data structure. Also, missing the coverageId might be problematic because it means we can't look up the corresponding backend calls.)

'v1/signature/coverage/init',
reqBody,
);

return this.#getCoverageResult(coverageId);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Coverage ID Missing in API Response

The code assumes coverageId is always present in the initCoverageCheck response, but the API can return a 200 status without it (e.g., for localhost or invalid chain IDs). This results in coverageId being undefined, which causes #getCoverageResult to fail for both transaction and signature coverage.

Additional Locations (1)

Fix in Cursor Fix in Web

}

async #initCoverageCheck(
reqBody: InitCoverageCheckRequest,
path: string,
reqBody: unknown,
): Promise<InitCoverageCheckResponse> {
const res = await this.#fetch(
`${this.#baseUrl}/v1/transaction/coverage/init`,
{
method: 'POST',
headers: await this.#createHeaders(),
body: JSON.stringify(reqBody),
},
);
const res = await this.#fetch(`${this.#baseUrl}/${path}`, {
method: 'POST',
headers: await this.#createHeaders(),
body: JSON.stringify(reqBody),
});
if (res.status !== 200) {
throw new Error(`Failed to init coverage check: ${res.status}`);
}
Expand Down
Loading
Loading