Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 104 additions & 79 deletions src/main/java/org/codarama/redlock4j/RedlockCountDownLatch.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,14 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;

/**
* A distributed countdown latch that allows one or more threads to wait until a set of operations being performed in
Expand Down Expand Up @@ -169,106 +173,138 @@ public boolean await(long timeout, TimeUnit unit) throws InterruptedException {

/**
* Decrements the count of the latch, releasing all waiting threads if the count reaches zero.
*
*
* <p>
* If the current count is greater than zero then it is decremented. If the new count is zero then all waiting
* threads are re-enabled for thread scheduling purposes.
* </p>
*
*
* <p>
* If the current count equals zero then nothing happens.
* </p>
*/
public void countDown() {
// Decrement the count atomically using Redis DECR
int successfulNodes = 0;
long newCount = -1;
int quorum = config.getQuorum();
CountDownLatch quorumLatch = new CountDownLatch(1);
AtomicInteger successCount = new AtomicInteger(0);

// Execute atomic decrement + conditional publish on all nodes in parallel
for (RedisDriver driver : redisDrivers) {
try {
// Atomically decrement the count
long count = driver.decr(latchKey);
newCount = count;
successfulNodes++;

logger.debug("Decremented latch {} count to {} on {}", latchKey, count, driver.getIdentifier());
} catch (Exception e) {
logger.debug("Failed to decrement latch count on {}: {}", driver.getIdentifier(), e.getMessage());
}
CompletableFuture.runAsync(() -> {
try {
// Atomic: decrement and publish if zero in single Lua script
long count = driver.decrAndPublishIfZero(latchKey, channelKey, "zero");
logger.debug("Decremented latch {} count to {} on {}", latchKey, count, driver.getIdentifier());
if (successCount.incrementAndGet() >= quorum) {
quorumLatch.countDown(); // Signal quorum reached
}
} catch (Exception e) {
logger.debug("Failed to decrement latch count on {}: {}", driver.getIdentifier(), e.getMessage());
}
});
}

if (successfulNodes >= config.getQuorum()) {
// Wait for quorum (not all nodes)
try {
quorumLatch.await();
logger.debug("Successfully decremented latch {} count on quorum", latchKey);

// If count reached zero, publish notification
if (newCount <= 0) {
publishZeroNotification();
}
} else {
logger.warn("Failed to decrement latch {} count on quorum of nodes", latchKey);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
logger.warn("Interrupted while decrementing latch {}", latchKey);
}
}

/**
* Returns the current count.
*
*
* <p>
* This method is typically used for debugging and testing purposes.
* </p>
*
*
* @return the current count
*/
public long getCount() {
// Use Redis GET to retrieve the current count
int successfulReads = 0;
long totalCount = 0;
int quorum = config.getQuorum();
CountDownLatch quorumLatch = new CountDownLatch(1);
List<Long> results = new ArrayList<>();

// Execute GET on all nodes in parallel
for (RedisDriver driver : redisDrivers) {
try {
String countStr = driver.get(latchKey);
if (countStr != null) {
long count = Long.parseLong(countStr);
totalCount += count;
successfulReads++;
CompletableFuture.runAsync(() -> {
try {
String countStr = driver.get(latchKey);
if (countStr != null) {
synchronized (results) {
results.add(Long.parseLong(countStr));
if (results.size() >= quorum) {
quorumLatch.countDown(); // Signal quorum reached
}
}
}
} catch (Exception e) {
logger.debug("Failed to read latch count from {}: {}", driver.getIdentifier(), e.getMessage());
}
} catch (Exception e) {
logger.debug("Failed to read latch count from {}: {}", driver.getIdentifier(), e.getMessage());
}
});
}

if (successfulReads >= config.getQuorum()) {
// Return average count (simple approach, could use median for better accuracy)
long avgCount = totalCount / successfulReads;
return Math.max(0, avgCount); // Never return negative
// Wait for quorum (not all nodes)
try {
quorumLatch.await();
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
logger.warn("Interrupted while reading latch {} count", latchKey);
return 0;
}

synchronized (results) {
if (results.size() >= quorum) {
// Return average count
long totalCount = results.stream().mapToLong(Long::longValue).sum();
long avgCount = totalCount / results.size();
return Math.max(0, avgCount); // Never return negative
}
}

logger.warn("Failed to read latch {} count from quorum of nodes", latchKey);
return 0; // Conservative fallback - assume completed
}

/**
* Initializes the latch count in Redis.
* Initializes the latch count in Redis using parallel operations with early quorum return.
*/
private void initializeLatch(int count) {
String countValue = String.valueOf(count);
int successfulNodes = 0;
long expirationMs = config.getDefaultLockTimeoutMs() * 10;
int quorum = config.getQuorum();

// Latch to signal when quorum is reached
CountDownLatch quorumLatch = new CountDownLatch(1);
AtomicInteger successCount = new AtomicInteger(0);

// Execute initialization on all nodes in parallel
List<CompletableFuture<Boolean>> futures = new ArrayList<>(redisDrivers.size());
for (RedisDriver driver : redisDrivers) {
try {
// Use setex to initialize with a long expiration (10x lock timeout)
driver.setex(latchKey, countValue, config.getDefaultLockTimeoutMs() * 10);
successfulNodes++;
} catch (Exception e) {
logger.warn("Failed to initialize latch on {}: {}", driver.getIdentifier(), e.getMessage());
}
futures.add(CompletableFuture.supplyAsync(() -> {
try {
driver.setex(latchKey, countValue, expirationMs);
if (successCount.incrementAndGet() >= quorum) {
quorumLatch.countDown(); // Signal quorum reached
}
return true;
} catch (Exception e) {
logger.warn("Failed to initialize latch on {}: {}", driver.getIdentifier(), e.getMessage());
return false;
}
}));
}

if (successfulNodes < config.getQuorum()) {
logger.warn("Failed to initialize latch {} on quorum of nodes (only {} of {} succeeded)", latchKey,
successfulNodes, redisDrivers.size());
} else {
logger.debug("Successfully initialized latch {} with count {} on {} nodes", latchKey, count,
successfulNodes);
// Wait for quorum (not all nodes)
try {
quorumLatch.await();
logger.debug("Successfully initialized latch {} with count {} on quorum", latchKey, count);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
logger.warn("Interrupted while initializing latch {}", latchKey);
}
}

Expand Down Expand Up @@ -309,29 +345,14 @@ public void onError(Throwable error) {
}
}

/**
* Publishes a notification that the latch has reached zero.
*/
private void publishZeroNotification() {
for (RedisDriver driver : redisDrivers) {
try {
long subscribers = driver.publish(channelKey, "zero");
logger.debug("Published zero notification for latch {} to {} subscribers on {}", latchKey, subscribers,
driver.getIdentifier());
} catch (Exception e) {
logger.debug("Failed to publish zero notification on {}: {}", driver.getIdentifier(), e.getMessage());
}
}
}

/**
* Resets the latch to its initial count.
*
*
* <p>
* <b>Warning:</b> This is not part of the standard CountDownLatch API and should be used with caution. It's
* provided for scenarios where you need to reuse a latch.
* </p>
*
*
* <p>
* This operation is not atomic and may lead to race conditions if called while other threads are waiting or
* counting down.
Expand All @@ -340,14 +361,18 @@ private void publishZeroNotification() {
public void reset() {
logger.debug("Resetting latch {} to initial count {}", latchKey, initialCount);

// Delete the existing latch using DEL
// Delete the existing latch using parallel DEL
List<CompletableFuture<Void>> futures = new ArrayList<>(redisDrivers.size());
for (RedisDriver driver : redisDrivers) {
try {
driver.del(latchKey);
} catch (Exception e) {
logger.warn("Failed to delete latch on {}: {}", driver.getIdentifier(), e.getMessage());
}
futures.add(CompletableFuture.runAsync(() -> {
try {
driver.del(latchKey);
} catch (Exception e) {
logger.warn("Failed to delete latch on {}: {}", driver.getIdentifier(), e.getMessage());
}
}));
}
CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])).join();

// Reset the local latch
localLatch = new CountDownLatch(1);
Expand Down
14 changes: 14 additions & 0 deletions src/main/java/org/codarama/redlock4j/driver/JedisRedisDriver.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ public class JedisRedisDriver implements RedisDriver {
private static final String SET_IF_VALUE_MATCHES_SCRIPT = "if redis.call('get', KEYS[1]) == ARGV[1] then "
+ " return redis.call('set', KEYS[1], ARGV[2], 'PX', ARGV[3]) " + "else " + " return nil " + "end";

private static final String DECR_AND_PUBLISH_IF_ZERO_SCRIPT = "local v = redis.call('decr', KEYS[1]); "
+ "if v <= 0 then redis.call('publish', KEYS[2], ARGV[1]) end; " + "return v";

/**
* Strategy for CAS/CAD operations.
*/
Expand Down Expand Up @@ -291,6 +294,17 @@ public long decr(String key) throws RedisDriverException {
}
}

@Override
public long decrAndPublishIfZero(String key, String channel, String message) throws RedisDriverException {
try (Jedis jedis = jedisPool.getResource()) {
Object result = jedis.eval(DECR_AND_PUBLISH_IF_ZERO_SCRIPT, java.util.Arrays.asList(key, channel),
Collections.singletonList(message));
return result != null ? ((Number) result).longValue() : 0;
} catch (JedisException e) {
throw new RedisDriverException("Failed to execute DECR_AND_PUBLISH script on " + identifier, e);
}
}

@Override
public String get(String key) throws RedisDriverException {
try (Jedis jedis = jedisPool.getResource()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ public class LettuceRedisDriver implements RedisDriver {
private static final String SET_IF_VALUE_MATCHES_SCRIPT = "if redis.call('get', KEYS[1]) == ARGV[1] then "
+ " return redis.call('set', KEYS[1], ARGV[2], 'PX', ARGV[3]) " + "else " + " return nil " + "end";

private static final String DECR_AND_PUBLISH_IF_ZERO_SCRIPT = "local v = redis.call('decr', KEYS[1]); "
+ "if v <= 0 then redis.call('publish', KEYS[2], ARGV[1]) end; " + "return v";

/**
* Strategy for CAS/CAD operations.
*/
Expand Down Expand Up @@ -315,6 +318,17 @@ public long decr(String key) throws RedisDriverException {
}
}

@Override
public long decrAndPublishIfZero(String key, String channel, String message) throws RedisDriverException {
try {
Object result = commands.eval(DECR_AND_PUBLISH_IF_ZERO_SCRIPT, io.lettuce.core.ScriptOutputType.INTEGER,
new String[]{key, channel}, message);
return result != null ? ((Number) result).longValue() : 0;
} catch (Exception e) {
throw new RedisDriverException("Failed to execute DECR_AND_PUBLISH script on " + identifier, e);
}
}

@Override
public String get(String key) throws RedisDriverException {
try {
Expand Down
16 changes: 16 additions & 0 deletions src/main/java/org/codarama/redlock4j/driver/RedisDriver.java
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,22 @@ boolean setIfValueMatches(String key, String newValue, String expectedCurrentVal
*/
long decr(String key) throws RedisDriverException;

/**
* Atomically decrements the value of a key and publishes a message to a channel if the new value is zero or less.
* This combines DECR and conditional PUBLISH into a single atomic operation.
*
* @param key
* the key to decrement
* @param channel
* the channel to publish to if count reaches zero
* @param message
* the message to publish
* @return the value after decrementing
* @throws RedisDriverException
* if there's an error communicating with Redis
*/
long decrAndPublishIfZero(String key, String channel, String message) throws RedisDriverException;

/**
* Gets the value of a key.
*
Expand Down
Loading
Loading