Skip to content
Home » Implementing a Distributed Countdown Latch

Implementing a Distributed Countdown Latch

Introduction

Before diving into the concept of a distributed countdown latch, it’s essential to first understand what countdown latches are and how they function.

Countdown Latches

Countdown latches are concurrency constructs used in computing to allow one or more threads to wait until a set of operations being performed in other threads completes. A countdown latch is initialized with a given count. The await methods block until the current count reaches zero due to invocations of the countDown() method, after which all waiting threads are released and any subsequent invocations of await return immediately. This is a one-shot phenomenon – the count cannot be reset.

The purpose of a countdown latch is to allow one thread to wait for other threads to finish their tasks. The thread calling await() will wait until the count reaches zero, and only then will it continue its execution. Other threads will invoke countDown() to decrease the count.

A classic use case of a countdown latch is in parallel computing, where multiple threads are spawned, and a certain thread needs to wait until all other threads have finished their tasks.

Distributed Countdown Latch

While countdown latches are highly useful in controlling flow in multithreaded applications running within a single JVM, they have a limitation: they are not inherently designed to work across multiple JVMs or multiple machines. This is where a distributed countdown latch comes in.

A distributed countdown latch has the same functionality as a countdown latch, but it can operate across different machines in a distributed system. It provides a mechanism for threads on different machines to synchronize their activities. The need for a distributed countdown latch arises in a microservices-based architecture, where different services may run on separate JVMs or different physical/virtual machines.

For distributed countdown latches, we need to use a distributed system that can coordinate across different machines. There are various options available like Apache Zookeeper, Redis, Hazelcast, etc. Here, we will use Redis as our distributed system to implement a distributed countdown latch.

Distributed Countdown Latch Implementation

Basic Implementation in Kotlin

Here’s a basic sketch of how a Redis-based CountDownLatch could be implemented using the Jedis library, a simple Redis client for Java/Kotlin:

import redis.clients.jedis.Jedis
import redis.clients.jedis.Transaction

class DistributedCountDownLatch(private val jedis: Jedis, private val latchName: String, private val total: Long) {

    init {
        if (jedis.get(latchName) == null) {
            jedis.set(latchName, total.toString())
        }
    }

    @Synchronized
    fun countDown() {
        jedis.watch(latchName)
        val count = jedis.get(latchName)?.toLongOrNull()
        if (count != null && count > 0) {
            val transaction: Transaction = jedis.multi()
            transaction.decr(latchName)
            transaction.exec()
        } else {
            jedis.unwatch()
        }
    }

    fun await() {
        while (true) {
            if (jedis.get(latchName)?.toLongOrNull() == 0L) {
                break
            }
            Thread.sleep(500)  // delay between each check
        }
    }

    fun getCount(): Long {
        return jedis.get(latchName)?.toLongOrNull() ?: 0L
    }
}
Kotlin

In this implementation, a Redis key is used to store the count. The countDown method uses the Redis WATCH command to watch the key for changes, and a transaction (MULTI/EXEC) to decrement the count. If the watch detects that the value of the key was changed by another client between the WATCH and EXEC, then the transaction is not executed. This prevents race conditions. The await method polls the key until its value is 0.

Implementing distributed synchronization primitives, like a CountDownLatch, can be complex and error-prone, which is why libraries such as Redisson are commonly used.

Implementation using Reddison

Redisson is a Redis client for Java/Kotlin and  provides a way to use Redis data structures in a distributed manner. Redisson takes care of the details of the race conditions and uses a similar WATCH mechanism internally, thus allowing us to write high-level code.

Here’s the Kotlin code:

import org.redisson.Redisson
import org.redisson.api.RedissonClient
import org.redisson.api.RCountDownLatch
import org.redisson.config.Config

class DistributedCountDownLatch(val latchCount: Int, val latchName: String) {

    private val redissonClient: RedissonClient
    private val latch: RCountDownLatch

    init {
        val config = Config()
        config.useSingleServer().setAddress("redis://127.0.0.1:6379")
        redissonClient = Redisson.create(config)
        latch = redissonClient.getCountDownLatch(latchName)
        latch.trySetCount(latchCount)
    }

    fun countDown() {
        latch.countDown()
    }

    fun await() {
        latch.await()
    }

    fun getCount(): Long {
        return latch.count
    }

    fun close() {
        redissonClient.shutdown()
    }
}
Kotlin

To use this DistributedCountDownLatch, you’d create a new instance, specifying the number of ‘latches’ (or ‘waits’) to expect. Other processes in the distributed system could then decrement the count by calling countDown(). A process waiting for all the countdowns to complete would call await() which will block until the count reaches zero.

Using the Latch 

Let’s consider a simple scenario where we have several threads and we want them to wait for each other at a certain point in the execution. The latch is used to ensure that all threads reach a common barrier point before proceeding. In the context of this example, we will create five threads that each perform a countdown on the latch, and one thread that awaits the others to finish.

This is how we can use the class:

import redis.clients.jedis.Jedis
import kotlin.concurrent.thread

fun main() {
    // Initialize jedis
    val jedis = Jedis("localhost")

    // Create a DistributedCountDownLatch with 5 as the total count
    val latch = DistributedCountDownLatch(jedis, "myLatch", 5)

    // Create 5 threads that count down the latch
    for (i in 1..5) {
        thread(start = true) {
            println("Thread $i doing work...")
            Thread.sleep((Math.random() * 1000).toLong())  // Simulate time taken to do work
            println("Thread $i finished work, counting down latch...")
            latch.countDown()
            println("Thread $i finished countdown, latch count is now: ${latch.getCount()}")
        }
    }

    // Create a thread that waits for the latch to reach zero
    thread(start = true) {
        println("Waiting thread started, waiting for latch to reach zero...")
        latch.await()
        println("Latch reached zero, resuming execution...")
    }
}
Kotlin

In this code, we first initialize Jedis and connect it to the local Redis instance. Then, we create a DistributedCountDownLatch with the total count set to 5. This means we expect five countDown() calls before any thread calling await() can proceed.

We then create five threads that simulate doing some work by sleeping for a random amount of time. After “doing the work”, each thread calls countDown() on the latch and prints the current count of the latch.

Finally, we create a thread that calls await() on the latch. This thread will not proceed until the count of the latch has reached zero, which will only happen once all other threads have called countDown().

This is a simplified example, but it should give you a sense of how you can use a distributed countdown latch to coordinate the behavior of multiple threads.