source('http://myweb.uiowa.edu/pbreheny/7600/s16/notes/fun.R')
require(ncvreg)

# Slide 6
# Requires running the code in sim1.R
load("sim1.RData")

col <- pal(2)
matplot(SNR, apply(MSE, 2:3, mean), type='l', lwd=3, lty=1, col=col,
        xlab="SNR", ylab="Mean squared error", las=1, bty='n')
toplegend(legend=c("Lasso", "MCP"), lwd=3, col=col)

M <- apply(MSE, 2:3, mean)
rMSE <- M[,2]/M[,1]
matplot(SNR, rMSE, type='l', lwd=3, lty=1, col=pal(2)[2],
        xlab="SNR", ylab="Relative MSE", las=1, bty='n')
abline(h=1, col='gray', lwd=2, lty=2)

# Slide 7
# Requires running the code in sim2.R
load("sim2.RData")

col <- pal(length(SNR))
par(mfrow=c(1,2), mar=c(5,5,1,0.5), oma=c(0,0,3,0))

matplot(log2(gam), t(apply(mMSE, 2:3, mean)), type='l', xaxt='n', lwd=3, lty=1, col=col,
        xlab=expression(gamma), ylab="Mean squared error", las=1, bty='n')
axis(1, at=1:5, labels=2^(1:5))
mtext('MCP')

matplot(log2(gam), t(apply(sMSE, 2:3, mean)), type='l', xaxt='n', lwd=3, lty=1, col=col,
        xlab=expression(gamma), ylab="Mean squared error", las=1, bty='n')
axis(1, at=1:5, labels=2^(1:5)+1)
mtext('SCAD')
toplegend(legend=c("SNR: ", SNR), lwd=3, col=c('white',col))

# Slide 20
set.seed(4)
Data <- genData(20, 50, J0=4)
fit <- with(Data, ncvreg(X, y, nlambda=500))
plot(fit, bty='n')

# Slide 21
X <- std(Data$X)
y <- with(Data, y - mean(y))
fit <- ncvreg(X, y, nlambda=500)
ind1 <- fit$convex.min
ind2 <- fit$convex.min +1
l1 <- fit$lambda[ind1]
l2 <- fit$lambda[ind2]
b1 <- coef(fit, which=ind1)[-1]
b2 <- coef(fit, which=ind2)[-1]
n <- length(y)
MCP <- function(theta, l, a=3) {
  T <- length(theta)
  val <- numeric(T)
  for (i in 1:T) {
    x <- abs(theta[i])
    val[i] <- (x < a*l)*(l*x - x^2/(2*a)) + (x >= a*l)*(1/2)*a*l^2
  }
  val
}
Q <- function(b, lam) {
  r <- y - X%*%b
  1/(2*n)*crossprod(r) + sum(MCP(b, lam))
}
x <- seq(-0.5,1.5,len=101)
q1 <- q2 <- q3 <- length(x)
for (i in 1:length(x)) {
  q1[i] <- Q(x[i]*b2 + (1-x[i])*b1, 1.2*l1)
  q2[i] <- Q(x[i]*b2 + (1-x[i])*b1, l1/2)
  q3[i] <- Q(x[i]*b2 + (1-x[i])*b1, 2*l1)
}
plot(x, q1, type='l', las=1, ylim=c(0.1, 1.4), bty='n', xaxt='n', ylab=expression(Q(beta)), xlab='')
axis(1, at=c(-0.5, 0, 1, 1.5), labels=expression("", beta[1], beta[2], ""))
lines(x, q2)
lines(x, q3)
text(1.9, tail(q1,1), expression(lambda==0.25), xpd=1)
text(1.9, tail(q2,1), expression(lambda==0.11), xpd=1)
text(1.9, tail(q3,1), expression(lambda==0.42), xpd=1)
