require(rpart)
require(party)
require(partykit)

## Slide 4
par(bty="n")
plot(c(0,10),c(0,10),type="n",xaxt="n",yaxt="n",xlab=expression(x[1]),ylab=expression(x[2]))
polygon(c(0,10,10,0),c(0,0,10,10),col="white")
segments(x0=7,x1=7,y0=0,y1=10)
segments(x0=0,x1=7,y0=3,y1=3)
segments(x0=2,x1=2,y0=3,y1=10)
segments(x0=5,x1=5,y0=3,y1=10)

## Accompanying tree (slide 7)
x1 <- 10*runif(1000)
x2 <- 10*runif(1000)
y <- character(1000)
y[(x1 < 2) & (x2 > 3)] <- "R1"
y[(x1 > 2) & (x1 < 5) & (x2 > 3)] <- "R2"
y[(x1 > 5) & (x1 < 7) & (x2 > 3)] <- "R3"
y[(x1 < 7) & (x2 < 3)] <- "R4"
y[(x1 > 7)] <- "R5"
fit <- rpart(y~x1+x2,maxdepth=4)
plot(fit)
fit$splits[,"index"] <- round(fit$splits[,"index"])
text(fit,digits=1)

## Slide 9
data("WeatherPlay", package = "partykit")
pn <- partynode(1L,
                split = partysplit(1L, index = 1:3),
                kids = list(
                  partynode(2L,
                            split = partysplit(3L, breaks = 75),
                            kids = list(
                              partynode(3L, info = "yes"),
                              partynode(4L, info = "no"))),
                  partynode(5L, info = "yes"),
                  partynode(6L,
                            split = partysplit(4L, index = 1:2),
                            kids = list(
                              partynode(7L, info = "yes"),
                              partynode(8L, info = "no")))))
py <- party(pn, WeatherPlay)
plot(py, ip_args=list(id=FALSE), tp_args=list(id=FALSE))

## Smoking
shs <- read.delim("http://web.as.uky.edu/statistics/users/pbreheny/621/data/shs.txt")

## Pruning (slide 25)
fit0 <- rpart(Cotinine~., data=shs, cp=0)
fit0$cptable
plotcp(fit0)
n <- nrow(fit0$cptable)
Y <- fit0$cptable[, c("rel error", "xerror")]
col <- c("#FF4E37FF", "#008DFFFF")
matplot(1:n, Y, type="o", pch=19, lwd=2, lty=1, xaxt="n", xlab="Size of tree", ylab="Relative error", col=col)
axis(1, at=1:n, labels=fit0$cptable[,"nsplit"]+1)
legend("topright", legend=c("In-sample", "Cross-validated"), pch=19, lwd=2, col=pal(2), ncol=2)

## Slide 26
alpha <- fit0$cptable[which.min(fit0$cptable[,"xerror"]),"CP"]
fit <- prune(fit0,alpha)
plot(as.party(fit0),ip_args=list(id=FALSE),tp_args=list(id=FALSE))
plot(as.party(fit),ip_args=list(id=FALSE),tp_args=list(id=FALSE))

## Slide 28
plot(fit)
text(fit, pretty=0)
plot(as.party(fit), ip_args=list(id=FALSE), tp_args=list(id=FALSE))

## Slide 29
plot(as.party(fit), ip_args=list(id=FALSE), tp_args=list(id=FALSE))
fit1 <- ctree(Cotinine~.,data=shs)
plot(fit1, ip_args=list(id=FALSE), tp_args=list(id=FALSE))

## Slide 35
x <- runif(1000)
y <- 5*x+rnorm(1000)
nd <- data.frame(x=seq(0,1,len=101))
fit.lm <- lm(y~x)
fit.tr <- rpart(y~x)
plot(x,y,pch=19,col="gray",cex=0.5)
xx <- seq(0, 1, len=101)
lines(xx,predict(fit.lm,nd))
yy <- predict(fit.tr,nd)
brk <- c(0,which(diff(yy)!=0),101)
for (i in 1:(length(brk)-1)) {
  ind <- (brk[i]+1):brk[i+1]
  lines(xx[ind],yy[ind],col="red",lwd=3)
}
