library(hdrm)
library(knockoff)
library(ggplot2)
f <- function(x) {
  w <- which(x != 0)
  if (length(w) > 0) min(w) else Inf
}

# TCGA
brca <- read_data(brca1)
x_brca <- brca$X
y_brca <- brca$y
n_brca <- length(y_brca)

# Running example
dat <- Ex9.1()
var_type <- dat$varType

# Supplemented running example (doubling n so that n > 2p)
set.seed(1)
sup <- list(
  x = rbind(dat$X, dat$X),
  y = c(dat$y, rnorm(100, dat$X %*% dat$beta))
)

# QR construction of N (for illustration)
qr_x <- qr(sup$x)
qx <- qr.Q(qr_x, complete = TRUE)
nx <- qx[, 61:120]
all.equal(crossprod(nx), diag(60))
all.equal(crossprod(nx, sup$x), matrix(0, 60, 60), check.attributes = FALSE)

# Construct knockoffs and fit augmented model
ko <- create.fixed(sup$x, method = 'equi', sigma = 1)
xx <- cbind(ko$X, ko$Xk)
fit <- glmnet(xx, sup$y)

# Construct test statistics
beta_g <- coef(fit)[2:61, ]
beta_k <- coef(fit)[62:121, ]
zg <- fit$lambda[apply(beta_g, 1, f)]
zk <- fit$lambda[apply(beta_k, 1, f)]
zk[is.na(zk)] <- 0

# Define polygon for plot
ll <- 0.175
numer <- data.frame(x = c(ll, 1.1, 0, 0), y = c(ll, 1.1, 1.1, ll))
denom <- data.frame(x = c(ll, 1.1, 1.1, ll), y = c(0, 0, 1.1, ll))

# FDR figure
g <- data.frame(zg, zk, var_type) |>
  ggplot(aes(zg, zk, color = var_type)) +
  geom_point() +
  geom_abline(intercept = 0, slope = 1, lty = 2, col = 'gray50') +
  xlim(0, 1.1) +
  ylim(0, 1.1) +
  coord_fixed() +
  xlab(expression(lambda * ' (Genuine)')) +
  ylab(expression(lambda * ' (Knockoff)')) +
  theme_minimal() +
  guides(color = guide_legend(title = ''))
g
g +
  geom_polygon(
    aes(x, y),
    data = denom,
    show.legend = FALSE,
    color = NA,
    fill = hdrm::pal(2, alpha = 0.3)[1]
  )
g +
  geom_polygon(
    aes(x, y),
    data = denom,
    show.legend = FALSE,
    color = NA,
    fill = hdrm::pal(2, alpha = 0.3)[1]
  ) +
  geom_polygon(
    aes(x, y),
    data = numer,
    show.legend = FALSE,
    color = NA,
    fill = hdrm::pal(2, alpha = 0.3)[2]
  )

# Using the knockoff.filter() function
kf <- function(x) create.fixed(x, method = 'equi', sigma = 1)
res <- knockoff.filter(
  sup$x,
  sup$y,
  knockoffs = kf,
  statistic = stat.glmnet_lambdasmax,
  offset = 0
)
res$selected

# Construct knockoffs for original data
set.seed(1)
ko <- create.fixed(dat$X, method = 'equi', sigma = 1)
xx <- cbind(ko$X, ko$Xk)
fit <- glmnet(xx, c(dat$y, ko$y))

# Construct test statistics
beta_g <- coef(fit)[2:61, ]
beta_k <- coef(fit)[62:121, ]
zg <- fit$lambda[apply(beta_g, 1, f)]
zk <- fit$lambda[apply(beta_k, 1, f)]
zk[is.na(zk)] <- 0

# Define polygon for plot
ll <- 0.3
numer <- data.frame(x = c(ll, 1.1, 0, 0), y = c(ll, 1.1, 1.1, ll))
denom <- data.frame(x = c(ll, 1.1, 1.1, ll), y = c(0, 0, 1.1, ll))

# Plot
g <- ggplot(data.frame(zg, zk, var_type), aes(zg, zk, color = var_type)) +
  geom_point() +
  geom_abline(intercept = 0, slope = 1, lty = 2, col = 'gray50') +
  xlim(0, 1.1) +
  ylim(0, 1.1) +
  coord_fixed() +
  xlab(expression(lambda * ' (Genuine)')) +
  ylab(expression(lambda * ' (Knockoff)')) +
  theme_minimal() +
  guides(color = guide_legend(title = ''))
g
g +
  geom_polygon(
    aes(x, y),
    data = denom,
    show.legend = FALSE,
    color = NA,
    fill = pal(2, alpha = 0.3)[1]
  ) +
  geom_polygon(
    aes(x, y),
    data = numer,
    show.legend = FALSE,
    color = NA,
    fill = pal(2, alpha = 0.3)[2]
  )

# Using knockoff.filter()
kf <- function(x) create.fixed(x, method = 'equi', y = dat$y)
res <- knockoff.filter(
  dat$X,
  dat$y,
  knockoffs = kf,
  statistic = stat.glmnet_lambdasmax,
  offset = 0
)
res$selected
res$statistic
res$threshold

set.seed(2)
ko <- create.second_order(dat$X, method = 'equi')
xx <- cbind(dat$X, ko)
fit <- glmnet(xx, dat$y, nlambda = 200)

# Construct test statistics
beta_g <- coef(fit)[2:61, ]
beta_k <- coef(fit)[62:121, ]
zg <- fit$lambda[apply(beta_g, 1, f)]
zk <- fit$lambda[apply(beta_k, 1, f)]
zg[is.na(zg)] <- 0
zk[is.na(zk)] <- 0

# Define polygon for plot
ll <- 0.3
numer <- data.frame(x = c(ll, 1.1, 0, 0), y = c(ll, 1.1, 1.1, ll))
denom <- data.frame(x = c(ll, 1.1, 1.1, ll), y = c(0, 0, 1.1, ll))

# Plot
g <- ggplot(data.frame(zg, zk, var_type), aes(zg, zk, color = var_type)) +
  geom_point() +
  geom_abline(intercept = 0, slope = 1, lty = 2, col = 'gray50') +
  xlim(0, 1.1) +
  ylim(0, 1.1) +
  coord_fixed() +
  xlab(expression(lambda * ' (Genuine)')) +
  ylab(expression(lambda * ' (Knockoff)')) +
  theme_minimal() +
  guides(color = guide_legend(title = ''))  
g
g +
  geom_polygon(
    aes(x, y),
    data = denom,
    show.legend = FALSE,
    color = NA,
    fill = pal(2, alpha = 0.3)[1]
  ) +
  geom_polygon(
    aes(x, y),
    data = numer,
    show.legend = FALSE,
    color = NA,
    fill = pal(2, alpha = 0.3)[2]
  )

# Using knockoff.filter()
set.seed(2)
kf <- function(x) create.second_order(x, method = 'equi')
res <- knockoff.filter(
  dat$X,
  dat$y,
  knockoffs = kf,
  statistic = stat.glmnet_lambdasmax,
  offset = 0,
  fdr = 0.2
)
res$selected
res$statistic
res$threshold

# # Crashes machine; don't run!
# res <- knockoff.filter(x_brca, y_brca)

