BlogHome

Implementing a Binary Tree in Rust for fun

2023-02-16

In this post, we'll look at a Binary Tree implementation in Rust and write a function that takes in the root of a binary tree that contains number (i32) values and returns the total sum of all values in the tree.

Implementing a binary tree in Rust is a great way to learn about the language's ownership and borrowing rules, as well as its features for managing memory. We'll write the logic in both iterative and recursive fashion.

THE THEORY

If you don't know what a binary tree is, here is a quick primer:

A tree data structure in computer science represents a hierarchical tree structure with a set of connected nodes.

A tree is a binary tree if:

  • it has exactly one root node
  • each node has at most two children (known as left child and right child)
  • and has exactly one path between root and any other node

This means that binary trees don't have circular dependencies.

Here are a few examples of a binary tree:

// ✅ valid binary tree
//      a
//    /   \
//   b     c
//  / \     \
// d   e     f
//    /       \
//   g         h

// ✅ Here's another binary tree
//      a
//       \
//        b
//       /
//      c
//       \
//        d
//         \
//          e

// ✅ And this is also a binary tree
//       a
//       |
//       b

The following are not binary trees:

// 🚨 one of the nodes has three children
//       a
//    /  |  \
//   b   c   d
//  / \       \
// e   f       g
//    /         \
//   h           i

// 🚨 There is no root node
//        a
//      /   \
//     b --- c

// 🚨 There is more than one path from
// root to another node
//        a
//        |
//        b
//      /   \
//     c --- d

Now let's see what a Rust implementation will look like.

SHOW ME THE CODE

Let's start by representing a binary tree using a struct.

use std::{cell::RefCell, rc::Rc};

#[derive(Debug, Clone)]
pub struct TreeNode {
  val: i32,
  left: Option<TreeNodeRef>,
  right: Option<TreeNodeRef>,
}

type TreeNodeRef = Rc<RefCell<TreeNode>>;

So we have a TreeNode struct that represents a node in a binary tree, and it has a val (of type i32), and two children nodes - left and right of type TreeNodeRef which is an alias for Rc<RefCell<TreeNode>>. The alias exists so that we can save some keystrokes and it also makes the struct definition clearer. Both left and right node values can be empty and therefore they are of type Option.

Now, why did we wrap the TreeNode inside a RefCell which is itself wrapped inside a Rc?

RefCell<T> (and Cell<T>) is a type that allows for interior mutability, which means that it allows you to mutate the contents of value even if you only have an immutable reference (&T) to it. Normally, you'd need an exclusive reference (&mut T) to update the contents ('inherited mutability'). This does not mean that you can get around the Rust memory safety rules - borrows for RefCell<T> are checked at runtime so the borrow checker won't complain if you attempt to borrow a value that is already mutated but the thread will panic during runtime.

By wrapping the left and right nodes in RefCell<T>, you don't have to keep track of mutable and immutable parts of the struct. This is especially useful if you're working with a recursive data structure.

Rc<T> (stands for reference counting) "provides shared ownership of a value of type T, allocated in the heap". It allows multiple parts of your code to share ownership of a value and no one has exclusive ownership. It is particularly useful (and the reason we need it here) when you need to create a cyclic data structure i.e. TreeNode referencing another TreeNode in left and right field values.

It is generally not possible to create cyclic data structures using references in Rust because it would create a reference cycle that would prevent the values from being dropped when they are no longer needed.

Rc lets us create cyclic data structure by using reference counting to keep track of the number of references to each value.

Brilliant, I hope you got all that. Now lets look at the iterative logic to calculate the total sum of all number values in the binary tree.

pub fn tree_sum(root: TreeNodeRef) -> i32 {
    let mut sum = 0i32;
    // We'll use a `vec` as a
    // stack LIFO data structure.
    // Start by adding the root node
    // to the stack.
    let mut stack = vec![root];

    while !stack.is_empty() {
        // current points to top most
        // item in the stack
        let current: Rc<RefCell<TreeNode>> = stack.pop().unwrap();
        sum += current.borrow().val;

        // if there is a right node,
        // then push it on top of the stack
        if let Some(right) = &current.borrow().right {
            stack.push(right.to_owned());
        };
        // if there is a left node,
        // then push it on top of the stack
        if let Some(left) = &current.borrow().left {
            stack.push(left.to_owned());
        };
    }
    sum
}

Hopefully, the code comments are sufficient to explain the logic. We are doing a traversing the binary tree using depth first search algorithm. So if the tree looks like:

//      a
//    /   \
//   b     c
//  / \     \
// d   e     f
//    /       \
//   g         h

Then, the search will access the values in this order:

// a -> b -> d -> e -> g -> c -> f -> h

For each node it visits, it will add the value of that node to the running total sum.

And here is the same logic implemented recursively:

pub fn tree_sum_recursive(root: Option<&TreeNodeRef>) -> i32 {
    // if `root` has `Some`thing
    // return `root.val` + left_node_val + right_node_val
    if let Some(root) = root {
        root.borrow().val
            // recursively call left path
            + tree_sum_recursive(root.borrow().left.as_ref())
            // recursively call right path
            + tree_sum_recursive(root.borrow().right.as_ref())
    } else {
        // root is None (i.e. empty or null)
        // so return `0`
        0
    }
}

I find the recursive approach concise and elegant, and easier to read but YMMV. The base case here is - if root is None then return 0, otherwise return root.val + tree_sum_recursive(left) + tree_sum_recursive(right).

You can get the full source code here - includes unit tests.

CONCLUSION

Overall, implementing a binary tree is a great way to learn about Rust's powerful type system. You have to carefully think about ownership and borrowing rules, and safely create cyclic data structure by using types like Rc<T> and RefCell<T> in order to write safe, efficient, and (arguably) expressive code.


FOOTNOTES

This work is licensed under CC BY-NC-SA 4.0.