library(hdrm)
library(knockoff)
library(ggplot2)
f <- function(x) {
  w <- which(x != 0)
  if (length(w) > 0) min(w) else Inf
}

# Running example (supplemented)
Data <- Ex9.1()
X <- rbind(Data$X, Data$X)
set.seed(1)
y <- c(Data$y, rnorm(100, Data$X %*% Data$beta))
n <- nrow(X)
p <- ncol(X)
varType <- Data$varType

# Knockoff filter: Fixed X ------------------------------------------------

# QR construction of N (for illustration)
QR <- qr(X)
Q <- qr.Q(QR, complete=TRUE)
N <- Q[, 61:120]
all.equal(crossprod(N), diag(60))
all.equal(crossprod(N, X), matrix(0, 60, 60), check.attributes=FALSE)

# Construct knockoffs and fit augmented model
tmp <- create.fixed(X, method='equi', sigma=1)
XX <- cbind(tmp$X, tmp$Xk)
fit <- glmnet(XX, y)

# Construct test statistics
Bg <- coef(fit)[2:61, ]
Bk <- coef(fit)[62:121, ]
zg <- fit$lambda[apply(Bg, 1, f)]
zk <- fit$lambda[apply(Bk, 1, f)]
zk[is.na(zk)] <- 0
ll <- 0.175

# FDR figure
numer <- data.frame(x=c(ll, 1.1, 0, 0),
                    y=c(ll, 1.1, 1.1, ll))
denom <- data.frame(x=c(ll, 1.1, 1.1, ll),
                    y=c(0, 0, 1.1, ll))
g <- ggplot(data.frame(zg, zk, varType), aes(zg, zk, color = varType)) +
  geom_point() +
  geom_abline(intercept = 0, slope = 1, lty = 2, col = "gray50") +
  xlim(0, 1.1) +
  ylim(0, 1.1) +
  coord_fixed() +
  xlab(expression(lambda * " (Genuine)")) +
  ylab(expression(lambda * " (Knockoff)"))
g
g + geom_polygon(aes(x, y), data=denom, show.legend=FALSE, color=NA, fill=hdrm:::pal(2, alpha = 0.3)[1])
g + geom_polygon(aes(x, y), data=denom, show.legend=FALSE, color=NA, fill=hdrm:::pal(2, alpha = 0.3)[1]) +
  geom_polygon(aes(x, y), data=numer, show.legend=FALSE, color=NA, fill=hdrm:::pal(2, alpha = 0.3)[2])

# Demo usage using classroom settings
kf <- function(X) create.fixed(X, method='equi')
res <- knockoff.filter(
  X2, y2, knockoffs=kf, statistic = stat.glmnet_lambdasmax,
  offset=0)  # Offset=1 adds 1 to the numerator as discussed in class
res$selected

# Demo usage using defaults (sdp instead of equicorrelation, coef difference instead of lambda)
res <- knockoff.filter(X2, y2, knockoffs=create.fixed, offset=0)
res$selected


# Augmentation approach ---------------------------------------------------

# Construct knockoffs and fit augmented model
set.seed(1)
tmp <- create.fixed(Data$X, method='equi', sigma=1)
XX <- cbind(tmp$X, tmp$Xk)
fit <- glmnet(XX, c(Data$y, tmp$y))

# Construct test statistics
Bg <- coef(fit)[2:61, ]
Bk <- coef(fit)[62:121, ]
zg <- fit$lambda[apply(Bg, 1, f)]
zk <- fit$lambda[apply(Bk, 1, f)]
zk[is.na(zk)] <- 0

# Define polygon for plot
ll <- 0.3
numer <- data.frame(x=c(ll, 1.1, 0, 0),
                    y=c(ll, 1.1, 1.1, ll))
denom <- data.frame(x=c(ll, 1.1, 1.1, ll),
                    y=c(0, 0, 1.1, ll))

# Plot
g <- ggplot(data.frame(zg, zk, varType), aes(zg, zk, color = varType)) +
  geom_point() +
  geom_abline(intercept = 0, slope = 1, lty = 2, col = "gray50") +
  xlim(0, 1.1) +
  ylim(0, 1.1) +
  coord_fixed() +
  xlab(expression(lambda * " (Genuine)")) +
  ylab(expression(lambda * " (Knockoff)"))
g
g + geom_polygon(aes(x, y), data=denom, show.legend=FALSE, color=NA, fill=hdrm:::pal(2, alpha = 0.3)[1])
g + geom_polygon(aes(x, y), data=denom, show.legend=FALSE, color=NA, fill=hdrm:::pal(2, alpha = 0.3)[1]) +
  geom_polygon(aes(x, y), data=numer, show.legend=FALSE, color=NA, fill=hdrm:::pal(2, alpha = 0.3)[2])

kf <- function(X) create.fixed(X, method='equi', y=Data$y)
res <- knockoff.filter(Data$X, Data$y, knockoffs=kf, statistic = stat.glmnet_lambdasmax, offset=0)
res$selected
res$statistic
res$threshold


# Model-X -----------------------------------------------------------------

# Construct knockoffs and fit augmented model
set.seed(2)
XK <- create.second_order(Data$X, method='equi')
XX <- cbind(Data$X, XK)
fit <- glmnet(XX, Data$y, nlambda = 200)

# Construct test statistics
Bg <- coef(fit)[2:61, ]
Bk <- coef(fit)[62:121, ]
zg <- fit$lambda[apply(Bg, 1, f)]
zk <- fit$lambda[apply(Bk, 1, f)]
zg[is.na(zg)] <- 0
zk[is.na(zk)] <- 0

# Define polygon for plot
ll <- 0.3
numer <- data.frame(x=c(ll, 1.1, 0, 0),
                    y=c(ll, 1.1, 1.1, ll))
denom <- data.frame(x=c(ll, 1.1, 1.1, ll),
                    y=c(0, 0, 1.1, ll))

# Plot
g <- ggplot(data.frame(zg, zk, varType), aes(zg, zk, color = varType)) +
  geom_point() +
  geom_abline(intercept = 0, slope = 1, lty = 2, col = "gray50") +
  xlim(0, 1.1) +
  ylim(0, 1.1) +
  coord_fixed() +
  xlab(expression(lambda * " (Genuine)")) +
  ylab(expression(lambda * " (Knockoff)"))
g
g + geom_polygon(aes(x, y), data=denom, show.legend=FALSE, color=NA, fill=hdrm:::pal(2, alpha = 0.3)[1])
g + geom_polygon(aes(x, y), data=denom, show.legend=FALSE, color=NA, fill=hdrm:::pal(2, alpha = 0.3)[1]) +
  geom_polygon(aes(x, y), data=numer, show.legend=FALSE, color=NA, fill=hdrm:::pal(2, alpha = 0.3)[2])

# Using classroom settings
set.seed(2)
kf <- function(X) create.second_order(X, method='equi')
res <- knockoff.filter(Data$X, Data$y, knockoffs=kf, statistic = stat.glmnet_lambdasmax, offset=0)
res$selected
res$statistic
res$threshold

# Using defaults
res <- knockoff.filter(Data$X, Data$y, offset=0)
res$selected
res$statistic
res$threshold
