library(hdrm)
library(knockoff)
library(ggplot2)


# Running example (X2, y2 supplemented to n=200)
set.seed(72)
Data <- genDataABN(n=100, p=60, a=6, b=2, rho=0.5, beta=c(1,-1,0.5,-0.5,0.5,-0.5))
X <- Data$X
X2 <- rbind(Data$X, Data$X)
y <- Data$y
y2 <- c(Data$y, rnorm(100, Data$X %*% Data$beta))
varType <- Data$varType


# Knockoff filter: Fixed X ------------------------------------------------

# Construct knockoffs and fit augmented model
tmp <- create.fixed(X2, method='equi', sigma=1)
XX <- cbind(tmp$X, tmp$Xk)
fit <- glmnet(XX, y2)

# Construct test statistics
Bg <- coef(fit)[2:61,]
Bk <- coef(fit)[62:121,]
zg <- fit$lambda[apply(Bg, 1, function(x) min(which(x!=0)))]
zk <- fit$lambda[apply(Bk, 1, function(x) min(which(x!=0)))]
zk[is.na(zk)] <- 0

# FDR figure
ll <- 0.175
numer <- data.frame(x=c(ll, 1.1, 0, 0),
                    y=c(ll, 1.1, 1.1, ll))
denom <- data.frame(x=c(ll, 1.1, 1.1, ll),
                    y=c(0, 0, 1.1, ll))
g <- ggplot(data.frame(zg, zk, varType), aes(zg, zk, color=varType)) + geom_point() + geom_abline(intercept=0, slope=1, lty=2, col='gray50') +
  xlim(0,1.1) + ylim(0,1.1) + coord_fixed() + xlab(expression(lambda*' (Genuine)')) + ylab(expression(lambda*' (Knockoff)'))
g
g + geom_polygon(aes(x, y), data=denom, show.legend=FALSE, color=NA, fill=pal(2, alpha = 0.3)[1])
g + geom_polygon(aes(x, y), data=denom, show.legend=FALSE, color=NA, fill=pal(2, alpha = 0.3)[1]) +
  geom_polygon(aes(x, y), data=numer, show.legend=FALSE, color=NA, fill=pal(2, alpha = 0.3)[2])

# Demo usage using classroom settings
kf <- function(X) create.fixed(X, method='equi')
res <- knockoff.filter(X2, y2, knockoffs=kf, statistic = stat.glmnet_lambdasmax, offset=0)  # Offset=1 adds 1 to the numerator as discussed in class
res$selected

# Demo usage using defaults (sdp instead of equicorrelation, coef difference instead of lambda)
res <- knockoff.filter(X2, y2, knockoffs=create.fixed, offset=0)
res$selected


# Augmentation approach ---------------------------------------------------

# Construct knockoffs and fit augmented model
set.seed(1)
tmp <- create.fixed(X, method='equi', sigma=1)
XX <- cbind(tmp$X, tmp$Xk)
fit <- glmnet(XX, c(y, tmp$y))

# Construct test statistics
Bg <- coef(fit)[2:61,]
Bk <- coef(fit)[62:121,]
zg <- fit$lambda[apply(Bg, 1, function(x) min(which(x!=0)))]
zk <- fit$lambda[apply(Bk, 1, function(x) min(which(x!=0)))]
zk[is.na(zk)] <- 0

# Define polygon for plot
ll <- 0.6
numer <- data.frame(x=c(ll, 1.1, 0, 0),
                    y=c(ll, 1.1, 1.1, ll))
denom <- data.frame(x=c(ll, 1.1, 1.1, ll),
                    y=c(0, 0, 1.1, ll))

# Plot
g <- ggplot(data.frame(zg, zk, varType), aes(zg, zk, color=varType)) + geom_point() + geom_abline(intercept=0, slope=1, lty=2, col='gray50') +
  xlim(0,1.1) + ylim(0,1.1) + coord_fixed() + xlab(expression(lambda*' (Genuine)')) + ylab(expression(lambda*' (Knockoff)'))
g
g + geom_polygon(aes(x, y), data=denom, show.legend=FALSE, color=NA, fill=pal(2, alpha = 0.3)[1])
g + geom_polygon(aes(x, y), data=denom, show.legend=FALSE, color=NA, fill=pal(2, alpha = 0.3)[1]) +
  geom_polygon(aes(x, y), data=numer, show.legend=FALSE, color=NA, fill=pal(2, alpha = 0.3)[2])

kf <- function(X) create.fixed(X, method='equi', y=Data$y)
res <- knockoff.filter(Data$X, Data$y, knockoffs=kf, statistic = stat.glmnet_lambdasmax, offset=0)
res$selected


# Model-X -----------------------------------------------------------------

# Construct knockoffs and fit augmented model
set.seed(2)
XK <- create.second_order(X, method='equi')
XX <- cbind(Data$X, XK)
fit <- glmnet(XX, y, nlambda = 200)

# Construct test statistics
Bg <- coef(fit)[2:61,]
Bk <- coef(fit)[62:121,]
zg <- fit$lambda[apply(Bg, 1, function(x) min(which(x!=0)))]
zk <- fit$lambda[apply(Bk, 1, function(x) min(which(x!=0)))]
zg[is.na(zg)] <- 0
zk[is.na(zk)] <- 0

# Define polygon for plot
ll <- 0.3
numer <- data.frame(x=c(ll, 1.1, 0, 0),
                    y=c(ll, 1.1, 1.1, ll))
denom <- data.frame(x=c(ll, 1.1, 1.1, ll),
                    y=c(0, 0, 1.1, ll))

# Plot
g <- ggplot(data.frame(zg, zk, varType), aes(zg, zk, color=varType)) + geom_point() + geom_abline(intercept=0, slope=1, lty=2, col='gray50') +
  xlim(0,1.1) + ylim(0,1.1) + coord_fixed() + xlab(expression(lambda*' (Genuine)')) + ylab(expression(lambda*' (Knockoff)'))
g
g + geom_polygon(aes(x, y), data=denom, show.legend=FALSE, color=NA, fill=pal(2, alpha = 0.3)[1])
g + geom_polygon(aes(x, y), data=denom, show.legend=FALSE, color=NA, fill=pal(2, alpha = 0.3)[1]) +
  geom_polygon(aes(x, y), data=numer, show.legend=FALSE, color=NA, fill=pal(2, alpha = 0.3)[2])

# Using classroom settings
kf <- function(X) create.second_order(X, method='equi')
res <- knockoff.filter(Data$X, Data$y, knockoffs=kf, statistic = stat.glmnet_lambdasmax, offset=0)
res$selected

# Using defaults
res <- knockoff.filter(Data$X, Data$y, offset=0)
res$selected
