Skip to content

Commit bdc6576

Browse files
committed
check signature coverage
1 parent a0737cb commit bdc6576

File tree

12 files changed

+282
-21
lines changed

12 files changed

+282
-21
lines changed

packages/shield-controller/package.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
"uuid": "^8.3.2"
6868
},
6969
"peerDependencies": {
70+
"@metamask/signature-controller": "^33.0.0",
7071
"@metamask/transaction-controller": "^60.0.0"
7172
},
7273
"engines": {

packages/shield-controller/src/ShieldController.test.ts

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
import { deriveStateFromMetadata } from '@metamask/base-controller';
2+
import type { SignatureControllerState } from '@metamask/signature-controller';
23
import type { TransactionControllerState } from '@metamask/transaction-controller';
34

45
import { ShieldController } from './ShieldController';
56
import { createMockBackend } from '../tests/mocks/backend';
67
import { createMockMessenger } from '../tests/mocks/messenger';
7-
import { generateMockTxMeta } from '../tests/utils';
8+
import {
9+
generateMockSignatureRequest,
10+
generateMockTxMeta,
11+
} from '../tests/utils';
812

913
/**
1014
* Sets up a ShieldController for testing.
@@ -144,6 +148,76 @@ describe('ShieldController', () => {
144148
});
145149
});
146150

151+
describe('checkSignatureCoverage', () => {
152+
it('should check signature coverage', async () => {
153+
const { baseMessenger, backend } = setup();
154+
const signatureRequest = generateMockSignatureRequest();
155+
const coverageResultReceived = new Promise<void>((resolve) => {
156+
baseMessenger.subscribe(
157+
'ShieldController:coverageResultReceived',
158+
(_coverageResult) => resolve(),
159+
);
160+
});
161+
baseMessenger.publish(
162+
'SignatureController:stateChange',
163+
{
164+
signatureRequests: { [signatureRequest.id]: signatureRequest },
165+
} as SignatureControllerState,
166+
undefined as never,
167+
);
168+
expect(await coverageResultReceived).toBeUndefined();
169+
expect(backend.checkSignatureCoverage).toHaveBeenCalledWith(
170+
signatureRequest,
171+
);
172+
});
173+
});
174+
175+
it('should check coverage for multiple signature request', async () => {
176+
const { baseMessenger, backend } = setup();
177+
const signatureRequest1 = generateMockSignatureRequest();
178+
const coverageResultReceived1 = new Promise<void>((resolve) => {
179+
baseMessenger.subscribe(
180+
'ShieldController:coverageResultReceived',
181+
(_coverageResult) => resolve(),
182+
);
183+
});
184+
baseMessenger.publish(
185+
'SignatureController:stateChange',
186+
{
187+
signatureRequests: {
188+
[signatureRequest1.id]: signatureRequest1,
189+
},
190+
} as SignatureControllerState,
191+
undefined as never,
192+
);
193+
expect(await coverageResultReceived1).toBeUndefined();
194+
expect(backend.checkSignatureCoverage).toHaveBeenCalledWith(
195+
signatureRequest1,
196+
);
197+
198+
const signatureRequest2 = generateMockSignatureRequest();
199+
const coverageResultReceived2 = new Promise<void>((resolve) => {
200+
baseMessenger.subscribe(
201+
'ShieldController:coverageResultReceived',
202+
(_coverageResult) => resolve(),
203+
);
204+
});
205+
baseMessenger.publish(
206+
'SignatureController:stateChange',
207+
{
208+
signatureRequests: {
209+
[signatureRequest2.id]: signatureRequest2,
210+
},
211+
} as SignatureControllerState,
212+
undefined as never,
213+
);
214+
215+
expect(await coverageResultReceived2).toBeUndefined();
216+
expect(backend.checkSignatureCoverage).toHaveBeenCalledWith(
217+
signatureRequest2,
218+
);
219+
});
220+
147221
describe('metadata', () => {
148222
it('includes expected state in debug snapshots', () => {
149223
const { controller } = setup();

packages/shield-controller/src/ShieldController.ts

Lines changed: 80 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,11 @@ import type {
33
ControllerStateChangeEvent,
44
RestrictedMessenger,
55
} from '@metamask/base-controller';
6+
import {
7+
SignatureRequestType,
8+
type SignatureRequest,
9+
type SignatureStateChange,
10+
} from '@metamask/signature-controller';
611
import type {
712
TransactionControllerStateChangeEvent,
813
TransactionMeta,
@@ -82,7 +87,9 @@ type AllowedActions = never;
8287
/**
8388
* The external events available to the ShieldController.
8489
*/
85-
type AllowedEvents = TransactionControllerStateChangeEvent;
90+
type AllowedEvents =
91+
| SignatureStateChange
92+
| TransactionControllerStateChangeEvent;
8693

8794
/**
8895
* The messenger of the {@link ShieldController}.
@@ -138,6 +145,11 @@ export class ShieldController extends BaseController<
138145
previousTransactions: TransactionMeta[] | undefined,
139146
) => void;
140147

148+
readonly #signatureControllerStateChangeHandler: (
149+
signatureRequests: Record<string, SignatureRequest>,
150+
previousSignatureRequests: Record<string, SignatureRequest> | undefined,
151+
) => void;
152+
141153
constructor(options: ShieldControllerOptions) {
142154
const {
143155
messenger,
@@ -161,6 +173,8 @@ export class ShieldController extends BaseController<
161173
this.#transactionHistoryLimit = transactionHistoryLimit;
162174
this.#transactionControllerStateChangeHandler =
163175
this.#handleTransactionControllerStateChange.bind(this);
176+
this.#signatureControllerStateChangeHandler =
177+
this.#handleSignatureControllerStateChange.bind(this);
164178
}
165179

166180
start() {
@@ -169,13 +183,54 @@ export class ShieldController extends BaseController<
169183
this.#transactionControllerStateChangeHandler,
170184
(state) => state.transactions,
171185
);
186+
187+
this.messagingSystem.subscribe(
188+
'SignatureController:stateChange',
189+
this.#signatureControllerStateChangeHandler,
190+
(state) => state.signatureRequests,
191+
);
172192
}
173193

174194
stop() {
175195
this.messagingSystem.unsubscribe(
176196
'TransactionController:stateChange',
177197
this.#transactionControllerStateChangeHandler,
178198
);
199+
200+
this.messagingSystem.unsubscribe(
201+
'SignatureController:stateChange',
202+
this.#signatureControllerStateChangeHandler,
203+
);
204+
}
205+
206+
#handleSignatureControllerStateChange(
207+
signatureRequests: Record<string, SignatureRequest>,
208+
previousSignatureRequests: Record<string, SignatureRequest> | undefined,
209+
) {
210+
const signatureRequestsArray = Object.values(signatureRequests);
211+
const previousSignatureRequestsArray = Object.values(
212+
previousSignatureRequests ?? {},
213+
);
214+
const previousSignatureRequestsById = new Map<string, SignatureRequest>(
215+
previousSignatureRequestsArray.map((request) => [request.id, request]),
216+
);
217+
for (const signatureRequest of signatureRequestsArray) {
218+
const previousSignatureRequest = previousSignatureRequestsById.get(
219+
signatureRequest.id,
220+
);
221+
222+
// Check coverage if the signature request is new and has type
223+
// `personal_sign`.
224+
if (
225+
!previousSignatureRequest &&
226+
signatureRequest.type === SignatureRequestType.PersonalSign
227+
) {
228+
this.checkSignatureCoverage(signatureRequest).catch(
229+
// istanbul ignore next
230+
(error) => log('Error checking coverage:', error),
231+
);
232+
}
233+
}
179234
}
180235

181236
#handleTransactionControllerStateChange(
@@ -212,7 +267,7 @@ export class ShieldController extends BaseController<
212267
*/
213268
async checkCoverage(txMeta: TransactionMeta): Promise<CoverageResult> {
214269
// Check coverage
215-
const coverageResult = await this.#fetchCoverageResult(txMeta);
270+
const coverageResult = await this.#backend.checkCoverage(txMeta);
216271

217272
// Publish coverage result
218273
this.messagingSystem.publish(
@@ -226,8 +281,29 @@ export class ShieldController extends BaseController<
226281
return coverageResult;
227282
}
228283

229-
async #fetchCoverageResult(txMeta: TransactionMeta): Promise<CoverageResult> {
230-
return this.#backend.checkCoverage(txMeta);
284+
/**
285+
* Checks the coverage of a signature request.
286+
*
287+
* @param signatureRequest - The signature request to check coverage for.
288+
* @returns The coverage result.
289+
*/
290+
async checkSignatureCoverage(
291+
signatureRequest: SignatureRequest,
292+
): Promise<CoverageResult> {
293+
// Check coverage
294+
const coverageResult =
295+
await this.#backend.checkSignatureCoverage(signatureRequest);
296+
297+
// Publish coverage result
298+
this.messagingSystem.publish(
299+
`${controllerName}:coverageResultReceived`,
300+
coverageResult,
301+
);
302+
303+
// Update state
304+
this.#addCoverageResult(signatureRequest.id, coverageResult);
305+
306+
return coverageResult;
231307
}
232308

233309
#addCoverageResult(txId: string, coverageResult: CoverageResult) {

packages/shield-controller/src/backend.test.ts

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
import { ShieldRemoteBackend } from './backend';
2-
import { generateMockTxMeta, getRandomCoverageStatus } from '../tests/utils';
2+
import {
3+
generateMockSignatureRequest,
4+
generateMockTxMeta,
5+
getRandomCoverageStatus,
6+
} from '../tests/utils';
37

48
/**
59
* Setup the test environment.
@@ -141,4 +145,40 @@ describe('ShieldRemoteBackend', () => {
141145
// that the polling loop is exited as expected.
142146
await new Promise((resolve) => setTimeout(resolve, 10));
143147
});
148+
149+
describe('checkSignatureCoverage', () => {
150+
it('should check signature coverage', async () => {
151+
const { backend, fetchMock, getAccessToken } = setup();
152+
153+
// Mock init coverage check.
154+
fetchMock.mockResolvedValueOnce({
155+
status: 200,
156+
json: jest.fn().mockResolvedValue({ coverageId: 'coverageId' }),
157+
} as unknown as Response);
158+
159+
// Mock get coverage result.
160+
const status = getRandomCoverageStatus();
161+
fetchMock.mockResolvedValueOnce({
162+
status: 200,
163+
json: jest.fn().mockResolvedValue({ status }),
164+
} as unknown as Response);
165+
166+
const signatureRequest = generateMockSignatureRequest();
167+
const coverageResult =
168+
await backend.checkSignatureCoverage(signatureRequest);
169+
expect(coverageResult).toStrictEqual({ status });
170+
expect(fetchMock).toHaveBeenCalledTimes(2);
171+
expect(getAccessToken).toHaveBeenCalledTimes(2);
172+
});
173+
174+
it('throws with invalid data', async () => {
175+
const { backend } = setup();
176+
177+
const signatureRequest = generateMockSignatureRequest();
178+
signatureRequest.messageParams.data = [];
179+
await expect(
180+
backend.checkSignatureCoverage(signatureRequest),
181+
).rejects.toThrow('Signature data must be a string');
182+
});
183+
});
144184
});

packages/shield-controller/src/backend.ts

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import type { SignatureRequest } from '@metamask/signature-controller';
12
import type { TransactionMeta } from '@metamask/transaction-controller';
23

34
import type { CoverageResult, CoverageStatus, ShieldBackend } from './types';
@@ -16,6 +17,14 @@ export type InitCoverageCheckRequest = {
1617
origin?: string;
1718
};
1819

20+
export type InitSignatureCoverageCheckRequest = {
21+
chainId: string;
22+
data: string;
23+
from: string;
24+
method: string;
25+
origin?: string;
26+
};
27+
1928
export type InitCoverageCheckResponse = {
2029
coverageId: string;
2130
};
@@ -59,9 +68,7 @@ export class ShieldRemoteBackend implements ShieldBackend {
5968
this.#fetch = fetchFn;
6069
}
6170

62-
checkCoverage: (txMeta: TransactionMeta) => Promise<CoverageResult> = async (
63-
txMeta,
64-
) => {
71+
async checkCoverage(txMeta: TransactionMeta): Promise<CoverageResult> {
6572
const reqBody: InitCoverageCheckRequest = {
6673
txParams: [
6774
{
@@ -76,22 +83,46 @@ export class ShieldRemoteBackend implements ShieldBackend {
7683
origin: txMeta.origin,
7784
};
7885

79-
const { coverageId } = await this.#initCoverageCheck(reqBody);
86+
const { coverageId } = await this.#initCoverageCheck(
87+
'v1/transaction/coverage/init',
88+
reqBody,
89+
);
8090

8191
return this.#getCoverageResult(coverageId);
82-
};
92+
}
93+
94+
async checkSignatureCoverage(
95+
signatureRequest: SignatureRequest,
96+
): Promise<CoverageResult> {
97+
if (typeof signatureRequest.messageParams.data !== 'string') {
98+
throw new Error('Signature data must be a string');
99+
}
100+
101+
const reqBody: InitSignatureCoverageCheckRequest = {
102+
chainId: signatureRequest.chainId,
103+
data: signatureRequest.messageParams.data,
104+
from: signatureRequest.messageParams.from,
105+
method: signatureRequest.type,
106+
origin: signatureRequest.messageParams.origin,
107+
};
108+
109+
const { coverageId } = await this.#initCoverageCheck(
110+
'v1/signature/coverage/init',
111+
reqBody,
112+
);
113+
114+
return this.#getCoverageResult(coverageId);
115+
}
83116

84117
async #initCoverageCheck(
85-
reqBody: InitCoverageCheckRequest,
118+
path: string,
119+
reqBody: unknown,
86120
): Promise<InitCoverageCheckResponse> {
87-
const res = await this.#fetch(
88-
`${this.#baseUrl}/v1/transaction/coverage/init`,
89-
{
90-
method: 'POST',
91-
headers: await this.#createHeaders(),
92-
body: JSON.stringify(reqBody),
93-
},
94-
);
121+
const res = await this.#fetch(`${this.#baseUrl}/${path}`, {
122+
method: 'POST',
123+
headers: await this.#createHeaders(),
124+
body: JSON.stringify(reqBody),
125+
});
95126
if (res.status !== 200) {
96127
throw new Error(`Failed to init coverage check: ${res.status}`);
97128
}

packages/shield-controller/src/types.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import type { SignatureRequest } from '@metamask/signature-controller';
12
import type { TransactionMeta } from '@metamask/transaction-controller';
23

34
export type CoverageResult = {
@@ -9,4 +10,7 @@ export type CoverageStatus = (typeof coverageStatuses)[number];
910

1011
export type ShieldBackend = {
1112
checkCoverage: (txMeta: TransactionMeta) => Promise<CoverageResult>;
13+
checkSignatureCoverage: (
14+
signatureRequest: SignatureRequest,
15+
) => Promise<CoverageResult>;
1216
};

0 commit comments

Comments
 (0)