Skip to content

[IA 01b] - fix: Remove PaymentType constraint from RecurringCollector #1190

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: ma/indexing-payments-003-trust-01-collect-on-cancel-by-payer
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,8 @@ contract RecurringCollector is EIP712, GraphDirectory, Authorizable, IRecurringC
* @dev Caller must be the data service the RCA was issued to.
*/
function collect(IGraphPayments.PaymentTypes paymentType, bytes calldata data) external returns (uint256) {
require(
paymentType == IGraphPayments.PaymentTypes.IndexingFee,
RecurringCollectorInvalidPaymentType(paymentType)
);
try this.decodeCollectData(data) returns (CollectParams memory collectParams) {
return _collect(collectParams);
return _collect(paymentType, collectParams);
} catch {
revert RecurringCollectorInvalidCollectData(data);
}
Expand Down Expand Up @@ -269,10 +265,14 @@ contract RecurringCollector is EIP712, GraphDirectory, Authorizable, IRecurringC
*
* Emits {PaymentCollected} and {RCACollected} events.
*
* @param _paymentType The type of payment to collect
* @param _params The decoded parameters for the collection
* @return The amount of tokens collected
*/
function _collect(CollectParams memory _params) private returns (uint256) {
function _collect(
IGraphPayments.PaymentTypes _paymentType,
CollectParams memory _params
) private returns (uint256) {
AgreementData storage agreement = _getAgreementStorage(_params.agreementId);
require(
_isCollectable(agreement),
Expand All @@ -289,7 +289,7 @@ contract RecurringCollector is EIP712, GraphDirectory, Authorizable, IRecurringC
tokensToCollect = _requireValidCollect(agreement, _params.agreementId, _params.tokens);

_graphPaymentsEscrow().collect(
IGraphPayments.PaymentTypes.IndexingFee,
_paymentType,
agreement.payer,
agreement.serviceProvider,
tokensToCollect,
Expand All @@ -301,7 +301,7 @@ contract RecurringCollector is EIP712, GraphDirectory, Authorizable, IRecurringC
agreement.lastCollectionAt = uint64(block.timestamp);

emit PaymentCollected(
IGraphPayments.PaymentTypes.IndexingFee,
_paymentType,
_params.collectionId,
agreement.payer,
agreement.serviceProvider,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
// SPDX-License-Identifier: MIT
pragma solidity 0.8.27;

import { IGraphPayments } from "../../../../contracts/interfaces/IGraphPayments.sol";

import { IRecurringCollector } from "../../../../contracts/interfaces/IRecurringCollector.sol";

import { RecurringCollectorSharedTest } from "./shared.t.sol";
Expand All @@ -14,32 +12,14 @@ contract RecurringCollectorCollectTest is RecurringCollectorSharedTest {

/* solhint-disable graph/func-name-mixedcase */

function test_Collect_Revert_WhenInvalidPaymentType(uint8 unboundedPaymentType, bytes memory data) public {
IGraphPayments.PaymentTypes paymentType = IGraphPayments.PaymentTypes(
bound(
unboundedPaymentType,
uint256(type(IGraphPayments.PaymentTypes).min),
uint256(type(IGraphPayments.PaymentTypes).max)
)
);
vm.assume(paymentType != IGraphPayments.PaymentTypes.IndexingFee);

bytes memory expectedErr = abi.encodeWithSelector(
IRecurringCollector.RecurringCollectorInvalidPaymentType.selector,
paymentType
);
vm.expectRevert(expectedErr);
_recurringCollector.collect(paymentType, data);
}

function test_Collect_Revert_WhenInvalidData(address caller, bytes memory data) public {
function test_Collect_Revert_WhenInvalidData(address caller, uint8 unboundedPaymentType, bytes memory data) public {
bytes memory expectedErr = abi.encodeWithSelector(
IRecurringCollector.RecurringCollectorInvalidCollectData.selector,
data
);
vm.expectRevert(expectedErr);
vm.prank(caller);
_recurringCollector.collect(IGraphPayments.PaymentTypes.IndexingFee, data);
_recurringCollector.collect(_paymentType(unboundedPaymentType), data);
}

function test_Collect_Revert_WhenCallerNotDataService(
Expand All @@ -61,7 +41,7 @@ contract RecurringCollectorCollectTest is RecurringCollectorSharedTest {
);
vm.expectRevert(expectedErr);
vm.prank(notDataService);
_recurringCollector.collect(IGraphPayments.PaymentTypes.IndexingFee, data);
_recurringCollector.collect(_paymentType(fuzzy.unboundedPaymentType), data);
}

function test_Collect_Revert_WhenUnknownAgreement(FuzzyTestCollect memory fuzzy, address dataService) public {
Expand All @@ -74,7 +54,7 @@ contract RecurringCollectorCollectTest is RecurringCollectorSharedTest {
);
vm.expectRevert(expectedErr);
vm.prank(dataService);
_recurringCollector.collect(IGraphPayments.PaymentTypes.IndexingFee, data);
_recurringCollector.collect(_paymentType(fuzzy.unboundedPaymentType), data);
}

function test_Collect_Revert_WhenCanceledAgreementByServiceProvider(FuzzyTestCollect calldata fuzzy) public {
Expand All @@ -97,7 +77,7 @@ contract RecurringCollectorCollectTest is RecurringCollectorSharedTest {
);
vm.expectRevert(expectedErr);
vm.prank(accepted.rca.dataService);
_recurringCollector.collect(IGraphPayments.PaymentTypes.IndexingFee, data);
_recurringCollector.collect(_paymentType(fuzzy.unboundedPaymentType), data);
}

function test_Collect_Revert_WhenCollectingTooSoon(
Expand All @@ -116,7 +96,7 @@ contract RecurringCollectorCollectTest is RecurringCollectorSharedTest {
)
);
vm.prank(accepted.rca.dataService);
_recurringCollector.collect(IGraphPayments.PaymentTypes.IndexingFee, data);
_recurringCollector.collect(_paymentType(fuzzy.unboundedPaymentType), data);

uint256 collectionSeconds = boundSkip(unboundedCollectionSeconds, 1, accepted.rca.minSecondsPerCollection - 1);
skip(collectionSeconds);
Expand All @@ -136,7 +116,7 @@ contract RecurringCollectorCollectTest is RecurringCollectorSharedTest {
);
vm.expectRevert(expectedErr);
vm.prank(accepted.rca.dataService);
_recurringCollector.collect(IGraphPayments.PaymentTypes.IndexingFee, data);
_recurringCollector.collect(_paymentType(fuzzy.unboundedPaymentType), data);
}

function test_Collect_Revert_WhenCollectingTooLate(
Expand All @@ -163,7 +143,7 @@ contract RecurringCollectorCollectTest is RecurringCollectorSharedTest {
)
);
vm.prank(accepted.rca.dataService);
_recurringCollector.collect(IGraphPayments.PaymentTypes.IndexingFee, data);
_recurringCollector.collect(_paymentType(fuzzy.unboundedPaymentType), data);

// skip beyond collectable time but still within the agreement endsAt
uint256 collectionSeconds = boundSkip(
Expand All @@ -189,7 +169,7 @@ contract RecurringCollectorCollectTest is RecurringCollectorSharedTest {
);
vm.expectRevert(expectedErr);
vm.prank(accepted.rca.dataService);
_recurringCollector.collect(IGraphPayments.PaymentTypes.IndexingFee, data);
_recurringCollector.collect(_paymentType(fuzzy.unboundedPaymentType), data);
}

function test_Collect_OK_WhenCollectingTooMuch(
Expand Down Expand Up @@ -219,7 +199,7 @@ contract RecurringCollectorCollectTest is RecurringCollectorSharedTest {
)
);
vm.prank(accepted.rca.dataService);
_recurringCollector.collect(IGraphPayments.PaymentTypes.IndexingFee, initialData);
_recurringCollector.collect(_paymentType(fuzzy.unboundedPaymentType), initialData);
}

// skip to collectable time
Expand All @@ -240,7 +220,7 @@ contract RecurringCollectorCollectTest is RecurringCollectorSharedTest {
);
bytes memory data = _generateCollectData(collectParams);
vm.prank(accepted.rca.dataService);
uint256 collected = _recurringCollector.collect(IGraphPayments.PaymentTypes.IndexingFee, data);
uint256 collected = _recurringCollector.collect(_paymentType(fuzzy.unboundedPaymentType), data);
assertEq(collected, maxTokens);
}

Expand All @@ -258,9 +238,9 @@ contract RecurringCollectorCollectTest is RecurringCollectorSharedTest {
unboundedTokens
);
skip(collectionSeconds);
_expectCollectCallAndEmit(accepted.rca, fuzzy.collectParams, tokens);
_expectCollectCallAndEmit(accepted.rca, _paymentType(fuzzy.unboundedPaymentType), fuzzy.collectParams, tokens);
vm.prank(accepted.rca.dataService);
uint256 collected = _recurringCollector.collect(IGraphPayments.PaymentTypes.IndexingFee, data);
uint256 collected = _recurringCollector.collect(_paymentType(fuzzy.unboundedPaymentType), data);
assertEq(collected, tokens);
}
/* solhint-enable graph/func-name-mixedcase */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import { RecurringCollectorHelper } from "./RecurringCollectorHelper.t.sol";
contract RecurringCollectorSharedTest is Test, Bounder {
struct FuzzyTestCollect {
FuzzyTestAccept fuzzyTestAccept;
uint8 unboundedPaymentType;
IRecurringCollector.CollectParams collectParams;
}

Expand Down Expand Up @@ -106,6 +107,7 @@ contract RecurringCollectorSharedTest is Test, Bounder {

function _expectCollectCallAndEmit(
IRecurringCollector.RecurringCollectionAgreement memory _rca,
IGraphPayments.PaymentTypes __paymentType,
IRecurringCollector.CollectParams memory _fuzzyParams,
uint256 _tokens
) internal {
Expand All @@ -114,7 +116,7 @@ contract RecurringCollectorSharedTest is Test, Bounder {
abi.encodeCall(
_paymentsEscrow.collect,
(
IGraphPayments.PaymentTypes.IndexingFee,
__paymentType,
_rca.payer,
_rca.serviceProvider,
_tokens,
Expand All @@ -126,7 +128,7 @@ contract RecurringCollectorSharedTest is Test, Bounder {
);
vm.expectEmit(address(_recurringCollector));
emit IPaymentsCollector.PaymentCollected(
IGraphPayments.PaymentTypes.IndexingFee,
__paymentType,
_fuzzyParams.collectionId,
_rca.payer,
_rca.serviceProvider,
Expand Down Expand Up @@ -193,4 +195,15 @@ contract RecurringCollectorSharedTest is Test, Bounder {
bound(_seed, 0, uint256(IRecurringCollector.CancelAgreementBy.Payer))
);
}

function _paymentType(uint8 _unboundedPaymentType) internal pure returns (IGraphPayments.PaymentTypes) {
return
IGraphPayments.PaymentTypes(
bound(
_unboundedPaymentType,
uint256(type(IGraphPayments.PaymentTypes).min),
uint256(type(IGraphPayments.PaymentTypes).max)
)
);
}
}