Skip to content

Commit 07ab2a5

Browse files
committed
check signature coverage
1 parent 8029bcd commit 07ab2a5

File tree

12 files changed

+278
-21
lines changed

12 files changed

+278
-21
lines changed

packages/shield-controller/package.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
"uuid": "^8.3.2"
6868
},
6969
"peerDependencies": {
70+
"@metamask/signature-controller": "^33.0.0",
7071
"@metamask/transaction-controller": "^60.0.0"
7172
},
7273
"engines": {

packages/shield-controller/src/ShieldController.test.ts

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
1+
import type { SignatureControllerState } from '@metamask/signature-controller';
12
import type { TransactionControllerState } from '@metamask/transaction-controller';
23

34
import { ShieldController } from './ShieldController';
45
import { createMockBackend } from '../tests/mocks/backend';
56
import { createMockMessenger } from '../tests/mocks/messenger';
6-
import { generateMockTxMeta } from '../tests/utils';
7+
import {
8+
generateMockSignatureRequest,
9+
generateMockTxMeta,
10+
} from '../tests/utils';
711

812
/**
913
* Sets up a ShieldController for testing.
@@ -142,4 +146,74 @@ describe('ShieldController', () => {
142146
expect(backend.checkCoverage).toHaveBeenCalledWith(txMeta);
143147
});
144148
});
149+
150+
describe('checkSignatureCoverage', () => {
151+
it('should check signature coverage', async () => {
152+
const { baseMessenger, backend } = setup();
153+
const signatureRequest = generateMockSignatureRequest();
154+
const coverageResultReceived = new Promise<void>((resolve) => {
155+
baseMessenger.subscribe(
156+
'ShieldController:coverageResultReceived',
157+
(_coverageResult) => resolve(),
158+
);
159+
});
160+
baseMessenger.publish(
161+
'SignatureController:stateChange',
162+
{
163+
signatureRequests: { [signatureRequest.id]: signatureRequest },
164+
} as SignatureControllerState,
165+
undefined as never,
166+
);
167+
expect(await coverageResultReceived).toBeUndefined();
168+
expect(backend.checkSignatureCoverage).toHaveBeenCalledWith(
169+
signatureRequest,
170+
);
171+
});
172+
});
173+
174+
it('should check coverage for multiple signature request', async () => {
175+
const { baseMessenger, backend } = setup();
176+
const signatureRequest1 = generateMockSignatureRequest();
177+
const coverageResultReceived1 = new Promise<void>((resolve) => {
178+
baseMessenger.subscribe(
179+
'ShieldController:coverageResultReceived',
180+
(_coverageResult) => resolve(),
181+
);
182+
});
183+
baseMessenger.publish(
184+
'SignatureController:stateChange',
185+
{
186+
signatureRequests: {
187+
[signatureRequest1.id]: signatureRequest1,
188+
},
189+
} as SignatureControllerState,
190+
undefined as never,
191+
);
192+
expect(await coverageResultReceived1).toBeUndefined();
193+
expect(backend.checkSignatureCoverage).toHaveBeenCalledWith(
194+
signatureRequest1,
195+
);
196+
197+
const signatureRequest2 = generateMockSignatureRequest();
198+
const coverageResultReceived2 = new Promise<void>((resolve) => {
199+
baseMessenger.subscribe(
200+
'ShieldController:coverageResultReceived',
201+
(_coverageResult) => resolve(),
202+
);
203+
});
204+
baseMessenger.publish(
205+
'SignatureController:stateChange',
206+
{
207+
signatureRequests: {
208+
[signatureRequest2.id]: signatureRequest2,
209+
},
210+
} as SignatureControllerState,
211+
undefined as never,
212+
);
213+
214+
expect(await coverageResultReceived2).toBeUndefined();
215+
expect(backend.checkSignatureCoverage).toHaveBeenCalledWith(
216+
signatureRequest2,
217+
);
218+
});
145219
});

packages/shield-controller/src/ShieldController.ts

Lines changed: 80 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,11 @@ import type {
33
ControllerStateChangeEvent,
44
RestrictedMessenger,
55
} from '@metamask/base-controller';
6+
import {
7+
SignatureRequestType,
8+
type SignatureRequest,
9+
type SignatureStateChange,
10+
} from '@metamask/signature-controller';
611
import type {
712
TransactionControllerStateChangeEvent,
813
TransactionMeta,
@@ -82,7 +87,9 @@ type AllowedActions = never;
8287
/**
8388
* The external events available to the ShieldController.
8489
*/
85-
type AllowedEvents = TransactionControllerStateChangeEvent;
90+
type AllowedEvents =
91+
| SignatureStateChange
92+
| TransactionControllerStateChangeEvent;
8693

8794
/**
8895
* The messenger of the {@link ShieldController}.
@@ -134,6 +141,11 @@ export class ShieldController extends BaseController<
134141
previousTransactions: TransactionMeta[] | undefined,
135142
) => void;
136143

144+
readonly #signatureControllerStateChangeHandler: (
145+
signatureRequests: Record<string, SignatureRequest>,
146+
previousSignatureRequests: Record<string, SignatureRequest> | undefined,
147+
) => void;
148+
137149
constructor(options: ShieldControllerOptions) {
138150
const {
139151
messenger,
@@ -157,6 +169,8 @@ export class ShieldController extends BaseController<
157169
this.#transactionHistoryLimit = transactionHistoryLimit;
158170
this.#transactionControllerStateChangeHandler =
159171
this.#handleTransactionControllerStateChange.bind(this);
172+
this.#signatureControllerStateChangeHandler =
173+
this.#handleSignatureControllerStateChange.bind(this);
160174
}
161175

162176
start() {
@@ -165,13 +179,54 @@ export class ShieldController extends BaseController<
165179
this.#transactionControllerStateChangeHandler,
166180
(state) => state.transactions,
167181
);
182+
183+
this.messagingSystem.subscribe(
184+
'SignatureController:stateChange',
185+
this.#signatureControllerStateChangeHandler,
186+
(state) => state.signatureRequests,
187+
);
168188
}
169189

170190
stop() {
171191
this.messagingSystem.unsubscribe(
172192
'TransactionController:stateChange',
173193
this.#transactionControllerStateChangeHandler,
174194
);
195+
196+
this.messagingSystem.unsubscribe(
197+
'SignatureController:stateChange',
198+
this.#signatureControllerStateChangeHandler,
199+
);
200+
}
201+
202+
#handleSignatureControllerStateChange(
203+
signatureRequests: Record<string, SignatureRequest>,
204+
previousSignatureRequests: Record<string, SignatureRequest> | undefined,
205+
) {
206+
const signatureRequestsArray = Object.values(signatureRequests);
207+
const previousSignatureRequestsArray = Object.values(
208+
previousSignatureRequests ?? {},
209+
);
210+
const previousSignatureRequestsById = new Map<string, SignatureRequest>(
211+
previousSignatureRequestsArray.map((request) => [request.id, request]),
212+
);
213+
for (const signatureRequest of signatureRequestsArray) {
214+
const previousSignatureRequest = previousSignatureRequestsById.get(
215+
signatureRequest.id,
216+
);
217+
218+
// Check coverage if the signature request is new and has type
219+
// `personal_sign`.
220+
if (
221+
!previousSignatureRequest &&
222+
signatureRequest.type === SignatureRequestType.PersonalSign
223+
) {
224+
this.checkSignatureCoverage(signatureRequest).catch(
225+
// istanbul ignore next
226+
(error) => log('Error checking coverage:', error),
227+
);
228+
}
229+
}
175230
}
176231

177232
#handleTransactionControllerStateChange(
@@ -208,7 +263,7 @@ export class ShieldController extends BaseController<
208263
*/
209264
async checkCoverage(txMeta: TransactionMeta): Promise<CoverageResult> {
210265
// Check coverage
211-
const coverageResult = await this.#fetchCoverageResult(txMeta);
266+
const coverageResult = await this.#backend.checkCoverage(txMeta);
212267

213268
// Publish coverage result
214269
this.messagingSystem.publish(
@@ -222,8 +277,29 @@ export class ShieldController extends BaseController<
222277
return coverageResult;
223278
}
224279

225-
async #fetchCoverageResult(txMeta: TransactionMeta): Promise<CoverageResult> {
226-
return this.#backend.checkCoverage(txMeta);
280+
/**
281+
* Checks the coverage of a signature request.
282+
*
283+
* @param signatureRequest - The signature request to check coverage for.
284+
* @returns The coverage result.
285+
*/
286+
async checkSignatureCoverage(
287+
signatureRequest: SignatureRequest,
288+
): Promise<CoverageResult> {
289+
// Check coverage
290+
const coverageResult =
291+
await this.#backend.checkSignatureCoverage(signatureRequest);
292+
293+
// Publish coverage result
294+
this.messagingSystem.publish(
295+
`${controllerName}:coverageResultReceived`,
296+
coverageResult,
297+
);
298+
299+
// Update state
300+
this.#addCoverageResult(signatureRequest.id, coverageResult);
301+
302+
return coverageResult;
227303
}
228304

229305
#addCoverageResult(txId: string, coverageResult: CoverageResult) {

packages/shield-controller/src/backend.test.ts

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import { ShieldRemoteBackend } from './backend';
2-
import { generateMockTxMeta, getRandomCoverageStatus } from '../tests/utils';
2+
import { generateMockSignatureRequest, generateMockTxMeta, getRandomCoverageStatus } from '../tests/utils';
33

44
/**
55
* Setup the test environment.
@@ -141,4 +141,40 @@ describe('ShieldRemoteBackend', () => {
141141
// that the polling loop is exited as expected.
142142
await new Promise((resolve) => setTimeout(resolve, 10));
143143
});
144+
145+
describe('checkSignatureCoverage', () => {
146+
it('should check signature coverage', async () => {
147+
const { backend, fetchMock, getAccessToken } = setup();
148+
149+
// Mock init coverage check.
150+
fetchMock.mockResolvedValueOnce({
151+
status: 200,
152+
json: jest.fn().mockResolvedValue({ coverageId: 'coverageId' }),
153+
} as unknown as Response);
154+
155+
// Mock get coverage result.
156+
const status = getRandomCoverageStatus();
157+
fetchMock.mockResolvedValueOnce({
158+
status: 200,
159+
json: jest.fn().mockResolvedValue({ status }),
160+
} as unknown as Response);
161+
162+
const signatureRequest = generateMockSignatureRequest();
163+
const coverageResult =
164+
await backend.checkSignatureCoverage(signatureRequest);
165+
expect(coverageResult).toStrictEqual({ status });
166+
expect(fetchMock).toHaveBeenCalledTimes(2);
167+
expect(getAccessToken).toHaveBeenCalledTimes(2);
168+
});
169+
170+
it('throws with invalid data', async () => {
171+
const { backend } = setup();
172+
173+
const signatureRequest = generateMockSignatureRequest();
174+
signatureRequest.messageParams.data = [];
175+
await expect(
176+
backend.checkSignatureCoverage(signatureRequest),
177+
).rejects.toThrow('Signature data must be a string');
178+
});
179+
});
144180
});

packages/shield-controller/src/backend.ts

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import type { SignatureRequest } from '@metamask/signature-controller';
12
import type { TransactionMeta } from '@metamask/transaction-controller';
23

34
import type { CoverageResult, CoverageStatus, ShieldBackend } from './types';
@@ -16,6 +17,14 @@ export type InitCoverageCheckRequest = {
1617
origin?: string;
1718
};
1819

20+
export type InitSignatureCoverageCheckRequest = {
21+
chainId: string;
22+
data: string;
23+
from: string;
24+
method: string;
25+
origin?: string;
26+
};
27+
1928
export type InitCoverageCheckResponse = {
2029
coverageId: string;
2130
};
@@ -59,9 +68,7 @@ export class ShieldRemoteBackend implements ShieldBackend {
5968
this.#fetch = fetchFn;
6069
}
6170

62-
checkCoverage: (txMeta: TransactionMeta) => Promise<CoverageResult> = async (
63-
txMeta,
64-
) => {
71+
async checkCoverage(txMeta: TransactionMeta): Promise<CoverageResult> {
6572
const reqBody: InitCoverageCheckRequest = {
6673
txParams: [
6774
{
@@ -76,22 +83,46 @@ export class ShieldRemoteBackend implements ShieldBackend {
7683
origin: txMeta.origin,
7784
};
7885

79-
const { coverageId } = await this.#initCoverageCheck(reqBody);
86+
const { coverageId } = await this.#initCoverageCheck(
87+
'v1/transaction/coverage/init',
88+
reqBody,
89+
);
8090

8191
return this.#getCoverageResult(coverageId);
82-
};
92+
}
93+
94+
async checkSignatureCoverage(
95+
signatureRequest: SignatureRequest,
96+
): Promise<CoverageResult> {
97+
if (typeof signatureRequest.messageParams.data !== 'string') {
98+
throw new Error('Signature data must be a string');
99+
}
100+
101+
const reqBody: InitSignatureCoverageCheckRequest = {
102+
chainId: signatureRequest.chainId,
103+
data: signatureRequest.messageParams.data,
104+
from: signatureRequest.messageParams.from,
105+
method: signatureRequest.type,
106+
origin: signatureRequest.messageParams.origin,
107+
};
108+
109+
const { coverageId } = await this.#initCoverageCheck(
110+
'v1/signature/coverage/init',
111+
reqBody,
112+
);
113+
114+
return this.#getCoverageResult(coverageId);
115+
}
83116

84117
async #initCoverageCheck(
85-
reqBody: InitCoverageCheckRequest,
118+
path: string,
119+
reqBody: unknown,
86120
): Promise<InitCoverageCheckResponse> {
87-
const res = await this.#fetch(
88-
`${this.#baseUrl}/v1/transaction/coverage/init`,
89-
{
90-
method: 'POST',
91-
headers: await this.#createHeaders(),
92-
body: JSON.stringify(reqBody),
93-
},
94-
);
121+
const res = await this.#fetch(`${this.#baseUrl}/${path}`, {
122+
method: 'POST',
123+
headers: await this.#createHeaders(),
124+
body: JSON.stringify(reqBody),
125+
});
95126
if (res.status !== 200) {
96127
throw new Error(`Failed to init coverage check: ${res.status}`);
97128
}

packages/shield-controller/src/types.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import type { SignatureRequest } from '@metamask/signature-controller';
12
import type { TransactionMeta } from '@metamask/transaction-controller';
23

34
export type CoverageResult = {
@@ -9,4 +10,7 @@ export type CoverageStatus = (typeof coverageStatuses)[number];
910

1011
export type ShieldBackend = {
1112
checkCoverage: (txMeta: TransactionMeta) => Promise<CoverageResult>;
13+
checkSignatureCoverage: (
14+
signatureRequest: SignatureRequest,
15+
) => Promise<CoverageResult>;
1216
};

packages/shield-controller/tests/mocks/backend.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,8 @@ export function createMockBackend() {
88
checkCoverage: jest.fn().mockResolvedValue({
99
status: 'covered',
1010
}),
11+
checkSignatureCoverage: jest.fn().mockResolvedValue({
12+
status: 'covered',
13+
}),
1114
};
1215
}

0 commit comments

Comments
 (0)