Implementing a Binary Tree in Rust for fun

2023-02-16

In this post, we'll look at a Binary Tree implementation in Rust and write a function that takes in the root of a binary tree that contains number (`i32`) values and returns the total sum of all values in the tree.

Implementing a binary tree in Rust is a great way to learn about the language's ownership and borrowing rules, as well as its features for managing memory. We'll write the logic in both iterative and recursive fashion.

THE THEORY

If you don't know what a binary tree is, here is a quick primer:

A tree data structure in computer science represents a hierarchical tree structure with a set of connected nodes.

A tree is a binary tree if:

• it has exactly one root node
• each node has at most two children (known as left child and right child)
• and has exactly one path between root and any other node

This means that binary trees don't have circular dependencies.

Here are a few examples of a binary tree:

``````// ✅ valid binary tree
//      a
//    /   \
//   b     c
//  / \     \
// d   e     f
//    /       \
//   g         h

// ✅ Here's another binary tree
//      a
//       \
//        b
//       /
//      c
//       \
//        d
//         \
//          e

// ✅ And this is also a binary tree
//       a
//       |
//       b
``````

The following are not binary trees:

``````// 🚨 one of the nodes has three children
//       a
//    /  |  \
//   b   c   d
//  / \       \
// e   f       g
//    /         \
//   h           i

// 🚨 There is no root node
//        a
//      /   \
//     b --- c

// 🚨 There is more than one path from
// root to another node
//        a
//        |
//        b
//      /   \
//     c --- d
``````

Now let's see what a Rust implementation will look like.

SHOW ME THE CODE

Let's start by representing a binary tree using a `struct`.

``````use std::{cell::RefCell, rc::Rc};

#[derive(Debug, Clone)]
pub struct TreeNode {
val: i32,
left: Option<TreeNodeRef>,
right: Option<TreeNodeRef>,
}

type TreeNodeRef = Rc<RefCell<TreeNode>>;
``````

So we have a `TreeNode` struct that represents a node in a binary tree, and it has a `val` (of type `i32`), and two children nodes - `left` and `right` of type `TreeNodeRef` which is an alias for `Rc<RefCell<TreeNode>>`. The alias exists so that we can save some keystrokes and it also makes the `struct` definition clearer. Both `left` and `right` node values can be empty and therefore they are of type `Option`.

Now, why did we wrap the `TreeNode` inside a `RefCell` which is itself wrapped inside a `Rc`?

`RefCell<T>` (and `Cell<T>`) is a type that allows for interior mutability, which means that it allows you to mutate the contents of value even if you only have an immutable reference (`&T`) to it. Normally, you'd need an exclusive reference (`&mut T`) to update the contents ('inherited mutability'). This does not mean that you can get around the Rust memory safety rules - borrows for `RefCell<T>` are checked at runtime so the borrow checker won't complain if you attempt to borrow a value that is already mutated but the thread will `panic` during runtime.

By wrapping the `left` and `right` nodes in `RefCell<T>`, you don't have to keep track of mutable and immutable parts of the `struct`. This is especially useful if you're working with a recursive data structure.

`Rc<T>` (stands for reference counting) "provides shared ownership of a value of type `T`, allocated in the heap". It allows multiple parts of your code to share ownership of a value and no one has exclusive ownership. It is particularly useful (and the reason we need it here) when you need to create a cyclic data structure i.e. `TreeNode` referencing another `TreeNode` in `left` and `right` field values.

It is generally not possible to create cyclic data structures using references in Rust because it would create a reference cycle that would prevent the values from being dropped when they are no longer needed.

`Rc` lets us create cyclic data structure by using reference counting to keep track of the number of references to each value.

Brilliant, I hope you got all that. Now lets look at the iterative logic to calculate the total sum of all number values in the binary tree.

``````pub fn tree_sum(root: TreeNodeRef) -> i32 {
let mut sum = 0i32;
// We'll use a `vec` as a
// stack LIFO data structure.
// Start by adding the root node
// to the stack.
let mut stack = vec![root];

while !stack.is_empty() {
// current points to top most
// item in the stack
let current: Rc<RefCell<TreeNode>> = stack.pop().unwrap();
sum += current.borrow().val;

// if there is a right node,
// then push it on top of the stack
if let Some(right) = &current.borrow().right {
stack.push(right.to_owned());
};
// if there is a left node,
// then push it on top of the stack
if let Some(left) = &current.borrow().left {
stack.push(left.to_owned());
};
}
sum
}
``````

Hopefully, the code comments are sufficient to explain the logic. We are doing a traversing the binary tree using depth first search algorithm. So if the tree looks like:

``````//      a
//    /   \
//   b     c
//  / \     \
// d   e     f
//    /       \
//   g         h
``````

Then, the search will access the values in this order:

``````// a -> b -> d -> e -> g -> c -> f -> h
``````

For each node it visits, it will add the value of that node to the running total `sum`.

And here is the same logic implemented recursively:

``````pub fn tree_sum_recursive(root: Option<&TreeNodeRef>) -> i32 {
// if `root` has `Some`thing
// return `root.val` + left_node_val + right_node_val
if let Some(root) = root {
root.borrow().val
// recursively call left path
+ tree_sum_recursive(root.borrow().left.as_ref())
// recursively call right path
+ tree_sum_recursive(root.borrow().right.as_ref())
} else {
// root is None (i.e. empty or null)
// so return `0`
0
}
}
``````

I find the recursive approach concise and elegant, and easier to read but YMMV. The base case here is - if `root` is `None` then return `0`, otherwise return `root.val` + `tree_sum_recursive(left)` + `tree_sum_recursive(right)`.

You can get the full source code here - includes unit tests.

CONCLUSION

Overall, implementing a binary tree is a great way to learn about Rust's powerful type system. You have to carefully think about ownership and borrowing rules, and safely create cyclic data structure by using types like `Rc<T>` and `RefCell<T>` in order to write safe, efficient, and (arguably) expressive code.