diff --git a/packages/horizon/contracts/payments/collectors/RecurringCollector.sol b/packages/horizon/contracts/payments/collectors/RecurringCollector.sol index e1225f6fa..662dc549f 100644 --- a/packages/horizon/contracts/payments/collectors/RecurringCollector.sol +++ b/packages/horizon/contracts/payments/collectors/RecurringCollector.sol @@ -62,12 +62,8 @@ contract RecurringCollector is EIP712, GraphDirectory, Authorizable, IRecurringC * @dev Caller must be the data service the RCA was issued to. */ function collect(IGraphPayments.PaymentTypes paymentType, bytes calldata data) external returns (uint256) { - require( - paymentType == IGraphPayments.PaymentTypes.IndexingFee, - RecurringCollectorInvalidPaymentType(paymentType) - ); try this.decodeCollectData(data) returns (CollectParams memory collectParams) { - return _collect(collectParams); + return _collect(paymentType, collectParams); } catch { revert RecurringCollectorInvalidCollectData(data); } @@ -269,10 +265,14 @@ contract RecurringCollector is EIP712, GraphDirectory, Authorizable, IRecurringC * * Emits {PaymentCollected} and {RCACollected} events. * + * @param _paymentType The type of payment to collect * @param _params The decoded parameters for the collection * @return The amount of tokens collected */ - function _collect(CollectParams memory _params) private returns (uint256) { + function _collect( + IGraphPayments.PaymentTypes _paymentType, + CollectParams memory _params + ) private returns (uint256) { AgreementData storage agreement = _getAgreementStorage(_params.agreementId); require( _isCollectable(agreement), @@ -289,7 +289,7 @@ contract RecurringCollector is EIP712, GraphDirectory, Authorizable, IRecurringC tokensToCollect = _requireValidCollect(agreement, _params.agreementId, _params.tokens); _graphPaymentsEscrow().collect( - IGraphPayments.PaymentTypes.IndexingFee, + _paymentType, agreement.payer, agreement.serviceProvider, tokensToCollect, @@ -301,7 +301,7 @@ contract RecurringCollector is EIP712, GraphDirectory, Authorizable, IRecurringC agreement.lastCollectionAt = uint64(block.timestamp); emit PaymentCollected( - IGraphPayments.PaymentTypes.IndexingFee, + _paymentType, _params.collectionId, agreement.payer, agreement.serviceProvider, diff --git a/packages/horizon/test/unit/payments/recurring-collector/collect.t.sol b/packages/horizon/test/unit/payments/recurring-collector/collect.t.sol index 8942c21bf..4382fa852 100644 --- a/packages/horizon/test/unit/payments/recurring-collector/collect.t.sol +++ b/packages/horizon/test/unit/payments/recurring-collector/collect.t.sol @@ -1,8 +1,6 @@ // SPDX-License-Identifier: MIT pragma solidity 0.8.27; -import { IGraphPayments } from "../../../../contracts/interfaces/IGraphPayments.sol"; - import { IRecurringCollector } from "../../../../contracts/interfaces/IRecurringCollector.sol"; import { RecurringCollectorSharedTest } from "./shared.t.sol"; @@ -14,32 +12,14 @@ contract RecurringCollectorCollectTest is RecurringCollectorSharedTest { /* solhint-disable graph/func-name-mixedcase */ - function test_Collect_Revert_WhenInvalidPaymentType(uint8 unboundedPaymentType, bytes memory data) public { - IGraphPayments.PaymentTypes paymentType = IGraphPayments.PaymentTypes( - bound( - unboundedPaymentType, - uint256(type(IGraphPayments.PaymentTypes).min), - uint256(type(IGraphPayments.PaymentTypes).max) - ) - ); - vm.assume(paymentType != IGraphPayments.PaymentTypes.IndexingFee); - - bytes memory expectedErr = abi.encodeWithSelector( - IRecurringCollector.RecurringCollectorInvalidPaymentType.selector, - paymentType - ); - vm.expectRevert(expectedErr); - _recurringCollector.collect(paymentType, data); - } - - function test_Collect_Revert_WhenInvalidData(address caller, bytes memory data) public { + function test_Collect_Revert_WhenInvalidData(address caller, uint8 unboundedPaymentType, bytes memory data) public { bytes memory expectedErr = abi.encodeWithSelector( IRecurringCollector.RecurringCollectorInvalidCollectData.selector, data ); vm.expectRevert(expectedErr); vm.prank(caller); - _recurringCollector.collect(IGraphPayments.PaymentTypes.IndexingFee, data); + _recurringCollector.collect(_paymentType(unboundedPaymentType), data); } function test_Collect_Revert_WhenCallerNotDataService( @@ -61,7 +41,7 @@ contract RecurringCollectorCollectTest is RecurringCollectorSharedTest { ); vm.expectRevert(expectedErr); vm.prank(notDataService); - _recurringCollector.collect(IGraphPayments.PaymentTypes.IndexingFee, data); + _recurringCollector.collect(_paymentType(fuzzy.unboundedPaymentType), data); } function test_Collect_Revert_WhenUnknownAgreement(FuzzyTestCollect memory fuzzy, address dataService) public { @@ -74,7 +54,7 @@ contract RecurringCollectorCollectTest is RecurringCollectorSharedTest { ); vm.expectRevert(expectedErr); vm.prank(dataService); - _recurringCollector.collect(IGraphPayments.PaymentTypes.IndexingFee, data); + _recurringCollector.collect(_paymentType(fuzzy.unboundedPaymentType), data); } function test_Collect_Revert_WhenCanceledAgreementByServiceProvider(FuzzyTestCollect calldata fuzzy) public { @@ -97,7 +77,7 @@ contract RecurringCollectorCollectTest is RecurringCollectorSharedTest { ); vm.expectRevert(expectedErr); vm.prank(accepted.rca.dataService); - _recurringCollector.collect(IGraphPayments.PaymentTypes.IndexingFee, data); + _recurringCollector.collect(_paymentType(fuzzy.unboundedPaymentType), data); } function test_Collect_Revert_WhenCollectingTooSoon( @@ -116,7 +96,7 @@ contract RecurringCollectorCollectTest is RecurringCollectorSharedTest { ) ); vm.prank(accepted.rca.dataService); - _recurringCollector.collect(IGraphPayments.PaymentTypes.IndexingFee, data); + _recurringCollector.collect(_paymentType(fuzzy.unboundedPaymentType), data); uint256 collectionSeconds = boundSkip(unboundedCollectionSeconds, 1, accepted.rca.minSecondsPerCollection - 1); skip(collectionSeconds); @@ -136,7 +116,7 @@ contract RecurringCollectorCollectTest is RecurringCollectorSharedTest { ); vm.expectRevert(expectedErr); vm.prank(accepted.rca.dataService); - _recurringCollector.collect(IGraphPayments.PaymentTypes.IndexingFee, data); + _recurringCollector.collect(_paymentType(fuzzy.unboundedPaymentType), data); } function test_Collect_Revert_WhenCollectingTooLate( @@ -163,7 +143,7 @@ contract RecurringCollectorCollectTest is RecurringCollectorSharedTest { ) ); vm.prank(accepted.rca.dataService); - _recurringCollector.collect(IGraphPayments.PaymentTypes.IndexingFee, data); + _recurringCollector.collect(_paymentType(fuzzy.unboundedPaymentType), data); // skip beyond collectable time but still within the agreement endsAt uint256 collectionSeconds = boundSkip( @@ -189,7 +169,7 @@ contract RecurringCollectorCollectTest is RecurringCollectorSharedTest { ); vm.expectRevert(expectedErr); vm.prank(accepted.rca.dataService); - _recurringCollector.collect(IGraphPayments.PaymentTypes.IndexingFee, data); + _recurringCollector.collect(_paymentType(fuzzy.unboundedPaymentType), data); } function test_Collect_OK_WhenCollectingTooMuch( @@ -219,7 +199,7 @@ contract RecurringCollectorCollectTest is RecurringCollectorSharedTest { ) ); vm.prank(accepted.rca.dataService); - _recurringCollector.collect(IGraphPayments.PaymentTypes.IndexingFee, initialData); + _recurringCollector.collect(_paymentType(fuzzy.unboundedPaymentType), initialData); } // skip to collectable time @@ -240,7 +220,7 @@ contract RecurringCollectorCollectTest is RecurringCollectorSharedTest { ); bytes memory data = _generateCollectData(collectParams); vm.prank(accepted.rca.dataService); - uint256 collected = _recurringCollector.collect(IGraphPayments.PaymentTypes.IndexingFee, data); + uint256 collected = _recurringCollector.collect(_paymentType(fuzzy.unboundedPaymentType), data); assertEq(collected, maxTokens); } @@ -258,9 +238,9 @@ contract RecurringCollectorCollectTest is RecurringCollectorSharedTest { unboundedTokens ); skip(collectionSeconds); - _expectCollectCallAndEmit(accepted.rca, fuzzy.collectParams, tokens); + _expectCollectCallAndEmit(accepted.rca, _paymentType(fuzzy.unboundedPaymentType), fuzzy.collectParams, tokens); vm.prank(accepted.rca.dataService); - uint256 collected = _recurringCollector.collect(IGraphPayments.PaymentTypes.IndexingFee, data); + uint256 collected = _recurringCollector.collect(_paymentType(fuzzy.unboundedPaymentType), data); assertEq(collected, tokens); } /* solhint-enable graph/func-name-mixedcase */ diff --git a/packages/horizon/test/unit/payments/recurring-collector/shared.t.sol b/packages/horizon/test/unit/payments/recurring-collector/shared.t.sol index 8dd270b2f..2dbd0e1a0 100644 --- a/packages/horizon/test/unit/payments/recurring-collector/shared.t.sol +++ b/packages/horizon/test/unit/payments/recurring-collector/shared.t.sol @@ -16,6 +16,7 @@ import { RecurringCollectorHelper } from "./RecurringCollectorHelper.t.sol"; contract RecurringCollectorSharedTest is Test, Bounder { struct FuzzyTestCollect { FuzzyTestAccept fuzzyTestAccept; + uint8 unboundedPaymentType; IRecurringCollector.CollectParams collectParams; } @@ -106,6 +107,7 @@ contract RecurringCollectorSharedTest is Test, Bounder { function _expectCollectCallAndEmit( IRecurringCollector.RecurringCollectionAgreement memory _rca, + IGraphPayments.PaymentTypes __paymentType, IRecurringCollector.CollectParams memory _fuzzyParams, uint256 _tokens ) internal { @@ -114,7 +116,7 @@ contract RecurringCollectorSharedTest is Test, Bounder { abi.encodeCall( _paymentsEscrow.collect, ( - IGraphPayments.PaymentTypes.IndexingFee, + __paymentType, _rca.payer, _rca.serviceProvider, _tokens, @@ -126,7 +128,7 @@ contract RecurringCollectorSharedTest is Test, Bounder { ); vm.expectEmit(address(_recurringCollector)); emit IPaymentsCollector.PaymentCollected( - IGraphPayments.PaymentTypes.IndexingFee, + __paymentType, _fuzzyParams.collectionId, _rca.payer, _rca.serviceProvider, @@ -193,4 +195,15 @@ contract RecurringCollectorSharedTest is Test, Bounder { bound(_seed, 0, uint256(IRecurringCollector.CancelAgreementBy.Payer)) ); } + + function _paymentType(uint8 _unboundedPaymentType) internal pure returns (IGraphPayments.PaymentTypes) { + return + IGraphPayments.PaymentTypes( + bound( + _unboundedPaymentType, + uint256(type(IGraphPayments.PaymentTypes).min), + uint256(type(IGraphPayments.PaymentTypes).max) + ) + ); + } }