library(hdrm)
library(SGL)
suppressPackageStartupMessages(library(grpreg))

# Setup; generate toy example ---------------------------------------------

set.seed(3)
Data <- genDataGrp(n=50, J=10, K=3, beta=c(1,-0.5,0, rep(0,15), 0.5,-1,0, rep(0, 9)))
X <- Data$X
y <- Data$y
g <- Data$group

# Sparse group lasso illustration -----------------------------------------

fit1 <- SGL(list(x=X, y=y), g, alpha=1, min.frac=0.01)
fit2 <- SGL(list(x=X, y=y), g, alpha=0.5, min.frac=0.01)
fit3 <- SGL(list(x=X, y=y), g, alpha=0, min.frac=0.01)

plot.sgl <- function(x, g, lty=1, type="l", las=1, bty="n", col, xlim=rev(range(ll)), xlab=expression(lambda), ylab=expression(beta), ...) {
  ll <- x$lambdas
  B <- sweep(x$beta, 1, x$X.transform$X.scale, "/")
  if (missing(col)) {
    n <- length(table(g))
    col <- rep(hdrm:::pal(n), table(g))
  }
  matplot(ll, t(B), las=las, type=type, lty=lty, col=col, bty=bty,
          xlim=xlim, xlab=xlab, ylab=ylab, ...)
}
alpha <- seq(1,0,-.1)
for (i in 1:length(alpha)) {
  fit <- SGL(list(x=X, y=y), g, alpha=alpha[i], min.frac=0.01)
  par(mar=c(4,4.5,1,0.2))
  plot.sgl(fit, g, lwd=2)
  mtext(bquote(alpha==.(alpha[i])))
}


# GEL illustration --------------------------------------------------------

fit <- grpreg(X, y, g, penalty="gel")
plot(fit, lwd=2)

tau <- seq(0, 0.5, len=11)
for (i in 1:length(tau)) {
  fit <- grpreg(X, y, g, penalty="gel", tau=tau[i])
  fname <- paste0("fig/gel", i, ".pdf")
  par(mar=c(4,4.5,1,0.1))
  plot(fit, lwd=2, bty="n")
  mtext(bquote(tau==.(tau[i])))
}


# Macular degeneration case study -----------------------------------------

attachData('glc-amd')
Gene <- group

# GEL
seed <- 5
cvfit <- cv.grpreg(X, y, group=Gene, family="binomial", penalty="gel", seed=seed, lambda.min=0.2, alpha=0.04, tau=1/4)

plot(cvfit, type="rsq")
plot(cvfit, type="pred")
summary(cvfit)

# SGL (takes a long time)
cvfit.sgl <- cvSGL(list(x=X, y=(y=='GLC')), index=as.numeric(factor(Gene)), type="logit", min.frac=0.2, alpha=0.5)

null.dev <- cvfit$null.dev*400
cve <- cvfit.sgl$lldiff
rsq <- (cve-null.dev)/null.dev
max(rsq)
b <- cvfit.sgl$fit$beta[,19]
sum(b!=0)
length(table(Gene[b!=0]))
