Sunday, June 24, 2012

Optimal sorting using rpart

Some time ago I read a nice post Solving easy problems the hard way where linear regression is used to solve an interesting puzzle. Following the idea I used rpart to find optimal decision tree sorting five elements.

It is well known that in order to sort five elements it is enough to use seven comparisons. Interestingly it is possible to adjust rpart function to find the optimal decision tree.

To see this first let us start with generating the appropriate training data using the following code:

library(gtools)
perm5 <- permutations(5,5)
comb5 <- combinations(5,2)

data.set <- data.frame(perm5c = apply(perm5, 1,
    paste, collapse = ""))
for (in 1:nrow(comb5)) {
    i <- comb5[k, 1]
    j <- comb5[k, 2]
    val <- factor(apply(perm5, 1function(v) {
        which(== v) < which( j == v) }))
    data.set[paste("v", i, j, sep = "")] <- val
}

We get a data set with 120 rows. Each row contains a column perm5c representing one permutation of the set from 1 to 5 and 10 columns representing pairwise comparison results between variables. We want to find a decision tree that partitions this data set into one-element leaves with seven decisions in the worst case.

Simply applying rpart in the following way:

tree.model <- rpart(perm5c ~ ., data = data.set,
    control = rpart.control(minsplit = 1, cp = 0))
plot(tree.model)

unfortunately generates a tree where ten decisions are neeeded in the worst case.

The trick to solve this problem is that you can supply rpart with your own split evaluation function. In this case the appropriate splitting rule should favour balanced splits (so for example 30-30 split is better than 40-20). Here is the code that does the work:

library(rpart)

temp1 <- function(y, wt, parms) {
    lab <- ifelse(length(y) == 1,
        as.integer(paste(perm5[y,], collapse="")),NaN)
    list(label = lab, deviance = length(y) - 1)
}

# prefer balanced splits
temp2 <- function(y, wt, x, parms, continuous) {
    list(goodness = nrow(perm5)^2-sum(table(x)^2),
        direction = unique(x))
}

temp3 <- function(y, offset, parms, wt) {
    list(y = y, parms = 0, numresp = 1, numy = 1,
         summary = function(yval, dev, wt, ylevel, digits ) {
                  "Summary" })
}

tree.model <- rpart(perm5c ~ ., data = data.set,
    control = rpart.control(minsplit = 1, cp = 0),
    method = list(eval = temp1, split = temp2, init = temp3))
plot(tree.model)

If you look at the output after printing tree.model you can see which paths lead to which classifications.
As you can see on the figure below at most seven comparisons are needed:


Additionally it is easy to modify the code to produce optimal trees for other numbers of elements to be sorted.

No comments:

Post a Comment