library(data.table)
library(hdrm)
library(ggplot2)
library(kableExtra)

# Data (running example)
dat <- Ex9.1()
x <- dat$X
y <- dat$y
n <- nrow(x)
p <- ncol(x)
var_type <- dat$varType

# BRCA1 data set (fix some problematic gene names)
brca <- read_data(brca1)
colnames(brca$X) <- gsub('-', '_', colnames(brca$X), fixed = TRUE)

# Semi-penalized likelihood ratio test
splrt <- function(j, dat, lam) {
  x <- dat$X
  y <- dat$y
  L <- length(lam)
  fit0 <- ncvreg(x[, -j], y, lambda = lam, penalty = 'lasso')
  w <- rep(1, ncol(x))
  w[j] <- 0
  fit1 <- ncvreg(x, y, penalty.factor = w, lambda = lam, penalty = 'lasso')
  stat <- fit0$loss[L] - fit1$loss[L]
  pchisq(stat, 1, lower.tail = FALSE)
}

# Comparison with mfdr
fold <- assign_fold(y, 10, seed = 10)
cvfit <- cv.ncvreg(x, y, penalty = 'lasso', fold = fold)
fit <- cvfit$fit
splrt_p <- vapply(1:60, splrt, double(1), dat = dat, lam = cvfit$lambda[1:cvfit$min])
s <- summary(fit, lambda = cvfit$lambda.min)

tab <- data.table(
  Feature = colnames(x),
  Estimate = coef(cvfit)[-1],
  mfdr = local_mfdr(fit, lambda = cvfit$lambda.min)$mfdr |> hdrm:::format_p(),
  SPLRT = splrt_p |> hdrm:::format_p()
)
tab[order(-abs(Estimate))[1:10]] |>
  kbl('latex', booktabs = TRUE, digits = 2, align = 'lrrr') |>
  kable_styling(position = 'center')

cvfit <- cv.ncvreg(brca$X, brca$y, penalty = 'lasso')
ind <- predict(cvfit, type = 'vars')
p_splrt <- vapply(
  ind, splrt, double(1),
  dat = list(X = brca$X, y = brca$y), lam = cvfit$lambda[1:cvfit$min]
)

# Sample splitting --------------------------------------------------------

# Single split
set.seed(7)
ind <- as.logical(sample(rep(0:1, each = n / 2)))
cvfit <- cv.glmnet(x[ind, ], y[ind])
b <- coef(cvfit, s = cvfit$lambda.min)[-1]
sel_type <- tapply(b != 0, var_type, sum)
xx <- x[!ind, which(b != 0)]
fit <- lm(y[!ind] ~ xx)
# summary(fit)
summ <- summary(fit)$coefficients[-1, ]
var_id <- gsub('xx', '', rownames(summ), fixed = TRUE)
pval <- rep(1, ncol(x)); names(pval) <- colnames(x)
pval[var_id] <- summ[,4]
sig_type <- tapply(pval < 0.05, var_type, sum)

P <- matrix(1, 100, ncol(x), dimnames = list(1:100, colnames(x)))
for (i in 1:100) {
  ind <- as.logical(sample(rep(0:1, each = n/2)))
  cvfit <- cv.glmnet(x[ind,], y[ind])
  b <- coef(cvfit, s = cvfit$lambda.min)[-1]
  xx <- x[!ind, which(b != 0), drop = FALSE]
  fit <- lm(y[!ind] ~ ., data = as.data.frame(xx))
  summ <- summary(fit)$coefficients[-1, , drop = FALSE]
  P[i, rownames(summ)] <- summ[, 4]
}
par(mar = c(2, 4, 0.5, 0.5))
boxplot(apply(P, 2, median)[which(var_type == 'A')],
        apply(P, 2, median)[which(var_type == 'B')],
        apply(P, 2, median)[which(var_type == 'N')],
        col = "gray", frame.plot = FALSE, pch = 19,
        names = c('A', 'B', 'N'), las = 1, ylim = c(0,1), ylab = "p")
p_agg <- apply(P, 2, median)

# TCGA
P <- matrix(1, 100, ncol(brca$X), dimnames = list(1:100, colnames(brca$X)))
for (i in 1:100) {
  ind <- as.logical(sample(rep(0:1, each = n/2)))
  cvfit <- cv.glmnet(brca$X[ind,], brca$y[ind])
  b <- coef(cvfit, s = cvfit$lambda.min)[-1]
  xx <- brca$X[!ind, which(b != 0)]
  fit <- lm(brca$y[!ind] ~ ., as.data.frame(xx))
  summ <- summary(fit)$coefficients[-1, , drop = FALSE]
  P[i, rownames(summ)] <- summ[, 4]
}
p_agg <- apply(P, 2, median, na.rm = TRUE)  # Sometimes lasso selects p > n in stage 1
# sum(p_agg < .05)

# Stability selection -----------------------------------------------------

fit <- glmnet(x, y)
SS <- array(
  NA, dim = c(100, ncol(x), length(fit$lambda)),
  dimnames = list(1:100, colnames(x), fit$lambda)
)
Q <- matrix(NA, 100, length(fit$lambda))
for (i in 1:100) {
  ind <- as.logical(sample(rep(0:1, each = n/2)))
  fit.i <- glmnet(x[ind,], y[ind], lambda = fit$lambda)
  SS[i, , ] <- as.matrix(coef(fit.i)[-1, ] != 0)
  Q[i,] <- vapply(predict(fit.i, type = "nonzero"), length, integer(1))
}
S <- apply(SS, 2:3, mean)
q <- colMeans(Q)

# FDR Bound
ev <- q^2 / (ncol(x) * (2 * .8 - 1))
fdr <- ev / sapply(predict(fit, type = "nonzero"), length)
max(which(fdr < .1))
which(S[, max(which(fdr < .1))] > .8)

l <- fit$lambda
col <- rep("gray", ncol(x))
col[var_type == "A"] <- "red"
par(mar = c(4, 4, 0.5, 0.5))
matplot(
  l, t(S), type = "l", lty = 1, xlim = rev(range(l)), col = col, lwd = 2,
  las = 1, bty = "n", xlab = expression(lambda), ylab = "Stability"
)

fit <- glmnet(brca$X, brca$y, lambda.min = 0.1)
SS <- array(
  NA, dim = c(100, ncol(brca$X), length(fit$lambda)),
  dimnames = list(1:100, colnames(brca$X), fit$lambda)
)
Q <- matrix(NA, 100, length(fit$lambda))
for (i in 1:100) {
  ind <- as.logical(sample(rep(0:1, each = length(brca$y) / 2)))
  fit.i <- glmnet(brca$X[ind,], brca$y[ind], lambda = fit$lambda)
  SS[i,,] <- as.matrix(coef(fit.i)[-1, ] != 0)
  Q[i,] <- sapply(predict(fit.i, type = "nonzero"), length)
}
S <- apply(SS, 2:3, mean)
q <- colMeans(Q)
l <- fit$lambda
col <- rep("gray", ncol(brca$X))
col[which(apply(S, 1, max) > 0.6)] <- "red"
par(mar = c(5, 5, 0.5, 0.5))
matplot(
  l, t(S), type = "l", lty = 1, xlim = rev(range(l)), col = col,
  lwd = 2, las = 1, bty = "n", xlab = expression(lambda), ylab = "Stability"
)
# table(col)

# FDR bound for TCGA stability selection
ev <- q^2 / (ncol(brca$X) * (2 * 0.8 - 1))
fdr <- ev / sapply(predict(fit, type = "nonzero"), length)
max(which(fdr < .1))
which(S[, max(which(fdr < .1))] > .8)
which(S[, max(which(fdr < .3))] > .8)

# Basic bootstrap
B <- matrix(NA, 1000, p, dimnames = list(1:1000, colnames(x)))
for (i in 1:1000) {
  ind <- sample(1:n, replace = TRUE)
  xx <- x[ind,]
  yy <- y[ind]
  fit.i <- glmnet(xx, yy, lambda = 0.07)
  B[i,] <- as.numeric(coef(fit.i)[-1])
}

boxplot(
  B[, var_type == "A"], col = "gray", pch = "|", horizontal = TRUE,
  frame.plot = FALSE, cex = 0.5, las = 1, ylim = c(-1.5, 1.5)
)
boxplot(
  B[,var_type == "B"], col = "gray", pch = "|", horizontal = TRUE,
  frame.plot = FALSE, cex = 0.5, las = 1, ylim = c(-1,1)
)

bz <- colMeans(B != 0)
Ba <- B[, var_type == "A"]
boot_p <- pmin(2 * pmin(colMeans(B >= 0), colMeans(B <= 0)), 1)
ci <- cbind(
  apply(Ba, 2, median),
  t(apply(Ba, 2, quantile, probs = c(.025, .975))),
  boot_p[var_type == 'A']
)
ci_plot(
  ci[c(1, 2, 3, 5, 4, 6), ], sort = FALSE, xlab = expression(beta),
  mar = c(5, 3, 0, 5), xlim = c(-1.5,1.5)
)
lines(c(1, 1), c(5.5, 6.5), col = "gray", lty = 2, lwd = 2, xpd = 1)
lines(c(-1, -1), c(4.5, 5.5), col = "gray", lty = 2, lwd = 2)
lines(c(0.5,0.5), c(2.5, 4.5), col = "gray", lty = 2, lwd = 2)
lines(c(-0.5,-0.5), c(0.5, 2.5), col = "gray", lty = 2, lwd = 2)

bfit <- boot_ncvreg(x, y, penalty = 'MCP')
ci_plot(
  bfit$confidence_intervals[c(1, 4, 7, 13, 10, 16), ], sort = FALSE,
  xlab = expression(beta), mar = c(5, 3, 0, 0.5), xlim = c(-1.5, 1.5)
)
lines(c(1, 1), c(5.5, 6.5), col = "gray", lty = 2, lwd = 2, xpd = 1)
lines(c(-1, -1), c(4.5, 5.5), col = "gray", lty = 2, lwd = 2)
lines(c(0.5,0.5), c(2.5, 4.5), col = "gray", lty = 2, lwd = 2)
lines(c(-0.5,-0.5), c(0.5, 2.5), col = "gray", lty = 2, lwd = 2)

