Wait-Free Per-Object Thread-Local Storage

Abstract

In this post, I present a wait-free thread-local storage using the Rust language.

What is Wait-Freedom?

To first introduce the concept of wait-freedom, I will introduce obstruction-freedom and lock-freedom. Note that such algorithms still can be used side-by-side with blocking algorithms. There is no need for a program to be fully non-blocking.

Obstruction-Free

An algorithm or data structure is obstruction-free if:

  1. Suspending all threads except an arbitrary one guarantees the one running makes progress with finite steps.

By finite steps we mean "I know this operation will take at most N steps" before executing the operation. This means we cannot depend on the OS scheduler directly. For instance, you cannot tell the scheduler that you are waiting for some event. Even if you decide to spin instead of relying on the scheduler, you can't spin to wait for an event. This means we cannot perform lock mechanisms such as mutex, read-write-locks, etc. If a thread holding a lock is suspended, any other thread will hang up for indefinite time when they reach the lock acquiring step.

Lock-Free

An algorithm or data structure is lock-free if:

  1. It is obstruction-free.
  2. And even when running multiple threads, at least one thread make progress with finite steps.
  3. Invariant number 2 still holds even if we suspend an arbitrary thread.

By finite steps we mean "I know this operation will take at most N steps" before executing the operation. Note that if no new thread start executing the operation, all threads will eventually make some progress. However, we still need to be able to suspend arbitrary threads and keep the rule 2, and this does not happen to locks.

Wait-Free

An algorithm or data structure is wait-free if:

  1. It is lock-free.
  2. And all threads executing the operation make progress with finite steps.

By finite steps we mean "I know this operation will take at most N steps" before executing the operation. This is really strong. But as you will see, by the end I will show that this TLS takes at most ceil(THREAD_ID_BIT_SIZE / BITS) iterations (where BITS is a constant of ours).

Overview

The first attempt to design a structure such as per-object Thread-Local Storage (TLS) is usually a hash table of thread IDs. However, most of times, combining sequential memory structures with lock-free structures does not end well. Sequential memory is just too... Sequential. There is few work to split between threads, so we need to split the structure's memory.

This is why ticki designed nested-hash-tables. I own him a lot for his idea and it was because of him I could design this TLS. The basic idea is that, when we face collisions, we create a new child table and put the colliders in this new table with some different index computation. In my case, this index computation simply uses different bits of the hash.

First of all, let's fix a constant BITS >= 1. In my own implementation I fix BITS = 8. Then, we introduce a new data-type: table. The table is an array of length = pow(2, BITS) (or in other words: 1 << BITS). It looks like this:

struct Table<T> {
    nodes: [AtomicPtr<Node<T>>; 1 << BITS],
}

As you can see, each element is an atomic pointer to a Node. A Node is either a child table or an entry.

enum Node<T> {
    Branch(Table<T>),
    Leaf(Entry<T>),
}

There are other ways of writing the type for Node, such as bit marking the pointer, which are more memory-efficient but more complex. I will keep our definition like this for didactic purpouses. The Entry is very simple too:

struct Entry<T> {
    id: usize,
    data: T,
}

And finally, this is the type for the TLS:

pub struct ThreadLocal<T> {
    top: Table<T>,
}

You may have noticed a field id on Entry. Yes, it needs to be integral. And we will have to generate it. We will write ID generation later. I will focus on the structure first.

Handling ID Collision

Collision will be handled by shifting the ID's bits (while still keeping a copy of the original ID). At each level of the TLS, we will use the first BITS lower bits of the ID's "shifted version" with a bitwise AND operation. This bitwise AND will give us the index of the node we should access in the current table. When going deeper, we will shift the ID's "shifted version" by BITS bits. Suppose somebody fixed BITS = 3, then the TLS would look like this:

Image About Collision Handling

For instance, the get method looks like this:

impl<T> ThreadLocal<T> {
    pub fn get(&self) -> Option<&T> {
        let id = thread_id();
        let mut table = &self.top;
        let mut shifted = id;

        loop {
            let index = shifted & (1 << BITS) - 1;
            let in_place = table.nodes[index].load(Acquire);

            match unsafe { in_place.as_ref() } {
                Some(Node::Leaf(entry)) if entry.id == id => {
                    break Some(&entry.data);
                },

                Some(Node::Branch(new_tbl)) => {
                    table = new_tbl;
                    shifted >>= BITS;
                }

                _ => break None,
            }
        }
    }
}

This method is wait-free since a thread will execute at most ceil(64 / BITS) iterations for 64 bit machines and ceil(32 / BITS) iterations for 32 bit machines. No matter how much we shift the IDs, eventually there will be no more bits and the IDs will be equal. A get + insert method will handle collision by trying to insert a new table:

impl<T> ThreadLocal<T> {
    pub fn with_init<F>(&self, init: F) -> &T
    where
        F: FnOnce() -> T,
    {
        let id = thread_id();
        let mut table = &*self.top;
        let mut depth = 1;
        let mut shifted = id;
        let mut index = shifted & (1 << BITS) - 1;
        let mut in_place = table.nodes[index].load(Acquire);
        let mut opt_init = Some(move || Entry { id, data: init() });
        let mut opt_ptr = None;

        loop {
            match unsafe { in_place.as_ref() } {
                Some(Node::Leaf(entry)) if entry.id == id => {
                    debug_assert!(opt_ptr.is_none());
                    break &entry.data;
                },

                Some(Node::Leaf(entry)) => {
                    let other_shifted = entry.id >> depth * BITS;
                    let other_index = other_shifted & (1 << BITS) - 1;

                    let branch = Node::Branch(unsafe { mem::zeroed() });
                    let branch_ptr = Box::into_raw(Box::new(branch));
                    let new_tbl = match unsafe { *branch_ptr } {
                        Node::Branch(tbl) => tbl,
                        _ => unreachable!(),
                    };
                    new_tbl.nodes[other_index].store(in_place, Relaxed);

                    match table.nodes[index].compare_exchange(
                        in_place,
                        branch_ptr,
                        AcqRel,
                        Release,
                    ) {
                        Ok(_) => {
                            table = new_tbl;
                            depth += 1;
                            shifted >>= BITS;
                            index = shifted & (1 << BITS) - 1;
                            in_place = table.nodes[index].load(Acquire);
                        },

                        Err(new) => {
                            unsafe { Box::from_raw(branch_ptr) };
                            in_place = new;
                        },
                    }
                },

                Some(Node::Branch(new_tbl)) => {
                    table = new_tbl;
                    depth += 1;
                    shifted >>= BITS;
                    index = shifted & (1 << BITS) - 1;
                    in_place = table.nodes[index].load(Acquire);
                },

                None => {
                    let ptr = opt_data.take().or_else(|| {
                        let init = opt_init.take().unwrap();
                        let leaf = Node::Leaf(init());
                        Box::into_raw(Box::new(leaf))
                    });

                    match table.nodes[index].compare_exchange(
                        in_place,
                        ptr,
                        AcqRel,
                        Acquire,
                    ) {
                        Ok(_) => break unsafe { &(*ptr).data },

                        Err(new) => {
                            opt_ptr = Some(ptr);
                            in_place = new;
                        },
                    }
                }
            }
        }
    }
}

This method is wait-free since a thread will execute at most ceil(64 / BITS) iterations for 64 bit machines and ceil(32 / BITS) iterations for 32 bit machines. NOTE: if we find a collision, we will try to insert a new table. Even if we fail to insert it, we will go into a deeper level in the next iteration. No extra iteration is performed.

Destruction

Entries will only be destructed when the structure by itself is destructed. We will reuse unused entries, however. This will be clear when you read the code of ID generation.

Handling Thread IDs

This is a very trick part. In my design, I need an integral thread ID so we use its bits. POSIX does not provide such thing. It just provides an opaque type and an equality function. We clearly have to generate thread IDs by ourselves. But how? We should probably use platform's global TLS to handle it. And this is kind of awful because there is a big chance the platform's TLS is not even lock-free. However, I want you to abstract away these platform's implementation details for the sake of simplicity.

At first, there is a pretty simple idea: let's use the memory address that the platform's TLS gives us. This sounds kind of cheating, but worked well on my machine. It looks like this:

thread_local! {
    static ID_MAKER: u8 = 0;
}

fn thread_id() -> usize {
    ID_MAKER.with(|addr| addr as *const _ as usize)
}

Works? Yes. Portable? Nope. This is a hack. If you can't afford making extra allocations and you now for sure it will work on your users' machine, go on. Use this. However, there is no guarantee that Rust or the OS will give us stable addresses. Another disadvantage is that we will not have a good distribution of addresses. First of all, we have alignment bits. Second of all, the other bits might follow some patterns.

Instead, we will use linked lists of IDs. Nodes will never be removed physically, only logically. Inserting, removing (allocating ID) and reinserting (freeing ID) a node will be O(1) for all threads, thus preserving wait-freedom. Searching for an available node is, however, O(n) where n is the length of the list when the search began (removed/allocated nodes included).

struct IdGuard {
    bits: usize,
    node: &'static Node,
}

struct Node {
    // Set to usize::max_value() when not free.
    free: AtomicUsize,
    next: AtomicPtr<Node>,
}

We will use lower bits of the id first. To improve distribution of lower bits, lower IDs will tend to stay on the earlier positions of the list, and so, the smallest IDs will be found first. I said tend. It is not a guarantee because inserting a node and creating an ID won't be a single atomic operation, so the scheduler might delay the ID creation and other threads might create IDs for nodes and insert them after ours. But the sorting of list's nodes is indeed, a tendency, since the scheduler will probably not execute too weirdly.

In order to have O(1) insertion (i.e. no cas loop), we will keep both front and back of the list. Insertion will just swap the previous back with a freshly allocated node. After that, we will update the next field of the previous back. This may cause a temporary desynchronization between front navigation of the list and reading the back. But it is acceptable. The list will start with a single node, so we don't need to initialize the front when first inserting at the back.

// We will start with one since our list will start with the node keeping 0.
static ID_COUNTER: AtomicUsize = AtomicUsize::new(1);

static ID_LIST: Node = Node {
    free: AtomicUsize::new(0),
    next: AtomicPtr::new(null_mut())
};

static ID_LIST_BACK: AtomicPtr<Node> = AtomicPtr::new(
    &ID_LIST as *const _ as *mut _
);

Also, ID guard is exactly what the name says: it guards the use of a specific node and ID. In the destructor, we will "reinsert" the guard's node, that is: we will free the ID. When constructing the guard, we will properly allocate an ID.

thread_local! {
    static ID: IdGuard = IdGuard::new();
}

impl Drop for IdGuard {
    fn drop(&mut self) {
        self.node.free.store(self.bits, Relaxed);
    }
}

Creation is very simple: we first load a back pointer. This will be our limit, otherwise our thread could keep navigating through the list while other threads insert nodes, infinitely. If we find a free node, we will simply use it. Or else, we will create and insert a new node.

impl IdGuard {
    fn new() -> Self {
        let back_then = ID_LIST_BACK.load(Acquire);

        let mut node = &ID_LIST;

        loop {
            let bits = node.free.swap(usize::max_value(), Relaxed);
            if bits != usize::max_value() {
                break Self { node, bits };
            }

            let next = node.next.load(Acquire);

            if next.is_null() || node as *const _ == back_then {
                break Self::create_node();
            }

            node = unsafe { &*next };
        }
    }
}

When creating a node, we will insert it first and increment the counter later. Don't worry with overflow, we will run out of address space before overflowing the counter.

impl IdGuard {
    fn create_node() -> Self {
        let new = Node {
            free: AtomicUsize::new(usize::max_value()),
            next: AtomicPtr::new(null_mut()),
        };

        let alloc = OwnedAlloc::new(new);
        let nnptr = alloc.into_raw();

        let prev = ID_LIST_BACK.swap(nnptr.as_ptr(), AcqRel);

        let bits = ID_COUNTER.fetch_add(1, Relaxed);

        let node = unsafe {
            (*prev).next.store(nnptr.as_ptr(), Release);
            &*nnptr.as_ptr()
        };

        Self { node, bits }
    }
}

Finally, we can write a "get me my thread ID" function:

fn thread_id() -> usize {
    ID.with(|guard| guard.bits)
}

Final Thoughts: This is Also ABA-free

This structure is ABA-free because we never remove anything while the structure is shared. The same applies to the ID generation.

Full Implementation

The full implementation is available as a Rust crate named lockfree. You can find the source code here.