diff --git a/src/allocators/HybridAllocator.sol b/src/allocators/HybridAllocator.sol index c207ed5..2fe739a 100644 --- a/src/allocators/HybridAllocator.sol +++ b/src/allocators/HybridAllocator.sol @@ -159,6 +159,7 @@ contract HybridAllocator is IHybridAllocator { bytes32 typehash, bytes32 witness ) public payable returns (bytes32, uint256[] memory, uint256) { + recipient = AL.getRecipient(recipient); idsAndAmounts = _actualIdsAndAmounts(idsAndAmounts); (bytes32 claimHash, uint256[] memory registeredAmounts) = _COMPACT.batchDepositAndRegisterFor{value: msg.value}( diff --git a/src/allocators/OnChainAllocator.sol b/src/allocators/OnChainAllocator.sol index ea30abf..d27f49a 100644 --- a/src/allocators/OnChainAllocator.sol +++ b/src/allocators/OnChainAllocator.sol @@ -133,6 +133,7 @@ contract OnChainAllocator is IOnChainAllocator { revert InvalidExpiration(expires, block.timestamp); } + recipient = AL.getRecipient(recipient); nonce = _getAndUpdateNonce(msg.sender, recipient); uint256[2][] memory idsAndAmounts = new uint256[2][](commitments.length); diff --git a/src/allocators/lib/AllocatorLib.sol b/src/allocators/lib/AllocatorLib.sol index c8b4046..a68d796 100644 --- a/src/allocators/lib/AllocatorLib.sol +++ b/src/allocators/lib/AllocatorLib.sol @@ -261,6 +261,13 @@ library AllocatorLib { return ecrecover(digest, v, r, s); } + function getRecipient(address recipient) internal view returns (address) { + assembly ("memory-safe") { + recipient := xor(recipient, mul(caller(), iszero(recipient))) + } + return recipient; + } + function splitId(uint256 id) internal pure returns (uint96 allocatorId_, address token_) { return (splitAllocatorId(id), splitToken(id)); } diff --git a/test/HybridAllocator.t.sol b/test/HybridAllocator.t.sol index 8aa6727..4310c0f 100644 --- a/test/HybridAllocator.t.sol +++ b/test/HybridAllocator.t.sol @@ -526,6 +526,35 @@ contract HybridAllocatorTest is Test, TestHelper { assertTrue(allocator.isClaimAuthorized(createdHash, address(0), address(0), 0, 0, new uint256[2][](0), '')); } + function test_allocateAndRegister_success_emptyRecipientBecomesCaller() public { + uint256[2][] memory idsAndAmounts = new uint256[2][](1); + idsAndAmounts[0][0] = _toId(Scope.Multichain, ResetPeriod.TenMinutes, address(allocator), address(usdc)); + idsAndAmounts[0][1] = defaultAmount; + + // Provide tokens + vm.prank(user); + usdc.transfer(address(allocator), defaultAmount); + assertEq(usdc.balanceOf(address(allocator)), defaultAmount); + + vm.prank(user); + (bytes32 claimHash, uint256[] memory registeredAmounts, uint256 nonce) = allocator.allocateAndRegister( + address(0), /* allocate for an empty recipient */ + idsAndAmounts, + arbiter, + defaultExpiration, + BATCH_COMPACT_TYPEHASH, + '' + ); + + // Ensure the allocation happened for the caller (user), not address(0) + assertTrue(compact.isRegistered(user, claimHash, BATCH_COMPACT_TYPEHASH)); + assertTrue(allocator.isClaimAuthorized(claimHash, address(0), address(0), 0, 0, new uint256[2][](0), '')); + assertEq(registeredAmounts[0], defaultAmount); + assertEq(usdc.balanceOf(address(compact)), defaultAmount); + assertEq(compact.balanceOf(address(user), idsAndAmounts[0][0]), defaultAmount); + assertEq(nonce, 1); + } + function test_allocateAndRegister_slot() public { uint256[2][] memory idsAndAmounts = new uint256[2][](1); idsAndAmounts[0][0] = _toId(Scope.Multichain, ResetPeriod.TenMinutes, address(allocator), address(0)); diff --git a/test/OnChainAllocator.t.sol b/test/OnChainAllocator.t.sol index d6dfe25..da45976 100644 --- a/test/OnChainAllocator.t.sol +++ b/test/OnChainAllocator.t.sol @@ -1505,4 +1505,36 @@ contract OnChainAllocatorTest is Test, TestHelper { vm.expectRevert(abi.encodeWithSelector(IOnChainAllocator.InsufficientBalance.selector, recipient, id2, 0, 1)); allocator.attest(address(this), recipient, address(this), id2, 1); } + + function test_allocateAndRegister_emptyRecipientBecomesCaller() public { + Lock[] memory commitments = new Lock[](1); + commitments[0] = _makeLock(address(usdc), defaultAmount); + + usdc.mint(address(allocator), defaultAmount); + + vm.prank(caller); + (bytes32 claimHash, uint256[] memory registeredAmounts, uint256 nonce) = allocator.allocateAndRegister( + address(0), /* allocate for an empty recipient */ + commitments, + arbiter, + defaultExpiration, + BATCH_COMPACT_TYPEHASH, + bytes32(0) + ); + + uint256[2][] memory idsAndAmounts = new uint256[2][](1); + idsAndAmounts[0][0] = _toId(Scope.Multichain, ResetPeriod.TenMinutes, address(allocator), address(usdc)); + idsAndAmounts[0][1] = defaultAmount; + + assertEq(nonce, _composeNonceUint(caller, 1)); + assertEq(registeredAmounts.length, 1); + assertEq(registeredAmounts[0], defaultAmount); + // Ensure the allocation happened for the caller, not address(0) + assertEq(ERC6909(address(compact)).balanceOf(caller, idsAndAmounts[0][0]), defaultAmount); + assertTrue(allocator.isClaimAuthorized(claimHash, arbiter, caller, nonce, defaultExpiration, idsAndAmounts, '')); + assertTrue(compact.isRegistered(caller, claimHash, BATCH_COMPACT_TYPEHASH)); + bytes32 claimHashRecreated = + _createClaimHash(caller, arbiter, nonce, defaultExpiration, commitments, bytes32(0)); + assertEq(claimHashRecreated, claimHash); + } }