diff --git a/src/allocators/HybridAllocator.sol b/src/allocators/HybridAllocator.sol index 1fbdd5a..7a29fbc 100644 --- a/src/allocators/HybridAllocator.sol +++ b/src/allocators/HybridAllocator.sol @@ -11,6 +11,9 @@ import {AllocatorLib as AL} from './lib/AllocatorLib.sol'; import {IAllocator} from '@uniswap/the-compact/interfaces/IAllocator.sol'; import {IOnChainAllocation} from '@uniswap/the-compact/interfaces/IOnChainAllocation.sol'; import {ITheCompact} from '@uniswap/the-compact/interfaces/ITheCompact.sol'; + +import {Extsload} from '@uniswap/the-compact/lib/Extsload.sol'; +import {IdLib} from '@uniswap/the-compact/lib/IdLib.sol'; import {IHybridAllocator} from 'src/interfaces/IHybridAllocator.sol'; /// @title HybridAllocator @@ -51,8 +54,33 @@ contract HybridAllocator is IHybridAllocator { revert InvalidSigner(); } _COMPACT = ITheCompact(compact_); - ALLOCATOR_ID = _COMPACT.__registerAllocator(address(this), ''); _COMPACT_DOMAIN_SEPARATOR = _COMPACT.DOMAIN_SEPARATOR(); + try _COMPACT.__registerAllocator(address(this), '') returns (uint96 allocatorId) { + ALLOCATOR_ID = allocatorId; + } catch { + // The Compact does not have a getter function for retrieving the status of allocator registration, + // so we need to calculate it manually. + uint96 allocatorId = IdLib.toAllocatorId(address(this)); + bytes32 allocatorSlot; + assembly ("memory-safe") { + // Identical to the registration logic slot calculation in The Compact: + // let allocatorSlot := or(_ALLOCATOR_BY_ALLOCATOR_ID_SLOT_SEED, allocatorId) + allocatorSlot := or(0x000044036fc77deaed2300000000000000000000000, allocatorId) + } + + bytes32 registeredAllocator = Extsload(compact_).extsload(allocatorSlot); + + assembly ("memory-safe") { + if iszero(eq(registeredAllocator, address())) { + // revert InvalidAllocatorRegistration(registeredAllocator) + mstore(0x00, 0x161ab6ea) + mstore(0x20, registeredAllocator) + revert(0x1c, 0x24) + } + } + + ALLOCATOR_ID = allocatorId; + } signers[signer_] = true; signerCount++; diff --git a/src/allocators/OnChainAllocator.sol b/src/allocators/OnChainAllocator.sol index 9025b36..8874c46 100644 --- a/src/allocators/OnChainAllocator.sol +++ b/src/allocators/OnChainAllocator.sol @@ -11,6 +11,8 @@ import {SafeTransferLib} from '@solady/utils/SafeTransferLib.sol'; import {IAllocator} from '@uniswap/the-compact/interfaces/IAllocator.sol'; import {IOnChainAllocation} from '@uniswap/the-compact/interfaces/IOnChainAllocation.sol'; import {ITheCompact} from '@uniswap/the-compact/interfaces/ITheCompact.sol'; +import {Extsload} from '@uniswap/the-compact/lib/Extsload.sol'; +import {IdLib} from '@uniswap/the-compact/lib/IdLib.sol'; import {Lock} from '@uniswap/the-compact/types/EIP712Types.sol'; /// @title OnChainAllocator @@ -42,7 +44,32 @@ contract OnChainAllocator is IOnChainAllocator { constructor(address compactContract_) { COMPACT_CONTRACT = compactContract_; COMPACT_DOMAIN_SEPARATOR = ITheCompact(COMPACT_CONTRACT).DOMAIN_SEPARATOR(); - ALLOCATOR_ID = ITheCompact(COMPACT_CONTRACT).__registerAllocator(address(this), ''); + try ITheCompact(COMPACT_CONTRACT).__registerAllocator(address(this), '') returns (uint96 allocatorId) { + ALLOCATOR_ID = allocatorId; + } catch { + // The Compact does not have a getter function for retrieving the status of allocator registration, + // so we need to calculate it manually. + uint96 allocatorId = IdLib.toAllocatorId(address(this)); + bytes32 allocatorSlot; + assembly ("memory-safe") { + // Identical to the registration logic slot calculation in The Compact: + // let allocatorSlot := or(_ALLOCATOR_BY_ALLOCATOR_ID_SLOT_SEED, allocatorId) + allocatorSlot := or(0x000044036fc77deaed2300000000000000000000000, allocatorId) + } + + bytes32 registeredAllocator = Extsload(COMPACT_CONTRACT).extsload(allocatorSlot); + + assembly ("memory-safe") { + if iszero(eq(registeredAllocator, address())) { + // revert InvalidAllocatorRegistration(registeredAllocator) + mstore(0x00, 0x161ab6ea) + mstore(0x20, registeredAllocator) + revert(0x1c, 0x24) + } + } + + ALLOCATOR_ID = allocatorId; + } } /// @inheritdoc IOnChainAllocator diff --git a/src/interfaces/IHybridAllocator.sol b/src/interfaces/IHybridAllocator.sol index 8065157..fb2f917 100644 --- a/src/interfaces/IHybridAllocator.sol +++ b/src/interfaces/IHybridAllocator.sol @@ -7,6 +7,7 @@ import {IOnChainAllocation} from '@uniswap/the-compact/interfaces/IOnChainAlloca /// @notice Interface for hybrid allocators supporting both on-chain and off-chain authorization mechanisms /// @dev Combines direct token deposit functionality with signature-based off-chain allocation authorization interface IHybridAllocator is IOnChainAllocation { + error InvalidAllocatorRegistration(address alreadyRegisteredAllocator); error Unsupported(); error InvalidIds(); error InvalidAllocatorId(uint96 allocatorId, uint96 expectedAllocatorId); diff --git a/src/interfaces/IOnChainAllocator.sol b/src/interfaces/IOnChainAllocator.sol index 823639e..540543a 100644 --- a/src/interfaces/IOnChainAllocator.sol +++ b/src/interfaces/IOnChainAllocator.sol @@ -14,6 +14,9 @@ interface IOnChainAllocator is IOnChainAllocation { bytes32 claimHash; } + /// @notice Thrown if the allocator is not successfully registered + error InvalidAllocatorRegistration(address alreadyRegisteredAllocator); + /// @notice Thrown if the caller is invalid error InvalidCaller(address caller, address expected); diff --git a/test/HybridAllocator.t.sol b/test/HybridAllocator.t.sol index 5687280..77ff900 100644 --- a/test/HybridAllocator.t.sol +++ b/test/HybridAllocator.t.sol @@ -21,6 +21,12 @@ import {IHybridAllocator} from 'src/interfaces/IHybridAllocator.sol'; import {ERC20Mock} from 'src/test/ERC20Mock.sol'; import {OnChainAllocationCaller} from 'src/test/OnChainAllocationCaller.sol'; +contract HybridAllocatorFactory { + function deploy(bytes32 salt, address compact, address signer) external returns (address) { + return address(new HybridAllocator{salt: salt}(compact, signer)); + } +} + contract HybridAllocatorTest is Test, TestHelper { TheCompact compact; address arbiter; @@ -1117,4 +1123,27 @@ contract HybridAllocatorTest is Test, TestHelper { assertFalse(allocator.signers(newSigner)); assertTrue(allocator.signers(newSigner2)); } + + function test_constructor_allowsPreRegisteredAllocator_create2() public { + HybridAllocatorFactory factory = new HybridAllocatorFactory(); + + bytes32 salt = keccak256('hybrid-allocator-pre-registered'); + bytes memory initCode = + abi.encodePacked(type(HybridAllocator).creationCode, abi.encode(address(compact), signer)); + bytes32 initCodeHash = keccak256(initCode); + + address expected = + address(uint160(uint256(keccak256(abi.encodePacked(bytes1(0xff), address(factory), salt, initCodeHash))))); + + bytes memory proof = abi.encodePacked(bytes1(0xff), address(factory), salt, initCodeHash); + + uint96 preId = compact.__registerAllocator(expected, proof); + assertEq(_toAllocatorId(expected), preId); + + address deployed = HybridAllocatorFactory(address(factory)).deploy(salt, address(compact), signer); + assertEq(deployed, expected); + + HybridAllocator newAllocator = HybridAllocator(deployed); + assertEq(newAllocator.ALLOCATOR_ID(), _toAllocatorId(deployed)); + } } diff --git a/test/OnChainAllocator.t.sol b/test/OnChainAllocator.t.sol index 2094640..fe9c027 100644 --- a/test/OnChainAllocator.t.sol +++ b/test/OnChainAllocator.t.sol @@ -15,7 +15,6 @@ import {ERC20Mock} from 'src/test/ERC20Mock.sol'; import {TheCompact} from '@uniswap/the-compact/TheCompact.sol'; import {IAllocator} from '@uniswap/the-compact/interfaces/IAllocator.sol'; -import {ITheCompact} from '@uniswap/the-compact/interfaces/ITheCompact.sol'; import {IOnChainAllocation} from '@uniswap/the-compact/interfaces/IOnChainAllocation.sol'; import {OnChainAllocator} from 'src/allocators/OnChainAllocator.sol'; @@ -27,10 +26,17 @@ import {ERC6909} from '@solady/tokens/ERC6909.sol'; import {ResetPeriod} from '@uniswap/the-compact/types/ResetPeriod.sol'; import {Scope} from '@uniswap/the-compact/types/Scope.sol'; +import {IdLib} from '@uniswap/the-compact/lib/IdLib.sol'; import {AllocatorLib} from 'src/allocators/lib/AllocatorLib.sol'; import {OnChainAllocationCaller} from 'src/test/OnChainAllocationCaller.sol'; import {TestHelper} from 'test/util/TestHelper.sol'; +contract OnChainAllocatorFactory { + function deploy(bytes32 salt, address compact) external returns (address) { + return address(new OnChainAllocator{salt: salt}(compact)); + } +} + contract OnChainAllocatorTest is Test, TestHelper { TheCompact internal compact; OnChainAllocator internal allocator; @@ -1372,6 +1378,59 @@ contract OnChainAllocatorTest is Test, TestHelper { ); } + function test_constructor_allowsPreRegisteredAllocator_create2() public { + OnChainAllocatorFactory factory = new OnChainAllocatorFactory(); + + bytes32 salt = keccak256('onchain-allocator-pre-registered'); + bytes memory initCode = abi.encodePacked(type(OnChainAllocator).creationCode, abi.encode(address(compact))); + bytes32 initCodeHash = keccak256(initCode); + + address expected = vm.computeCreate2Address(salt, initCodeHash, address(factory)); + + bytes memory proof = abi.encodePacked(bytes1(0xff), address(factory), salt, initCodeHash); + + uint96 preId = compact.__registerAllocator(expected, proof); + assertEq(_toAllocatorId(expected), preId); + + address deployed = OnChainAllocatorFactory(address(factory)).deploy(salt, address(compact)); + assertEq(deployed, expected); + + OnChainAllocator newAllocator = OnChainAllocator(deployed); + assertEq(newAllocator.ALLOCATOR_ID(), _toAllocatorId(deployed)); + } + + function test_constructor_reverts_with_already_registered_allocator_in_case_of_address_collision() public { + // Deploy Create2 factory + OnChainAllocatorFactory factory = new OnChainAllocatorFactory(); + + // Precalculate the allocator's address + bytes32 salt = keccak256('onchain-allocator-pre-registered'); + bytes memory initCode = abi.encodePacked(type(OnChainAllocator).creationCode, abi.encode(address(compact))); + bytes32 initCodeHash = keccak256(initCode); + + address expected = vm.computeCreate2Address(salt, initCodeHash, address(factory)); + + // Store a different registered allocator address (simulate an address collision) + address differentRegisteredAllocator = address(1); + + uint96 allocatorId = IdLib.toAllocatorId(expected); + bytes32 allocatorSlot; + assembly ("memory-safe") { + allocatorSlot := or(0x000044036fc77deaed2300000000000000000000000, allocatorId) + } + + vm.store(address(compact), allocatorSlot, bytes32(uint256(uint160(differentRegisteredAllocator)))); + + // Try to deploy the allocator + // Should revert with the InvalidAllocatorRegistration error, since The Compact has a different address stored in the allocator's slot + vm.expectRevert( + abi.encodeWithSelector( + IOnChainAllocator.InvalidAllocatorRegistration.selector, differentRegisteredAllocator + ) + ); + OnChainAllocatorFactory(address(factory)).deploy(salt, address(compact)); + } + function test_allocateAndRegister_tokensImmediatelyAllocated() public { uint256 amount1 = 1 ether; uint256 amount2 = 2 ether;