# remotes::install_github('jklosa/seagull')
suppressPackageStartupMessages({
  library(hdrm)
  # library(seagull)
  library(SGL)
  library(grpreg)
  library(kableExtra)
})

# Setup; generate toy example ---------------------------------------------

set.seed(3)
dat <- gen_data_grp(
  n = 50,
  J = 10,
  K = 3,
  beta = c(1, -0.5, 0, rep(0, 15), 0.5, -1, 0, rep(0, 9))
)
x <- dat$X
y <- dat$y
g <- dat$group

par(mar = c(4, 4.5, 2, 0.1))
plot_sgl <- function(
  x, g, lty = 1, type = "l", las = 1, bty = "n", col, xlim = rev(range(ll)),
  xlab = expression(lambda), ylab = expression(beta), ...
) {
  ll <- x$lambdas
  B <- sweep(x$beta, 1, x$X.transform$X.scale, "/")
  if (missing(col)) {
    n <- length(table(g))
    col <- rep(pal(n), table(g))
  }
  matplot(
    ll, t(B), las = las, type = type, lty = lty, col = col, bty = bty,
    xlim = xlim, xlab = xlab, ylab = ylab, ...
  )
}
alpha <- seq(1, 0, -.1)
for (i in 1:length(alpha)) {
  fit <- SGL(list(x = x, y = y), g, alpha = alpha[i], min.frac = 0.01)
  plot_sgl(fit, g, lwd = 2)
  mtext(bquote(alpha == .(alpha[i])))
}

par(mar = c(4, 4.5, 2, 0.1))
tau <- seq(0, 0.5, len = 11)
for (i in 1:length(tau)) {
  fit <- grpreg(x, y, g, penalty = "gel", tau = tau[i])
  plot(fit, lwd = 2, bty = "n")
  mtext(bquote(tau == .(tau[i])))
}

attach_data(glcamd)
fold <- assign_fold(y, 10)
gene <- factor(group)
y <- 1 * (y == 'AMD')

# Warning: SGL takes a very long time
cvfit_sgl <- list(x = x, y = y) |>
  cvSGL(index = gene, type = "logit", alpha = 0.5, foldid = fold, min.frac = 0.2)
cvfit_gel <- cv.grpreg(
  x, y, gene, family = 'binomial', penalty = 'gel', fold = fold,
  lambda.min = 0.75, alpha = 0.5, tau = 0.2, returnY = TRUE
)
cvfit_gel <- cv.grpreg(
  x, y, gene, family = 'binomial', penalty = 'gel', fold = fold,
  lambda.min = 0.5, alpha = 0.2, tau = 0.2, returnY = TRUE
)

# list(x = x, y = y) |>
#   cvSGL(index = gene, type = 'logit', alpha = 0.5)

# cv.grpreg(x, y, group = gene,
#   family = 'binomial ', penalty = 'gel')

plot_rsq <- function(cvfit) {
  if (inherits(cvfit, 'cv.grpreg')) {
    dev <- ncvreg:::loss.ncvreg(y, cvfit$Y, 'binomial')
    ll <- log(cvfit$lambda)
  } else {
    dev <- ncvreg:::loss.ncvreg(y, cvfit$prevals, 'binomial')
    ll <- log(cvfit$lambdas)
  }
  mean_dev <- colMeans(dev)
  se_dev <- apply(dev, 2, sd) / sqrt(nrow(dev))
  rsq <- (1 - exp(mean_dev - mean_dev[1])) |> pmax(0) |> pmin(1)
  lwr <- (1 - exp(mean_dev + se_dev - mean_dev[1])) |> pmax(0) |> pmin(1)
  upr <- (1 - exp(mean_dev - se_dev - mean_dev[1])) |> pmax(0) |> pmin(1)
  
  plot(
    ll, rsq, ylim = c(0, 0.1), xlim = rev(range(ll)), las = 1, bty = 'n',
    xlab = expression(log(lambda)), ylab = ~R^2, type = 'n'
  )
  suppressWarnings(arrows(
    x0 = ll, x1 = ll, y0 = lwr, y1 = upr, code = 3, angle = 90, col = "gray80", length = 0.05
  ))
  points(ll, rsq, col = 'red', pch = 19, cex = 0.75)
}
plot_rsq(cvfit_sgl)
mtext('SGL')
plot_rsq(cvfit_gel)
mtext('GEL')

cvfit_las <- cv.grpreg(x, y, family = 'binomial', fold = fold, lambda.min = 0.2)
cvfit_grp <- cv.grpreg(x, y, gene, family = 'binomial', fold = fold, lambda.min = 0.2)
dev_sgl <- ncvreg:::loss.ncvreg(y, cvfit_sgl$prevals, 'binomial') |> colMeans()
rsq_sgl <- (1 - exp(dev_sgl - dev_sgl[1]))
b_las <- coef(cvfit_las, which = cvfit_las$min)
g_las <- length(table(gene[b_las[-1] != 0]))
min_sgl <- which.min(cvfit_sgl$lldiff)
b_sgl <- cvfit_sgl$fit$beta[, min_sgl]
g_sgl <- length(table(gene[b_sgl[-1] != 0]))

data.frame(
  method = c('Lasso', 'Group lasso', 'GEL', 'SGL'),
  rsq = c(
    max(summary(cvfit_las)$r.squared),
    max(summary(cvfit_grp)$r.squared),
    max(summary(cvfit_gel)$r.squared),
    max(rsq_sgl)
  ),
  genes = c(
    g_las,
    predict(cvfit_grp, type = 'ngroups'),
    predict(cvfit_gel, type = 'ngroups'),
    g_sgl
  ),
  vars = c(
    predict(cvfit_las, type = 'nvars'),
    predict(cvfit_grp, type = 'nvars'),
    predict(cvfit_gel, type = 'nvars'),
    sum(b_sgl[-1] != 0)
  )
) |>
  kbl(
    'latex', booktabs = TRUE, digits = c(0, 2, 0, 0), escape = FALSE,
    col.names = c('Method', '$R^2$', 'Genes', 'Variants')
  ) |>
  kable_styling()

