diff --git a/__mocks__/typedData/example_MerkleTreeNested.json b/__mocks__/typedData/example_MerkleTreeNested.json new file mode 100644 index 000000000..de12c7253 --- /dev/null +++ b/__mocks__/typedData/example_MerkleTreeNested.json @@ -0,0 +1,45 @@ +{ + "types": { + "StarknetDomain": [ + { "name": "name", "type": "shortstring" }, + { "name": "version", "type": "shortstring" }, + { "name": "chainId", "type": "shortstring" }, + { "name": "revision", "type": "shortstring" } + ], + "Example": [ + { "name": "flag", "type": "bool" }, + { "name": "data", "type": "Session" } + ], + "Session": [{ "name": "root", "type": "merkletree", "contains": "Policy" }], + "Policy": [ + { "name": "contractAddress", "type": "felt" }, + { "name": "selector", "type": "selector" } + ] + }, + "primaryType": "Example", + "domain": { + "name": "StarkNet Mail", + "version": "1", + "chainId": "1", + "revision": "1" + }, + "message": { + "flag": true, + "data": { + "root": [ + { + "contractAddress": "0x1", + "selector": "transfer" + }, + { + "contractAddress": "0x2", + "selector": "transfer" + }, + { + "contractAddress": "0x3", + "selector": "transfer" + } + ] + } + } +} diff --git a/__mocks__/typedData/example_baseTypes.json b/__mocks__/typedData/example_baseTypes.json index db2285843..eb7e936d4 100644 --- a/__mocks__/typedData/example_baseTypes.json +++ b/__mocks__/typedData/example_baseTypes.json @@ -10,7 +10,8 @@ { "name": "n0", "type": "felt" }, { "name": "n1", "type": "bool" }, { "name": "n2", "type": "string" }, - { "name": "n3", "type": "selector" }, + { "name": "n3_0", "type": "selector" }, + { "name": "n3_1", "type": "selector" }, { "name": "n4", "type": "u128" }, { "name": "n5", "type": "i128" }, { "name": "n6", "type": "ContractAddress" }, @@ -30,7 +31,8 @@ "n0": "0x3e8", "n1": true, "n2": "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.", - "n3": "transfer", + "n3_0": "transfer", + "n3_1": "0x83afd3f4caedc6eebf44246fe54e38c95e3179a5ec9ea81740eca5b482d12e", "n4": 10, "n5": -10, "n6": "0x3e8", diff --git a/__tests__/utils/typedData.test.ts b/__tests__/utils/typedData.test.ts index 51c435bac..1e855c84a 100644 --- a/__tests__/utils/typedData.test.ts +++ b/__tests__/utils/typedData.test.ts @@ -4,6 +4,7 @@ import typedDataExample from '../../__mocks__/typedData/baseExample.json'; import exampleBaseTypes from '../../__mocks__/typedData/example_baseTypes.json'; import exampleEnum from '../../__mocks__/typedData/example_enum.json'; import exampleEnumNested from '../../__mocks__/typedData/example_enumNested.json'; +import exampleMerkleTreeNested from '../../__mocks__/typedData/example_MerkleTreeNested.json'; import examplePresetTypes from '../../__mocks__/typedData/example_presetTypes.json'; import typedDataStructArrayExample from '../../__mocks__/typedData/mail_StructArray.json'; import typedDataSessionExample from '../../__mocks__/typedData/session_MerkleTree.json'; @@ -57,7 +58,7 @@ describe('typedData', () => { ); encoded = encodeType(exampleBaseTypes.types, 'Example', TypedDataRevision.ACTIVE); expect(encoded).toMatchInlineSnapshot( - `"\\"Example\\"(\\"n0\\":\\"felt\\",\\"n1\\":\\"bool\\",\\"n2\\":\\"string\\",\\"n3\\":\\"selector\\",\\"n4\\":\\"u128\\",\\"n5\\":\\"i128\\",\\"n6\\":\\"ContractAddress\\",\\"n7\\":\\"ClassHash\\",\\"n8\\":\\"timestamp\\",\\"n9\\":\\"shortstring\\")"` + `"\\"Example\\"(\\"n0\\":\\"felt\\",\\"n1\\":\\"bool\\",\\"n2\\":\\"string\\",\\"n3_0\\":\\"selector\\",\\"n3_1\\":\\"selector\\",\\"n4\\":\\"u128\\",\\"n5\\":\\"i128\\",\\"n6\\":\\"ContractAddress\\",\\"n7\\":\\"ClassHash\\",\\"n8\\":\\"timestamp\\",\\"n9\\":\\"shortstring\\")"` ); encoded = encodeType(examplePresetTypes.types, 'Example', TypedDataRevision.ACTIVE); expect(encoded).toMatchInlineSnapshot( @@ -65,11 +66,15 @@ describe('typedData', () => { ); encoded = encodeType(exampleEnum.types, 'Example', TypedDataRevision.ACTIVE); expect(encoded).toMatchInlineSnapshot( - `"\\"Example\\"(\\"someEnum1\\":\\"EnumA\\",\\"someEnum2\\":\\"EnumB\\")\\"EnumA\\"(\\"Variant 1\\":(),\\"Variant 2\\":(\\"u128\\",\\"u128*\\"),\\"Variant 3\\":(\\"u128\\"))\\"EnumB\\"(\\"Variant 1\\":(),\\"Variant 2\\":(\\"u128\\"))"` + `"\\"Example\\"(\\"someEnum1\\":\\"enum\\",\\"someEnum2\\":\\"enum\\")\\"EnumA\\"(\\"Variant 1\\"(),\\"Variant 2\\"(\\"u128\\",\\"u128*\\"),\\"Variant 3\\"(\\"u128\\"))\\"EnumB\\"(\\"Variant 1\\"(),\\"Variant 2\\"(\\"u128\\"))"` ); encoded = encodeType(exampleEnumNested.types, 'Example', TypedDataRevision.ACTIVE); expect(encoded).toMatchInlineSnapshot( - `"\\"Example\\"(\\"someEnum\\":\\"EnumA\\")\\"EnumA\\"(\\"Variant 1\\":(),\\"Variant 2\\":(\\"u128\\",\\"StructA\\"))\\"EnumB\\"(\\"Variant A\\":(),\\"Variant B\\":(\\"StructB*\\"))\\"StructA\\"(\\"nestedEnum\\":\\"EnumB\\")\\"StructB\\"(\\"flag\\":\\"bool\\")"` + `"\\"Example\\"(\\"someEnum\\":\\"enum\\")\\"EnumA\\"(\\"Variant 1\\"(),\\"Variant 2\\"(\\"u128\\",\\"StructA\\"))\\"EnumB\\"(\\"Variant A\\"(),\\"Variant B\\"(\\"StructB*\\"))\\"StructA\\"(\\"nestedEnum\\":\\"enum\\")\\"StructB\\"(\\"flag\\":\\"bool\\")"` + ); + encoded = encodeType(exampleMerkleTreeNested.types, 'Example', TypedDataRevision.ACTIVE); + expect(encoded).toMatchInlineSnapshot( + `"\\"Example\\"(\\"flag\\":\\"bool\\",\\"data\\":\\"Session\\")\\"Policy\\"(\\"contractAddress\\":\\"felt\\",\\"selector\\":\\"selector\\")\\"Session\\"(\\"root\\":\\"merkletree\\")"` ); }); @@ -101,7 +106,7 @@ describe('typedData', () => { ); typeHash = getTypeHash(exampleBaseTypes.types, 'Example', TypedDataRevision.ACTIVE); expect(typeHash).toMatchInlineSnapshot( - `"0x1f94cd0be8b4097a41486170fdf09a4cd23aefbc74bb2344718562994c2c111"` + `"0x2fe0aa1f0baee396812084785f7907b3e1204f8b3451d6ec37b18d35f5e004d"` ); typeHash = getTypeHash(examplePresetTypes.types, 'Example', TypedDataRevision.ACTIVE); expect(typeHash).toMatchInlineSnapshot( @@ -109,11 +114,11 @@ describe('typedData', () => { ); typeHash = getTypeHash(exampleEnum.types, 'Example', TypedDataRevision.ACTIVE); expect(typeHash).toMatchInlineSnapshot( - `"0x8eb4aeac64b707f3e843284c4258df6df1f0f7fd38dcffdd8a153a495cd351"` + `"0x2d5ac0cebbe47959d53fafa1c230a3cbd4e2f17f89c461b17c4864baf54439f"` ); typeHash = getTypeHash(exampleEnumNested.types, 'Example', TypedDataRevision.ACTIVE); expect(typeHash).toMatchInlineSnapshot( - `"0x2143bb787fabace39d62e9acf8b6e97d9a369000516c3e6ffd963dc1370fc1a"` + `"0x246f8826603bb897655d8058028a31bfbb3589694df3c18a1fc36d386464752"` ); }); @@ -132,18 +137,21 @@ describe('typedData', () => { }); test('should prepare selector', () => { - const res1 = prepareSelector('myFunction'); - expect(res1).toEqual('0xc14cfe23f3fa7ce7b1f8db7d7682305b1692293f71a61cc06637f0d8d8b6c8'); + const name = 'myFunction'; + const selector = '0xc14cfe23f3fa7ce7b1f8db7d7682305b1692293f71a61cc06637f0d8d8b6c8'; - const res2 = prepareSelector( - '0xc14cfe23f3fa7ce7b1f8db7d7682305b1692293f71a61cc06637f0d8d8b6c8' - ); - expect(res2).toEqual('0xc14cfe23f3fa7ce7b1f8db7d7682305b1692293f71a61cc06637f0d8d8b6c8'); + const res1 = prepareSelector(name); + expect(res1).toEqual(selector); - const res3 = prepareSelector( - '0xc14cfe23f3fa7ce7b1f8db7d7682305b1692293f71a61cc06637f0d8d8b6c8' - ); + const res2 = prepareSelector(selector); + expect(res2).toEqual(selector); + + const res3 = prepareSelector(selector); expect(res3).not.toEqual('0xc14cfe23f3fa7ce7b1f8db7d76'); + + const res4 = prepareSelector(selector, TypedDataRevision.ACTIVE); + expect(res4).not.toEqual(selector); + expect(res4).toEqual('0x424f5e095375246eb2e25c35fdb9a1398a2b8b1f1f3956c270dbc24c46bdda'); }); test('should transform merkle tree', () => { @@ -235,7 +243,7 @@ describe('typedData', () => { TypedDataRevision.ACTIVE ); expect(hash).toMatchInlineSnapshot( - `"0x555f72e550b308e50c1a4f8611483a174026c982a9893a05c185eeb85399657"` + `"0x59d498384b95b0f1e011d810bfb195ef4ba0a9a637f235026c760f022894b87"` ); }); @@ -325,22 +333,27 @@ describe('typedData', () => { let messageHash: string; messageHash = getMessageHash(exampleBaseTypes, exampleAddress); expect(messageHash).toMatchInlineSnapshot( - `"0xdb7829db8909c0c5496f5952bcfc4fc894341ce01842537fc4f448743480b6"` + `"0x554ee8db1f628fd2cca20f80a0b76ad543e3fa1a066557ac5e3055ccd3f3724"` ); messageHash = getMessageHash(examplePresetTypes, exampleAddress); expect(messageHash).toMatchInlineSnapshot( - `"0x185b339d5c566a883561a88fb36da301051e2c0225deb325c91bb7aa2f3473a"` + `"0x7d53c86332b95e3f23a37b70d58b209a8ae25c0b97b0d27635ac0ae5ca1973"` ); messageHash = getMessageHash(exampleEnum, exampleAddress); expect(messageHash).toMatchInlineSnapshot( - `"0x6e61abaf480b1370bbf231f54e298c5f4872f40a6d2dd409ff30accee5bbd1e"` + `"0x1e6ef6a264070e0cc5209b0b93d2a0eb49cb8b393517a48aa18718feba1263a"` ); messageHash = getMessageHash(exampleEnumNested, exampleAddress); expect(messageHash).toMatchInlineSnapshot( - `"0x691fc54567306a8ea5431130f1b98299e74a748ac391540a86736f20ef5f2b7"` + `"0x5ee5e9938bacbf3a5fd145c218133ab25d02d03536b798f4ceaa80b46c3f8a1"` + ); + + messageHash = getMessageHash(exampleMerkleTreeNested, exampleAddress); + expect(messageHash).toMatchInlineSnapshot( + `"0x233570609d39af53518adee8d2d77b3fa434c9d1517e3a0e82b4f2501cd7725"` ); expect(spyPedersen).not.toHaveBeenCalled(); diff --git a/src/utils/typedData.ts b/src/utils/typedData.ts index 6f2dbb9fe..a3c05b2ef 100644 --- a/src/utils/typedData.ts +++ b/src/utils/typedData.ts @@ -21,7 +21,7 @@ import { } from './hash'; import { MerkleTree } from './merkle'; import { isBigNumberish, isHex, toHex } from './num'; -import { encodeShortString } from './shortString'; +import { encodeShortString, isShortString } from './shortString'; import { isBoolean, isString } from './typed'; interface Context { @@ -119,8 +119,8 @@ export function validateTypedData(data: unknown): data is TypedData { * // result2 = '0xc14cfe23f3fa7ce7b1f8db7d7682305b1692293f71a61cc06637f0d8d8b6c8' * ``` */ -export function prepareSelector(selector: string): string { - return isHex(selector) ? selector : getSelectorFromName(selector); +export function prepareSelector(selector: string, revision: Revision = Revision.LEGACY): string { + return revision === Revision.LEGACY && isHex(selector) ? selector : getSelectorFromName(selector); } /** @@ -170,8 +170,8 @@ export function getDependencies( if (type[type.length - 1] === '*') { dependencyTypes = [type.slice(0, -1)]; } else if (revision === Revision.ACTIVE) { - // enum base - if (type === 'enum') { + // enum or merkletree base + if (type === 'enum' || type === 'merkletree') { dependencyTypes = [contains]; } // enum element types @@ -259,26 +259,22 @@ export function encodeType( const esc = revisionConfiguration[revision].escapeTypeString; - return newTypes - .map((dependency) => { - const dependencyElements = allTypes[dependency].map((t) => { - const targetType = - t.type === 'enum' && revision === Revision.ACTIVE - ? (t as StarknetEnumType).contains - : t.type; - // parentheses handling for enum variant types - const typeString = targetType.match(/^\(.*\)$/) - ? `(${targetType - .slice(1, -1) - .split(',') - .map((e) => (e ? esc(e) : e)) - .join(',')})` - : esc(targetType); - return `${esc(t.name)}:${typeString}`; - }); - return `${esc(dependency)}(${dependencyElements})`; - }) - .join(''); + const escapedTypes = newTypes.map((dependency) => { + const dependencyElements = allTypes[dependency].map((t) => { + const targetType = t.type; + // parentheses handling for enum variant types + const typeString = targetType.match(/^\(.*\)$/) + ? `(${targetType + .slice(1, -1) + .split(',') + .map((e) => (e ? esc(e) : e)) + .join(',')})` + : `:${esc(targetType)}`; + return `${esc(t.name)}${typeString}`; + }); + return `${esc(dependency)}(${dependencyElements})`; + }); + return escapedTypes.join(''); } /** @@ -364,22 +360,24 @@ export function encodeValue( if (revision === Revision.ACTIVE) { const [variantKey, variantData] = Object.entries(data as TypedData['message'])[0]; - const parentType = types[ctx.parent as string].find((t) => t.name === ctx.key); - const enumType = types[(parentType as StarknetEnumType).contains]; + const parentType = types[ctx.parent as string].find((t) => t.name === ctx.key)!; + const enumName = (parentType as StarknetEnumType).contains; + const enumType = types[enumName]; const variantType = enumType.find((t) => t.name === variantKey) as StarknetType; const variantIndex = enumType.indexOf(variantType); + const typeHash = getTypeHash(types, enumName, revision); const encodedSubtypes = variantType.type .slice(1, -1) .split(',') + .filter((subtype) => !!subtype) .map((subtype, index) => { - if (!subtype) return subtype; const subtypeData = (variantData as unknown[])[index]; return encodeValue(types, subtype, subtypeData, undefined, revision)[1]; }); return [ type, - revisionConfiguration[revision].hashMethod([variantIndex, ...encodedSubtypes]), + revisionConfiguration[revision].hashMethod([typeHash, variantIndex, ...encodedSubtypes]), ]; } // else fall through to default return [type, getHex(data as string)]; @@ -396,7 +394,7 @@ export function encodeValue( return ['felt', root]; } case 'selector': { - return ['felt', prepareSelector(data as string)]; + return ['felt', prepareSelector(data as string, revision)]; } case 'string': { if (revision === Revision.ACTIVE) { @@ -426,14 +424,22 @@ export function encodeValue( } // else fall through to default return [type, getHex(data as string)]; } - case 'felt': - case 'shortstring': { - // TODO: should 'shortstring' diverge into directly using encodeShortString()? + case 'felt': { if (revision === Revision.ACTIVE) { assertRange(getHex(data as string), type, RANGE_FELT); } // else fall through to default return [type, getHex(data as string)]; } + case 'shortstring': { + if (revision === Revision.ACTIVE) { + if (ctx.parent === revisionConfiguration[revision].domain && ctx.key === 'revision') { + return [type, getHex(data as string)]; + } + assert(isString(data) && isShortString(data), `Type mismatch for ${type} ${data}`); + return [type, encodeShortString(data)]; + } // else fall through to default + return [type, getHex(data as string)]; + } case 'ClassHash': case 'ContractAddress': { if (revision === Revision.ACTIVE) { @@ -478,7 +484,7 @@ export function encodeData( ([ts, vs], field) => { if ( data[field.name as keyof T['message']] === undefined || - (data[field.name as keyof T['message']] === null && field.type !== 'enum') + data[field.name as keyof T['message']] === null ) { throw new Error(`Cannot encode data: missing data for '${field.name}'`); }