Skip to content

Commit 28e3b52

Browse files
committed
fix usdc as cc fee
1 parent d037886 commit 28e3b52

File tree

7 files changed

+365
-13
lines changed

7 files changed

+365
-13
lines changed

contracts/ProtocolVault.sol

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ contract ProtocolVault is Ownable2StepUpgradeable, UUPSUpgradeable, PausableUpgr
6868
bool public lpWhitelistEnabled;
6969
uint256 public lpWhitelistEndTime;
7070
mapping(address => bool) public lpWhitelist;
71+
/// @dev cross chain fee for claims
72+
uint256 public claimCrossChainFee;
7173

7274
//receive native token
7375
receive() external payable {}
@@ -267,7 +269,12 @@ contract ProtocolVault is Ownable2StepUpgradeable, UUPSUpgradeable, PausableUpgr
267269
emit DepositToStrategy(periodId, vaultId, receiver, amount, dexNonce);
268270
}
269271

270-
function updateUnClaimed(uint256 periodId, ClaimInfo[] memory userClaimInfos) external onlyVaultCrossChainManager {
272+
function updateUnClaimed(uint256 periodId, uint256 ccFee, ClaimInfo[] memory userClaimInfos)
273+
external
274+
onlyVaultCrossChainManager
275+
{
276+
claimCrossChainFee += ccFee;
277+
271278
for (uint256 i = 0; i < userClaimInfos.length; i++) {
272279
bytes32 userId = userClaimInfos[i].accountId == bytes32(0)
273280
? userClaimInfos[i].strategyProviderId
@@ -280,11 +287,16 @@ contract ProtocolVault is Ownable2StepUpgradeable, UUPSUpgradeable, PausableUpgr
280287
emit UnClaimedUpdated(periodId, vaultId, userClaimInfos);
281288
}
282289

283-
/// @notice withdraw native token
284-
/// @param to the receiver address
285-
/// @param amount the amount to withdraw
286-
function withdrawNativeToken(address payable to, uint256 amount) external onlyOwner {
287-
to.sendValue(amount);
290+
function withdrawToken(address token, address to, uint256 amount) external onlyOwner {
291+
if (address(token) != address(0)) {
292+
if (amount > claimCrossChainFee) {
293+
revert NotEnoughCrossChainFee();
294+
}
295+
claimCrossChainFee -= amount;
296+
SafeTransferLib.safeTransfer(ERC20(token), to, amount);
297+
} else {
298+
payable(to).sendValue(amount);
299+
}
288300
}
289301

290302
//--------------------------------------CONFIG--------------------------------------------

contracts/ProtocolVaultLedger.sol

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -629,12 +629,13 @@ contract ProtocolVaultLedger is Ownable2StepUpgradeable, UUPSUpgradeable, IProto
629629
delete userClaimInfo[requestId];
630630
}
631631

632+
uint256 actualTotalFee = feePerUser * userClaimInfos.length;
632633
// Cross chain message
633634
StrategyVaultCCMessage memory message = StrategyVaultCCMessage({
634635
payloadType: PayloadType.UPDATE_USER_CLAIM,
635636
srcChainId: block.chainid,
636637
dstChainId: chainId,
637-
payload: abi.encode(periodId, userClaimInfos)
638+
payload: abi.encode(periodId, actualTotalFee, userClaimInfos)
638639
});
639640

640641
// Send cross-chain message
@@ -758,7 +759,7 @@ contract ProtocolVaultLedger is Ownable2StepUpgradeable, UUPSUpgradeable, IProto
758759
/// @param periodId period id
759760
/// @param requestIds request Id array
760761
/// @return nativeFee native token fee required
761-
function quoteClaim(uint256 chainId, uint256 periodId, bytes32[] memory requestIds)
762+
function quoteClaim(uint256 chainId, uint256 periodId, uint256 ccFee, bytes32[] memory requestIds)
762763
external
763764
view
764765
returns (uint256 nativeFee)
@@ -776,7 +777,7 @@ contract ProtocolVaultLedger is Ownable2StepUpgradeable, UUPSUpgradeable, IProto
776777
payloadType: PayloadType.UPDATE_USER_CLAIM,
777778
srcChainId: block.chainid,
778779
dstChainId: chainId,
779-
payload: abi.encode(periodId, userClaimInfos)
780+
payload: abi.encode(periodId, ccFee, userClaimInfos)
780781
});
781782

782783
// Return only the native fee

contracts/VaultCrossChainManager.sol

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,8 @@ contract VaultCrossChainManager is OAppUpgradeable, IVaultCrossChainManager {
142142
IProtocolVault(vault).depositToStrategy(periodId, vault, assetsDistribution.assets);
143143
} else if (payloadType == PayloadType.UPDATE_USER_CLAIM) {
144144
//Decode the payload
145-
(uint256 periodId, ClaimInfo[] memory userClaims) = abi.decode(payload, (uint256, ClaimInfo[]));
145+
(uint256 periodId, uint256 ccFee, ClaimInfo[] memory userClaims) =
146+
abi.decode(payload, (uint256, uint256, ClaimInfo[]));
146147

147148
//Convert the amount
148149
uint256 dstChainId = strategyVaultCCmessage.dstChainId;
@@ -153,7 +154,7 @@ contract VaultCrossChainManager is OAppUpgradeable, IVaultCrossChainManager {
153154
}
154155
}
155156
//Call Protocol Vault
156-
IProtocolVault(vault).updateUnClaimed(periodId, userClaims);
157+
IProtocolVault(vault).updateUnClaimed(periodId, ccFee, userClaims);
157158
} else {
158159
revert InvalidPayloadType();
159160
}

contracts/interfaces/IProtocolVault.sol

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,11 @@ interface IProtocolVault {
4141
error NotAllowedStrategyProvider(bytes32 strategyProviderId);
4242
error InvalidClaimToken(address token);
4343
error ZeroAmount();
44+
error NotEnoughCrossChainFee();
4445

4546
function deposit(DepositParams memory depositParams) external payable;
4647
function withdraw(WithdrawParams memory withdrawParams) external payable;
4748
function claim(ClaimParams memory claimParams) external;
4849
function depositToStrategy(uint256 periodId, address receiver, uint256 amount) external;
49-
function updateUnClaimed(uint256 periodId, ClaimInfo[] memory userClaimInfos) external;
50+
function updateUnClaimed(uint256 periodId, uint256 ccFee, ClaimInfo[] memory userClaimInfos) external;
5051
}

contracts/interfaces/IProtocolVaultLedger.sol

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ interface IProtocolVaultLedger {
125125
view
126126
returns (AccountState[] memory);
127127

128-
function quoteClaim(uint256 chainId, uint256 periodId, bytes32[] memory requestIds)
128+
function quoteClaim(uint256 chainId, uint256 periodId, uint256 ccFee, bytes32[] memory requestIds)
129129
external
130130
view
131131
returns (uint256 nativeFee);

test/ProtocolVault.t.sol

Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ contract TestProtocolVault is Base {
2424
error EnforcedPause();
2525
error NotEnoughUnclaimedAssets(uint256 amount);
2626
error VaultClosed();
27+
error NotEnoughCrossChainFee();
2728

2829
function setUp() public override {
2930
super.setUp();
@@ -624,4 +625,230 @@ contract TestProtocolVault is Base {
624625
protocolVault.deposit{value: nativeFee}(depositParams);
625626
vm.stopPrank();
626627
}
628+
629+
function testWithdrawCrossChainFee() public {
630+
// Setup: Create some cross chain fees
631+
uint256 periodId;
632+
bytes32 vaultId;
633+
uint256 amount = 100e6;
634+
uint256 ccFee = 20; // Total fee
635+
636+
bytes32[] memory requestIds = new bytes32[](2);
637+
requestIds[0] = keccak256(abi.encode(0));
638+
requestIds[1] = keccak256(abi.encode(1));
639+
640+
bytes memory signature = _getUpdateUnclaimedSignature(evmChainId, periodId, vaultId, requestIds);
641+
vm.deal(address(bVaultCrossChainManager), 10 ether);
642+
643+
// Add claim info on ledger
644+
svLedger.setLpClaimInfo(requestIds[0], userA_id, amount);
645+
svLedger.setLpClaimInfo(requestIds[1], userB_id, amount);
646+
647+
// Process unclaimed update with fees
648+
vm.prank(operator);
649+
svLedger.updateUnclaimed(evmChainId, periodId, ccFee, vaultId, requestIds, signature);
650+
verifyPackets(srcEid, address(aVaultCrossChainManager));
651+
652+
// Calculate expected accumulated fee: feePerUser = 20/2 = 10, actualTotalFee = 10*2 = 20
653+
uint256 expectedAccumulatedFee = (ccFee / requestIds.length) * requestIds.length;
654+
assertEq(protocolVault.claimCrossChainFee(), expectedAccumulatedFee, "Cross chain fee should be accumulated correctly");
655+
656+
// Mint tokens to vault to represent collected fees
657+
mockToken.mint(address(protocolVault), expectedAccumulatedFee);
658+
659+
// Owner withdraws part of the fees
660+
uint256 withdrawAmount = 15;
661+
uint256 ownerBalanceBefore = mockToken.balanceOf(owner);
662+
663+
vm.prank(owner);
664+
protocolVault.withdrawToken(address(mockToken), owner, withdrawAmount);
665+
666+
// Verify withdrawal
667+
uint256 ownerBalanceAfter = mockToken.balanceOf(owner);
668+
assertEq(ownerBalanceAfter - ownerBalanceBefore, withdrawAmount, "Owner should receive withdrawn amount");
669+
assertEq(protocolVault.claimCrossChainFee(), expectedAccumulatedFee - withdrawAmount, "Cross chain fee should be reduced after withdrawal");
670+
}
671+
672+
function testWithdrawCrossChainFeeExceedsAvailable() public {
673+
// Setup: Create some cross chain fees
674+
uint256 periodId;
675+
bytes32 vaultId;
676+
uint256 amount = 100e6;
677+
uint256 ccFee = 10;
678+
679+
bytes32[] memory requestIds = new bytes32[](1);
680+
requestIds[0] = keccak256(abi.encode(0));
681+
682+
bytes memory signature = _getUpdateUnclaimedSignature(evmChainId, periodId, vaultId, requestIds);
683+
vm.deal(address(bVaultCrossChainManager), 10 ether);
684+
685+
svLedger.setLpClaimInfo(requestIds[0], userA_id, amount);
686+
687+
vm.prank(operator);
688+
svLedger.updateUnclaimed(evmChainId, periodId, ccFee, vaultId, requestIds, signature);
689+
verifyPackets(srcEid, address(aVaultCrossChainManager));
690+
691+
uint256 expectedAccumulatedFee = ccFee; // Only 1 user, so actualTotalFee = ccFee
692+
assertEq(protocolVault.claimCrossChainFee(), expectedAccumulatedFee, "Cross chain fee should be accumulated");
693+
694+
// Try to withdraw more than available
695+
uint256 excessiveWithdrawAmount = expectedAccumulatedFee + 1;
696+
697+
vm.prank(owner);
698+
vm.expectRevert(abi.encodeWithSelector(NotEnoughCrossChainFee.selector));
699+
protocolVault.withdrawToken(address(mockToken), owner, excessiveWithdrawAmount);
700+
}
701+
702+
function testMultipleFeeAccumulationAndWithdrawal() public {
703+
uint256 periodId;
704+
bytes32 vaultId;
705+
uint256 amount = 100e6;
706+
707+
// First batch of fees
708+
uint256 ccFee1 = 12;
709+
bytes32[] memory requestIds1 = new bytes32[](3);
710+
requestIds1[0] = keccak256(abi.encode(0));
711+
requestIds1[1] = keccak256(abi.encode(1));
712+
requestIds1[2] = keccak256(abi.encode(2));
713+
714+
bytes memory signature1 = _getUpdateUnclaimedSignature(evmChainId, periodId, vaultId, requestIds1);
715+
vm.deal(address(bVaultCrossChainManager), 10 ether);
716+
717+
for (uint256 i = 0; i < requestIds1.length; i++) {
718+
svLedger.setLpClaimInfo(requestIds1[i], userA_id, amount);
719+
}
720+
721+
vm.prank(operator);
722+
svLedger.updateUnclaimed(evmChainId, periodId, ccFee1, vaultId, requestIds1, signature1);
723+
verifyPackets(srcEid, address(aVaultCrossChainManager));
724+
725+
uint256 expectedFee1 = (ccFee1 / requestIds1.length) * requestIds1.length; // 4 * 3 = 12
726+
assertEq(protocolVault.claimCrossChainFee(), expectedFee1, "First fee accumulation should be correct");
727+
728+
// Second batch of fees
729+
uint256 ccFee2 = 21;
730+
bytes32[] memory requestIds2 = new bytes32[](2);
731+
requestIds2[0] = keccak256(abi.encode(3));
732+
requestIds2[1] = keccak256(abi.encode(4));
733+
734+
bytes memory signature2 = _getUpdateUnclaimedSignature(evmChainId, periodId, vaultId, requestIds2);
735+
736+
for (uint256 i = 0; i < requestIds2.length; i++) {
737+
svLedger.setLpClaimInfo(requestIds2[i], userB_id, amount);
738+
}
739+
740+
vm.prank(operator);
741+
svLedger.updateUnclaimed(evmChainId, periodId, ccFee2, vaultId, requestIds2, signature2);
742+
verifyPackets(srcEid, address(aVaultCrossChainManager));
743+
744+
uint256 expectedFee2 = (ccFee2 / requestIds2.length) * requestIds2.length; // 10 * 2 = 20
745+
uint256 totalExpectedFee = expectedFee1 + expectedFee2; // 12 + 20 = 32
746+
assertEq(protocolVault.claimCrossChainFee(), totalExpectedFee, "Total fee accumulation should be correct");
747+
748+
// Mint tokens to represent collected fees
749+
mockToken.mint(address(protocolVault), totalExpectedFee);
750+
751+
// Owner withdraws all fees
752+
uint256 ownerBalanceBefore = mockToken.balanceOf(owner);
753+
754+
vm.prank(owner);
755+
protocolVault.withdrawToken(address(mockToken), owner, totalExpectedFee);
756+
757+
// Verify complete withdrawal
758+
uint256 ownerBalanceAfter = mockToken.balanceOf(owner);
759+
assertEq(ownerBalanceAfter - ownerBalanceBefore, totalExpectedFee, "Owner should receive all fees");
760+
assertEq(protocolVault.claimCrossChainFee(), 0, "Cross chain fee should be zero after full withdrawal");
761+
}
762+
763+
function testWithdrawNativeToken() public {
764+
// Send some native tokens to the vault
765+
uint256 nativeAmount = 1 ether;
766+
vm.deal(address(protocolVault), nativeAmount);
767+
768+
uint256 ownerBalanceBefore = owner.balance;
769+
770+
// Withdraw native tokens
771+
vm.prank(owner);
772+
protocolVault.withdrawToken(address(0), owner, nativeAmount);
773+
774+
uint256 ownerBalanceAfter = owner.balance;
775+
assertEq(ownerBalanceAfter - ownerBalanceBefore, nativeAmount, "Owner should receive native tokens");
776+
assertEq(address(protocolVault).balance, 0, "Vault should have no native tokens left");
777+
}
778+
779+
function testSingleUserCcFeeCollectionAndWithdrawal() public {
780+
// Setup: Create cross chain fee scenario with single user
781+
uint256 periodId;
782+
bytes32 vaultId;
783+
uint256 userAsset = 1000e6; // User has 1000 USDC to claim
784+
uint256 ccFee = 50; // 50 USDC cross chain fee
785+
786+
bytes32[] memory requestIds = new bytes32[](1);
787+
requestIds[0] = keccak256(abi.encode("singleUser"));
788+
789+
bytes memory signature = _getUpdateUnclaimedSignature(evmChainId, periodId, vaultId, requestIds);
790+
vm.deal(address(bVaultCrossChainManager), 10 ether);
791+
792+
// Add claim info for single user
793+
svLedger.setLpClaimInfo(requestIds[0], userA_id, userAsset);
794+
795+
// Process unclaimed update with fees
796+
vm.prank(operator);
797+
svLedger.updateUnclaimed(evmChainId, periodId, ccFee, vaultId, requestIds, signature);
798+
verifyPackets(srcEid, address(aVaultCrossChainManager));
799+
800+
// For single user: feePerUser = ccFee / 1 = 50, actualTotalFee = 50 * 1 = 50
801+
uint256 expectedFeePerUser = ccFee; // 50
802+
uint256 expectedActualTotalFee = ccFee; // 50 (same as ccFee for single user)
803+
uint256 expectedUserAssets = userAsset - expectedFeePerUser; // 1000 - 50 = 950
804+
805+
// Verify user claim info (asset should be reduced by fee)
806+
UserClaimedInfo memory userClaimedInfo = protocolVault.getUserClaimedInfo(userA_id);
807+
assertEq(userClaimedInfo.unClaimedAssets, expectedUserAssets, "User assets should be reduced by ccFee");
808+
assertEq(userClaimedInfo.requestIds.length, 1, "User should have 1 request ID");
809+
assertEq(userClaimedInfo.requestIds[0], requestIds[0], "Request ID should match");
810+
811+
// Verify protocol vault accumulated the correct fee
812+
assertEq(protocolVault.claimCrossChainFee(), expectedActualTotalFee, "Protocol should collect exactly the ccFee amount");
813+
814+
// Mint tokens to vault to represent the collected fees
815+
mockToken.mint(address(protocolVault), expectedActualTotalFee);
816+
817+
// Owner withdraws the collected fees
818+
uint256 ownerBalanceBefore = mockToken.balanceOf(owner);
819+
820+
vm.prank(owner);
821+
protocolVault.withdrawToken(address(mockToken), owner, expectedActualTotalFee);
822+
823+
// Verify owner received the fees
824+
uint256 ownerBalanceAfter = mockToken.balanceOf(owner);
825+
assertEq(ownerBalanceAfter - ownerBalanceBefore, expectedActualTotalFee, "Owner should receive all collected fees");
826+
827+
// Verify protocol vault fee counter is reset
828+
assertEq(protocolVault.claimCrossChainFee(), 0, "Cross chain fee should be zero after withdrawal");
829+
830+
// Verify the fee amount calculation
831+
assertEq(expectedActualTotalFee, ccFee, "For single user, actualTotalFee should equal original ccFee");
832+
833+
// Additional verification: User can still claim their remaining assets
834+
mockToken.mint(address(protocolVault), expectedUserAssets);
835+
836+
uint256 userBalanceBefore = mockToken.balanceOf(userA);
837+
ClaimParams memory claimParams = ClaimParams({
838+
roleType: RoleType.LP,
839+
token: address(mockToken),
840+
brokerHash: ORDERLY_BROKER
841+
});
842+
843+
vm.prank(userA);
844+
protocolVault.claim(claimParams);
845+
846+
uint256 userBalanceAfter = mockToken.balanceOf(userA);
847+
assertEq(userBalanceAfter - userBalanceBefore, expectedUserAssets, "User should receive assets minus fee");
848+
849+
// Final verification: User claim info should be cleared
850+
userClaimedInfo = protocolVault.getUserClaimedInfo(userA_id);
851+
assertEq(userClaimedInfo.unClaimedAssets, 0, "User unclaimed assets should be zero after claim");
852+
assertEq(userClaimedInfo.requestIds.length, 0, "User request IDs should be cleared after claim");
853+
}
627854
}

0 commit comments

Comments
 (0)