Sunday, January 22, 2012

Simulating average height of a random binary search tree

Recently on Stack Overflow I have found a discussion on Average height of a binary search tree. The problem has been solved analytically, see for example Reed (2003). However, I was intrigued by one of the answers that presented a simulation of average tree height. I thought that it would be nice to have its implementation in R.

I decided to use a standard approach in which one generates random permutations of given size and stores them in a tree -  Performance test  part of the code given at the end of the post that performs a test. Function  test  expects to get a function argument that returns tree height for given random permutation.

I started with implementation of binary tree using lists (Implementation using lists part of the code). In this approach each node has attribute containing its value and list containing left and right subtree. An interesting thing to notice in this approach is that one can use recursive access to the list elements (path variable). Unfortunately the code worked quite slow:

> test(th.list)
  size exec.time mean.height sd.height
1  512      2.37    19.56250  1.721543
2 1024      8.01    22.28125  2.158769
3 2048     31.01    25.21875  2.296342

so I decided to stop using lists.

The improved implementation (Implementation using vectors part of the code given below) uses a vector to store values and a matrix with two columns storing indexes of left and right subtrees in values vector. The code works significantly faster for larger tree sizes:

> test(th.vector)
  size exec.time mean.height sd.height
1  512      1.43    19.56250  1.721543
2 1024      3.27    22.28125  2.158769
3 2048      7.23    25.21875  2.296342

but then I thought that one does not have to construct a tree in order to calculate its height.

The third implementation (Implementation without tree generation) does not construct the tree but recursively splits input permutation. This approach an order of magnitude faster than second:

> test(th.virtual)
  size exec.time mean.height sd.height
1  512      0.07    19.56250  1.721543
2 1024      0.17    22.28125  2.158769
3 2048      0.32    25.21875  2.296342

I wonder if it is possible to get a much better result in R?

The full code of the simulations is given below:

# Performance test
test <- function(tree.height) {
      set.seed(0)
      sizes <- 2^c(9, 10, 11)
      reps <- 32
      test.cases <- list()
      results <- data.frame(size = sizes, exec.time = NA,
                            mean.height = NA, sd.height = NA)
      for (i in seq_along(sizes)) {
            test.cases[[i]] <- replicate(reps,
                                     sample.int(sizes[i]))
      }
     
      for (i in seq_along(sizes)) {
            height <- numeric(reps)
            simtime <- system.time(
            for (j in 1:ncol(test.cases[[i]])) {
                  height[j] <- tree.height(test.cases[[i]][,j])
            })[3]
            results[i,2:4] <- c(simtime,
                            mean(height), sd(height))
      }
      return(results)
}

# Implementation using lists
th.list <- function(perm) {
      comp <- function(a, b) {
            if (a == b) { return(0) }
            if (a > b) { return(1) }
            return(2)
      }

      tree.add <- function(tree, value) {
            new.node <- function(value) {
                  node <- list(high = NULL, low = NULL)
                  attr(node, "value") <- value
                  return(node)
            }
            if (is.null(tree)) { return(new.node(value)) }
            sub.tree <- tree
            path <- numeric(0)
            while (TRUE) {
                  cr <- comp(value, attr(sub.tree, "value"))
                  if (cr == 0) {
                        return(tree)
                  }
                  path <- c(path, cr)
                  sub.tree <- sub.tree[[cr]]
                  if (is.null(sub.tree)) {
                        tree[[path]] <- new.node(value)
                        return(tree)
                  }
            }
      }

      tree.height <- function(tree, height = 0) {
            if (is.null(tree)) { return (height) }
            return (max(tree.height(tree$high, height + 1),
                              tree.height(tree$low, height + 1)))
      }
     
      tree <- NULL
      for (i in perm) {
            tree <- tree.add(tree, i)
      }
      return(tree.height(tree))
}

# Implementation using vectors
th.vector <- function(perm) {
      comp <- function(a, b) {
            if (a == b) { return(0) }
            if (a > b) { return(1) }
            return(2)
      }

      tree.add <- function(value) {
            if (max.idx == 0) {
                  max.idx <<- 1
                  values[max.idx] <<- value
            } else {
                  cur.idx <- 1
                  while (TRUE) {
                        cr <- comp(value, values[cur.idx])
                        if (cr == 0) { return() }
                        next.idx <- idxs[cur.idx, cr]
                        if (next.idx == 0) {
                              max.idx <<- max.idx + 1
                              idxs[cur.idx, cr] <<- max.idx
                              values[max.idx] <<- value
                              return()
                        } else {
                              cur.idx <- next.idx
                        }
                  }
            }
      }

      tree.height <- function(idx = 1) {
            if (any(idx > max.idx, idx == 0)) { return (0) }
            return (1 + max(tree.height(idxs[idx, 1]),
                                   tree.height(idxs[idx, 2])))
      }

      max.idx <- 0
      values <- numeric(length(perm))
      idxs <- matrix(0, nrow = length(perm), ncol = 2,
            dimnames = list(NULL, c("high", "low")))
      for (i in perm) {
            tree.add(i)
      }
      return(tree.height())
}

# Implementation without tree generation
th.virtual <- function(perm) {
      if (length(perm) < 2) { return(length(perm)) }
      high <- perm[perm > perm[1]]
      low <- perm[perm < perm[1]]
      return (1 + max(th.virtual(high), th.virtual(low)))
}

No comments:

Post a Comment

Note: Only a member of this blog may post a comment.