Skip to content
21 changes: 21 additions & 0 deletions .env.sample
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,24 @@ AGENT_PRIVATE_KEY=your_private_key_here_without_0x_prefix
# */30 * * * * - every 30 minutes
REBALANCE_SCHEDULE=*/5 * * * *

# =============================================================================
# Protection & Safety
# =============================================================================

# Maximum gas price allowed for transactions (in Gwei)
# For L2s like Base, 0.1 to 1.0 is common.
MAX_GAS_PRICE_GWEI=1.0

# Swap slippage tolerance in basis points (1 bps = 0.01%)
# 50 = 0.5%, 100 = 1.0%
SWAP_SLIPPAGE_BPS=50

# Rebalance threshold (0.1 = 10% deviation)
# Rebalance only triggers if the current portfolio distribution deviates from target by more than this.
DEVIATION_THRESHOLD=0.1

# StateView contract address (used to fetch position liquidity)
# Mainnet: 0x7ffe42c4a5deea5b0fec41c94c136cf115597227
# Sepolia: 0xe1dd9c3fa50edb962e442f60dfbc432e24537e4c
# Base Sepolia: 0x571291b572ed32ce6751a2cb2486ebee8defb9b4
STATEVIEW_CONTRACT_ADDR=0x7ffe42c4a5deea5b0fec41c94c136cf115597227
32 changes: 32 additions & 0 deletions cmd/rebalance/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"log/slog"
"os"
"os/signal"
"strconv"
"syscall"

"github.com/ethereum/go-ethereum/common"
Expand Down Expand Up @@ -60,14 +61,45 @@ func main() {
})

// Initialize agent service
stateViewAddr := common.HexToAddress(os.Getenv("STATEVIEW_CONTRACT_ADDR"))
agentSvc := agent.New(
vaultSource,
nil, // TODO: strategySvc
sgn,
ethClient,
logger,
stateViewAddr,
)

// Load protection settings from env
swapSlippage := os.Getenv("SWAP_SLIPPAGE_BPS")
maxGasPrice := os.Getenv("MAX_GAS_PRICE_GWEI")
devThreshold := os.Getenv("DEVIATION_THRESHOLD")

sSlippage := int64(50) // default 0.5%
if swapSlippage != "" {
if val, err := strconv.ParseInt(swapSlippage, 10, 64); err == nil {
sSlippage = val
}
}

mGasPrice := 1.0 // default 1.0 Gwei
if maxGasPrice != "" {
if val, err := strconv.ParseFloat(maxGasPrice, 64); err == nil {
mGasPrice = val
}
}

dThreshold := 0.1 // default 10%
if devThreshold != "" {
if val, err := strconv.ParseFloat(devThreshold, 64); err == nil {
dThreshold = val
}
}

agentSvc.SetProtectionSettings(sSlippage, mGasPrice)
agentSvc.SetDeviationThreshold(dThreshold)

ctx, cancel := context.WithCancel(context.Background())
exitCode := 0

Expand Down
205 changes: 172 additions & 33 deletions internal/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@ import (
"context"
"fmt"
"log/slog"
"math/big"

"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/ethclient"

"remora/internal/allocation"
"remora/internal/coverage"
"remora/internal/liquidity/poolid"
"remora/internal/signer"
"remora/internal/strategy"
"remora/internal/vault"
Expand All @@ -29,13 +33,16 @@ type RebalanceResult struct {

// Service is the main agent orchestrator.
type Service struct {
vaultSource VaultSource
strategySvc strategy.Service
signer *signer.Signer
ethClient *ethclient.Client
logger *slog.Logger
vaultSource VaultSource
strategySvc strategy.Service
signer *signer.Signer
ethClient *ethclient.Client
logger *slog.Logger
stateViewAddr common.Address

deviationThreshold float64
swapSlippageBps int64
maxGasPriceGwei float64
}

// New creates a new agent service.
Expand All @@ -45,17 +52,32 @@ func New(
signer *signer.Signer,
ethClient *ethclient.Client,
logger *slog.Logger,
stateViewAddr common.Address,
) *Service {
return &Service{
vaultSource: vaultSource,
strategySvc: strategySvc,
signer: signer,
ethClient: ethClient,
logger: logger,
deviationThreshold: defaultDeviationThreshold,
stateViewAddr: stateViewAddr,
deviationThreshold: 0.1,
swapSlippageBps: 50, // default: 0.5%
maxGasPriceGwei: 1.0, // default: 1.0 Gwei (suitable for many L2s)
}
}

// SetProtectionSettings updates the protection settings for the service.
func (s *Service) SetProtectionSettings(swapSlippageBps int64, maxGasPriceGwei float64) {
s.swapSlippageBps = swapSlippageBps
s.maxGasPriceGwei = maxGasPriceGwei
}

// SetDeviationThreshold updates the threshold for rebalance decision.
func (s *Service) SetDeviationThreshold(threshold float64) {
s.deviationThreshold = threshold
}

// Run executes one round of rebalance check for all vaults.
func (s *Service) Run(ctx context.Context) ([]RebalanceResult, error) {
addresses, err := s.vaultSource.GetVaultAddresses(ctx)
Expand All @@ -77,7 +99,8 @@ func (s *Service) Run(ctx context.Context) ([]RebalanceResult, error) {

// processVault handles rebalance logic for a single vault.
func (s *Service) processVault(ctx context.Context, vaultAddr common.Address) RebalanceResult {
s.logger.InfoContext(ctx, "processing vault", slog.String("address", vaultAddr.Hex()))

s.logger.Info("processing vault", slog.String("address", vaultAddr.Hex()))

// Step 1: Create vault client
auth, err := s.signer.TransactOpts()
Expand All @@ -91,40 +114,156 @@ func (s *Service) processVault(ctx context.Context, vaultAddr common.Address) Re
}

// Step 2: Get vault state and current positions
// TODO: state, err := vaultClient.GetState(ctx)
// TODO: currentPositions, err := vaultClient.GetPositions(ctx)
_ = vaultClient
state, err := vaultClient.GetState(ctx)
if err != nil {
s.logger.Error("failed to get vault state", slog.Any("error", err))
return RebalanceResult{VaultAddress: vaultAddr, Reason: "get_state_error"}
}

// Step 3: Compute target positions using strategy service
// TODO: targetResult, err := s.computeTargetPositions(ctx, state.PoolKey)
// Convert vault.PoolKey to poolid.PoolKey
liqPoolKey := poolid.PoolKey{
Currency0: state.PoolKey.Currency0.Hex(),
Currency1: state.PoolKey.Currency1.Hex(),
Fee: uint32(state.PoolKey.Fee.Uint64()), //nolint:gosec // fee fits in uint24
TickSpacing: int32(state.PoolKey.TickSpacing.Int64()), //nolint:gosec // tickSpacing fits in int24
Hooks: state.PoolKey.Hooks.Hex(),
}

// Step 4: Calculate deviation between current and target
// TODO: deviation := s.calculateDeviation(currentPositions, targetResult)
computeParams := &strategy.ComputeParams{
PoolKey: liqPoolKey,
BinSizeTicks: 200, // TODO: Configurable
TickRange: 1000, // TODO: Configurable
AlgoConfig: coverage.DefaultConfig(),
}

// Step 5: Check if rebalance is needed
// TODO: if deviation < s.deviationThreshold { return skipped }
targetResult, err := s.strategySvc.ComputeTargetPositions(ctx, computeParams)
if err != nil {
s.logger.Error("failed to compute target", slog.Any("error", err))
return RebalanceResult{VaultAddress: vaultAddr, Reason: "strategy_error"}
}

// Step 6: Execute rebalance
// TODO: err := s.executeRebalance(ctx, vaultClient, targetResult)
// Step 4: Calculate Total Assets (Idle + Invested)
// We need decimals and balances
token0 := state.PoolKey.Currency0
token1 := state.PoolKey.Currency1

return RebalanceResult{
VaultAddress: vaultAddr,
Rebalanced: false,
Reason: "not_implemented",
// TODO: Cache decimals
decimals0, err := s.getTokenDecimals(ctx, token0)
if err != nil {
s.logger.Error("failed to get token0 decimals", slog.Any("error", err))
return RebalanceResult{VaultAddress: vaultAddr, Reason: "token_error"}
}
}
decimals1, err := s.getTokenDecimals(ctx, token1)
if err != nil {
s.logger.Error("failed to get token1 decimals", slog.Any("error", err))
return RebalanceResult{VaultAddress: vaultAddr, Reason: "token_error"}
}

// Get Idle Balances
idle0, err := s.getTokenBalance(ctx, token0, vaultAddr)
if err != nil {
s.logger.Error("failed to get token0 balance", slog.Any("error", err))
return RebalanceResult{VaultAddress: vaultAddr, Reason: "balance_error"}
}
idle1, err := s.getTokenBalance(ctx, token1, vaultAddr)
if err != nil {
s.logger.Error("failed to get token1 balance", slog.Any("error", err))
return RebalanceResult{VaultAddress: vaultAddr, Reason: "balance_error"}
}

// Get Invested Balances (from current positions)
positions, err := vaultClient.GetPositions(ctx)
if err != nil {
s.logger.Error("failed to get positions", slog.Any("error", err))
return RebalanceResult{VaultAddress: vaultAddr, Reason: "get_positions_error"}
}

invested0 := big.NewInt(0)
invested1 := big.NewInt(0)

// We use the Strategy's SqrtPriceX96 to estimate current position value
// Note: accurate value requires getting the real positions info including uncollected fees,
// but here we just estimate principal from liquidity.
for _, pos := range positions {
// Fetch real liquidity from POSM/StateView
liquidity, err := s.getPositionLiquidity(ctx, state.Posm, state.PoolID, pos.TokenID)
if err != nil {
s.logger.Warn("failed to get position liquidity",
slog.String("tokenID", pos.TokenID.String()),
slog.Any("error", err))
continue
}
pos.Liquidity = liquidity

if pos.Liquidity == nil || pos.Liquidity.Sign() == 0 {
continue
}

// =============================================================================
// Private methods to implement
// =============================================================================
// Calculate amounts for this position
// allocation.GetAmount0ForLiquidity needs sqrtPriceX96, sqrtPriceA, sqrtPriceB, liquidity
sqrtPriceAX96 := allocation.TickToSqrtPriceX96(int(pos.TickLower))
sqrtPriceBX96 := allocation.TickToSqrtPriceX96(int(pos.TickUpper))

amt0 := allocation.GetAmount0ForLiquidity(targetResult.SqrtPriceX96, sqrtPriceAX96, sqrtPriceBX96, pos.Liquidity)
amt1 := allocation.GetAmount1ForLiquidity(targetResult.SqrtPriceX96, sqrtPriceAX96, sqrtPriceBX96, pos.Liquidity)

invested0.Add(invested0, amt0)
invested1.Add(invested1, amt1)
}

// computeTargetPositions computes target positions for a vault.
// Flow: PoolKey -> liquidity.GetDistribution -> strategy.ComputeTargetPositions
// func (s *Service) computeTargetPositions(ctx context.Context, poolKey vault.PoolKey) (*strategy.ComputeResult, error)
// Step 4.5: Deviation Check
// Decide if we really need to rebalance
deviation := s.calculateDeviation(positions, targetResult)
s.logger.Info("deviation calculated", slog.Float64("deviation", deviation), slog.Float64("threshold", s.deviationThreshold))

// calculateDeviation calculates deviation between current and target positions.
// func (s *Service) calculateDeviation(current []vault.Position, target *strategy.ComputeResult) float64
if deviation < s.deviationThreshold {
return RebalanceResult{
VaultAddress: vaultAddr,
Rebalanced: false,
Reason: "deviation_below_threshold",
}
}

// Sum total
total0 := new(big.Int).Add(idle0, invested0)
total1 := new(big.Int).Add(idle1, invested1)

// Step 5: Allocate
poolState := allocation.PoolState{
SqrtPriceX96: targetResult.SqrtPriceX96,
CurrentTick: int(targetResult.CurrentTick),
Token0Decimals: int(decimals0),
Token1Decimals: int(decimals1),
}

userFunds := allocation.UserFunds{
Amount0: total0,
Amount1: total1,
}

// executeRebalance executes rebalance transactions.
// Flow: 1. Burn all existing positions 2. Mint new positions
// func (s *Service) executeRebalance(ctx context.Context, client vault.Vault, target *strategy.ComputeResult) error
allocationResult, err := allocation.Allocate(targetResult.Segments, userFunds, poolState, state.SwapAllowed)
if err != nil {
s.logger.Error("failed to allocate", slog.Any("error", err))
return RebalanceResult{VaultAddress: vaultAddr, Reason: "allocation_error"}
}

s.logger.Info("allocation computed",
slog.String("swap_amount", allocationResult.SwapAmount.String()),
slog.Bool("zero_for_one", allocationResult.SwapToken0To1),
slog.Int("new_positions", len(allocationResult.Positions)),
)

// Step 6: Execute rebalance
err = s.executeRebalance(ctx, vaultClient, positions, allocationResult, targetResult.SqrtPriceX96)
if err != nil {
s.logger.Error("failed to execute rebalance", slog.Any("error", err))
return RebalanceResult{VaultAddress: vaultAddr, Reason: "execution_error"}
}

return RebalanceResult{
VaultAddress: vaultAddr,
Rebalanced: true,
Reason: "success",
}
}
Loading
Loading