library(hdrm)

############################
#### Case study 1: BRCA ####
############################

attach_data('brca1')
set.seed(1)

# cvfit <- cv.glmnet(X, y)
# fit <- cvfit$glmnet.fit

# Demo code
cvfit <- cv.glmnet(X, y)
fit <- cvfit$glmnet.fit
xlim <- log(c(fit$lambda[1], cvfit$lambda.min))
plot(fit, xlim=xlim, xvar="lambda")
plot(cvfit)

# xlim <- log(c(fit$lambda[1], cvfit$lambda.min))
# plot(fit, xlim=xlim, xvar="lambda")

par(mar=c(4, 4, 5, 0.5))
Fig2.10(cvfit, 'left')

# plot(cvfit)

par(mar=c(4, 4, 5, 0.5))
Fig2.10(cvfit, 'right')

max(1-cvfit$cvm/var(y))

# fit <- ncvreg(X, y, penalty='lasso')
# AIC(fit); BIC(fit)

par(mar=c(4, 4, 2, 0.5))
ncvreg(X, y, penalty='lasso', lambda=cvfit$lambda) |> Fig2.11()
ncvreg(X, y, penalty='lasso', lambda.min=0.001) |> Fig2.11()

# Here's what AICC looks like, if you're curious
Fig2.11_aicc <- function(fit) {
  lam <- log(fit$lambda)
  ll <- logLik(fit)
  df <- as.numeric(attr(ll, "df"))
  print(range(df))
  IC <- cbind(AIC(fit), BIC(fit), AIC(fit) + 2 * df * (df + 1)/(fit$n - df - 1))
  matplot(lam, IC, xlim = rev(range(lam)), col = hdrm::pal(3), type = "l",
          lwd = 3, lty = 1, bty = "n", xlab = expression(lambda),
          xaxt = "n", las = 1, ylab = "AIC/BIC")
  at <- seq(max(lam), min(lam), length = 5)
  axis(1, at = at, labels = round(exp(at), 2))
  hdrm::toplegend(legend = c("AIC", "BIC", "AICc"), col = hdrm::pal(3), lwd = 3)
}
ncvreg(X, y, penalty='lasso', lambda.min=0.001) |> Fig2.11_aicc()

b <- coef(cvfit)
b[which(b > 0.15),,drop=FALSE]

b <- coef(fit, s=0.2)
sum(b != 0)

predict(cvfit, X[85,,drop=FALSE])

################################
#### Case study 2: Carbotax ####
################################

attach_data(Koussounadis2014)
set.seed(1)

# Combine low- and high-dimensional features
library(splines)
sDay <- ns(sData$Day, df=2)
X0 <- model.matrix(~ Treatment*sDay, sData)[,-1]
w <- rep(0:1, c(ncol(X0), ncol(X)))
XX <- cbind(X0, X)

# cvfit <- cv.glmnet(XX, y, penalty.factor=w)
# fit <- cvfit$fit

cvfit <- cv.ncvreg(XX, y, penalty.factor=w,
                   penalty='lasso')
fit <- cvfit$fit

# plot(fit)

Fig2.12(cvfit, 'left')

# plot(cvfit, type='rsq') # Only available in ncvreg

Fig2.12(cvfit, 'right')

# Double exponential prior
par(mar=c(5,3,1,5))
x <- seq(-3,3,len=101)
ylim <- c(0,dexp(0)/2)
plot(x, dnorm(x), type="l", col="gray80", lwd=3, ylim=ylim, yaxt="n", bty="n", xlab=expression(beta),ylab="")
mtext(expression(p(beta)),2,1)
lines(x,dexp(abs(x))/2,col="slateblue",lwd=3)
hdrm::rightlegend(legend=c("Ridge","Lasso"), lwd=3, col=c("gray80","slateblue"))

