Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion contracts/evm/ERC20Custody.sol
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
pragma solidity 0.8.26;

import { IERC20Custody } from "./interfaces/IERC20Custody.sol";
import { IGatewayEVM, MessageContext } from "./interfaces/IGatewayEVM.sol";
import { IGatewayEVM, MessageContext} from "./interfaces/IGatewayEVM.sol";

import { RevertContext } from "../../contracts/Revert.sol";

Expand Down
12 changes: 9 additions & 3 deletions contracts/evm/GatewayEVM.sol
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import { INotSupportedMethods } from "../../contracts/Errors.sol";
import { RevertContext, RevertOptions, Revertable } from "../../contracts/Revert.sol";
import { ZetaConnectorBase } from "./ZetaConnectorBase.sol";
import { IERC20Custody } from "./interfaces/IERC20Custody.sol";
import { Callable, IGatewayEVM, MessageContext } from "./interfaces/IGatewayEVM.sol";
import "./interfaces/IGatewayEVM.sol";

import "@openzeppelin/contracts-upgradeable/access/AccessControlUpgradeable.sol";
import "@openzeppelin/contracts-upgradeable/proxy/utils/Initializable.sol";
Expand Down Expand Up @@ -452,7 +452,13 @@ contract GatewayEVM is
private
returns (bytes memory)
{
return Callable(destination).onCall{ value: msg.value }(messageContext, data);
if (messageContext.amount == 0) {
return Callable(destination).onCall{ value: msg.value }(
LegacyMessageContext({ sender: messageContext.sender }), data
);
} else {
return CallableV2(destination).onCall{ value: msg.value }(messageContext, data);
}
}

// @dev prevent spoofing onCall and onRevert functions
Expand All @@ -463,7 +469,7 @@ contract GatewayEVM is
functionSelector := calldataload(data.offset)
}

if (functionSelector == Callable.onCall.selector) {
if (functionSelector == Callable.onCall.selector || functionSelector == CallableV2.onCall.selector) {
revert NotAllowedToCallOnCall();
}

Expand Down
2 changes: 1 addition & 1 deletion contracts/evm/Registry.sol
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ contract Registry is BaseRegistry, IRegistry {
/// @param context Information about the cross-chain message
/// @param data The encoded function call to execute
function onCall(
MessageContext calldata context,
LegacyMessageContext calldata context,
bytes calldata data
)
external
Expand Down
7 changes: 1 addition & 6 deletions contracts/evm/ZetaConnectorBase.sol
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,7 @@ import "@openzeppelin/contracts/token/ERC20/IERC20.sol";
import "@openzeppelin/contracts/token/ERC20/utils/SafeERC20.sol";

import { RevertContext } from "../../contracts/Revert.sol";
import {
IGatewayEVM,
IGatewayEVMErrors,
IGatewayEVMEvents,
MessageContext
} from "../../contracts/evm/interfaces/IGatewayEVM.sol";
import { IGatewayEVM, MessageContext } from "../../contracts/evm/interfaces/IGatewayEVM.sol";
import "../../contracts/evm/interfaces/IZetaConnector.sol";

/// @title ZetaConnectorBase
Expand Down
2 changes: 1 addition & 1 deletion contracts/evm/interfaces/IERC20Custody.sol
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ pragma solidity 0.8.26;

import { RevertContext } from "../../../contracts/Revert.sol";

import { MessageContext } from "./IGatewayEVM.sol";
import {MessageContext} from "./IGatewayEVM.sol";
import "@openzeppelin/contracts/token/ERC20/IERC20.sol";

/// @title IERC20CustodyEvents
Expand Down
25 changes: 23 additions & 2 deletions contracts/evm/interfaces/IGatewayEVM.sol
Original file line number Diff line number Diff line change
Expand Up @@ -221,11 +221,32 @@ interface IGatewayEVM is IGatewayEVMErrors, IGatewayEVMEvents {

/// @notice Message context passed to execute function.
/// @param sender Sender from omnichain contract.
struct MessageContext {
struct LegacyMessageContext {
address sender;
}

/// @notice Interface implemented by contracts receiving authenticated calls.
interface Callable {
function onCall(MessageContext calldata context, bytes calldata message) external payable returns (bytes memory);
function onCall(LegacyMessageContext calldata context, bytes calldata message) external payable returns (bytes memory);
}

/// @notice Message context passed to execute function.
/// @param sender Sender from omnichain contract.
/// @param asset The address of the asset.
/// @param amount The amount of the asset.
struct MessageContext {
address sender;
address asset;
uint256 amount;
}

/// @notice Interface implemented by contracts receiving authenticated calls with new MessageContext.
interface CallableV2 {
function onCall(
MessageContext calldata context,
bytes calldata message
)
external
payable
returns (bytes memory);
}
41 changes: 41 additions & 0 deletions contracts/zevm/GatewayZEVM.sol
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,47 @@ contract GatewayZEVM is
);
}

/// @notice Withdraw ZRC20 tokens and call a smart contract on an external chain.
/// @param receiver The receiver address on the external chain.
/// @param amount The amount of tokens to withdraw.
/// @param zrc20 The address of the ZRC20 token.
/// @param message The calldata to pass to the contract call.
/// @param version The number representing message context version.
/// @param callOptions Call options including gas limit, arbirtrary call flag and message context version.
/// @param revertOptions Revert options.
function withdrawAndCall(
bytes memory receiver,
uint256 amount,
address zrc20,
bytes calldata message,
uint256 version,
CallOptions calldata callOptions,
RevertOptions calldata revertOptions
)
external
whenNotPaused
{
if (receiver.length == 0) revert ZeroAddress();
if (amount == 0) revert InsufficientZRC20Amount();
if (callOptions.gasLimit < MIN_GAS_LIMIT) revert InsufficientGasLimit();
if (message.length + revertOptions.revertMessage.length > MAX_MESSAGE_SIZE) revert MessageSizeExceeded();

uint256 gasFee = _withdrawZRC20WithGasLimit(amount, zrc20, callOptions.gasLimit);
emit WithdrawnAndCalledV2(
msg.sender,
0,
receiver,
zrc20,
amount,
gasFee,
IZRC20(zrc20).PROTOCOL_FLAT_FEE(),
message,
version,
callOptions,
revertOptions
);
}

/// @notice Withdraw ZETA tokens to an external chain.
//// @param receiver The receiver address on the external chain.
//// @param amount The amount of tokens to withdraw.
Expand Down
45 changes: 45 additions & 0 deletions contracts/zevm/interfaces/IGatewayZEVM.sol
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,32 @@ interface IGatewayZEVMEvents {
CallOptions callOptions,
RevertOptions revertOptions
);

/// @notice Emitted when a withdraw and call is made.
/// @param sender The address from which the tokens are withdrawn.
/// @param chainId Chain id of external chain.
/// @param receiver The receiver address on the external chain.
/// @param zrc20 The address of the ZRC20 token.
/// @param value The amount of tokens withdrawn.
/// @param gasfee The gas fee for the withdrawal.
/// @param protocolFlatFee The protocol flat fee for the withdrawal.
/// @param message The calldata passed to the contract call.
/// @param version The number representing message context version.
/// @param callOptions Call options including gas limit, arbirtrary call flag and message context version.
/// @param revertOptions Revert options.
event WithdrawnAndCalledV2(
address indexed sender,
uint256 indexed chainId,
bytes receiver,
address zrc20,
uint256 value,
uint256 gasfee,
uint256 protocolFlatFee,
bytes message,
uint256 version,
CallOptions callOptions,
RevertOptions revertOptions
);
}

/// @title IGatewayZEVMErrors
Expand Down Expand Up @@ -161,6 +187,25 @@ interface IGatewayZEVM is IGatewayZEVMErrors, IGatewayZEVMEvents {
)
external;

/// @notice Withdraw ZRC20 tokens and call a smart contract on an external chain.
/// @param receiver The receiver address on the external chain.
/// @param amount The amount of tokens to withdraw.
/// @param zrc20 The address of the ZRC20 token.
/// @param message The calldata to pass to the contract call.
/// @param version The number representing message context version.
/// @param callOptions Call options including gas limit, arbirtrary call flag and message context version.
/// @param revertOptions Revert options.
function withdrawAndCall(
bytes memory receiver,
uint256 amount,
address zrc20,
bytes calldata message,
uint256 version,
CallOptions calldata callOptions,
RevertOptions calldata revertOptions
)
external;

/// @notice Withdraw ZETA tokens and call a smart contract on an external chain.
/// @param receiver The receiver address on the external chain.
/// @param amount The amount of tokens to withdraw.
Expand Down
13 changes: 10 additions & 3 deletions test/ERC20Custody.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ contract ERC20CustodyTest is Test, IGatewayEVMErrors, IGatewayEVMEvents, IReceiv
address tssAddress;
address foo;
RevertContext revertContext;
MessageContext arbitraryCallMessageContext = MessageContext({ sender: address(0) });
MessageContext arbitraryCallMessageContext =
MessageContext({ sender: address(0), asset: address(0), amount: 0 });

error EnforcedPause();
error NotWhitelisted();
Expand Down Expand Up @@ -349,11 +350,17 @@ contract ERC20CustodyTest is Test, IGatewayEVMErrors, IGatewayEVMEvents, IReceiv
uint256 balanceBeforeCustody = token.balanceOf(address(custody));

vm.expectEmit(true, true, true, true, address(receiver));
emit ReceivedOnCall(sender, message);
emit ReceivedOnCallV2(sender, address(token), amount, message);
vm.expectEmit(true, true, true, true, address(custody));
emit WithdrawnAndCalled(address(receiver), address(token), amount, message);
vm.prank(tssAddress);
custody.withdrawAndCall(MessageContext({ sender: sender }), address(receiver), address(token), amount, message);
custody.withdrawAndCall(
MessageContext({ sender: sender, asset: address(token), amount: amount }),
address(receiver),
address(token),
amount,
message
);

// Verify that the tokens were not transferred to the destination address
uint256 balanceAfter = token.balanceOf(destination);
Expand Down
40 changes: 36 additions & 4 deletions test/GatewayEVM.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ contract GatewayEVMTest is Test, IGatewayEVMErrors, IGatewayEVMEvents, IReceiver
address foo;
RevertOptions revertOptions;
RevertContext revertContext;
MessageContext arbitraryCallMessageContext = MessageContext({ sender: address(0) });
MessageContext arbitraryCallMessageContext =
MessageContext({ sender: address(0), asset: address(0), amount: 0 });

error EnforcedPause();
error AccessControlUnauthorizedAccount(address account, bytes32 neededRole);
Expand Down Expand Up @@ -186,14 +187,33 @@ contract GatewayEVMTest is Test, IGatewayEVMErrors, IGatewayEVMEvents, IReceiver
gateway.execute(arbitraryCallMessageContext, address(receiver), data);
}

// Test with legacy MessageContext
function testForwardCallToReceiveOnCallUsingAuthCall() public {
vm.expectEmit(true, true, true, true, address(receiver));
address sender = address(0x123);
emit ReceivedOnCall(sender, bytes("1"));
vm.expectEmit(true, true, true, true, address(gateway));
emit Executed(address(receiver), 0, bytes("1"));
vm.prank(tssAddress);
gateway.execute(MessageContext({ sender: sender }), address(receiver), bytes("1"));
gateway.execute(
MessageContext({ sender: sender, asset: address(0), amount: 0 }), address(receiver), bytes("1")
);
}

// Test with new MessageContext
function testForwardCallToReceiveOnCallUsingAuthCallV2() public {
vm.expectEmit(true, true, true, true, address(receiver));
address sender = address(0x123);
address asset = address(token);
uint256 amount = 100;
bytes memory data = bytes("1");
emit ReceivedOnCallV2(sender, asset, amount, data);
vm.expectEmit(true, true, true, true, address(gateway));
emit Executed(address(receiver), 0, data);
vm.prank(tssAddress);
gateway.execute(
MessageContext({ sender: sender, asset: asset, amount: amount }), address(receiver), bytes("1")
);
}

function testForwardCallToReceiveNonPayableFailsIfSenderIsNotTSS() public {
Expand All @@ -219,7 +239,9 @@ contract GatewayEVMTest is Test, IGatewayEVMErrors, IGatewayEVMEvents, IReceiver

vm.prank(owner);
vm.expectRevert(abi.encodeWithSelector(AccessControlUnauthorizedAccount.selector, owner, TSS_ROLE));
gateway.execute(MessageContext({ sender: address(0x123) }), address(receiver), data);
gateway.execute(
MessageContext({ sender: address(0x123), asset: address(0), amount: 0 }), address(receiver), data
);
}

function testForwardCallToReceivePayable() public {
Expand Down Expand Up @@ -262,6 +284,16 @@ contract GatewayEVMTest is Test, IGatewayEVMErrors, IGatewayEVMEvents, IReceiver
gateway.execute(arbitraryCallMessageContext, address(receiver), data);
}

function testForwardCallToReceiveOnCallV2Fails() public {
bytes memory data = abi.encodeWithSignature(
"onCall((address,address,uint256),bytes)", address(123), address(456), 100, bytes("")
);

vm.prank(tssAddress);
vm.expectRevert(NotAllowedToCallOnCall.selector);
gateway.execute(arbitraryCallMessageContext, address(receiver), data);
}

function testForwardCallToReceiveOnRevertFails() public {
bytes memory data = abi.encodeWithSignature("onRevert((address,address,uint256,bytes))");

Expand All @@ -283,7 +315,7 @@ contract GatewayEVMTest is Test, IGatewayEVMErrors, IGatewayEVMEvents, IReceiver

vm.prank(tssAddress);
vm.expectRevert(ZeroAddress.selector);
gateway.execute(MessageContext({ sender: address(0x123) }), address(0), data);
gateway.execute(MessageContext({ sender: address(0x123), asset: address(0), amount: 0 }), address(0), data);
}

function testForwardCallToReceiveNoParamsTogglePause() public {
Expand Down
3 changes: 2 additions & 1 deletion test/GatewayEVMZEVM.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ contract GatewayEVMZEVMTest is
address ownerEVM;
address destination;
address tssAddress;
MessageContext arbitraryCallMessageContext = MessageContext({ sender: address(0) });
MessageContext arbitraryCallMessageContext =
MessageContext({ sender: address(0), asset: address(0), amount: 0 });

// zevm
address payable proxyZEVM;
Expand Down
Loading
Loading