library(hdrm)
library(selectiveInference)

# Data (running examples)
attachData(bcTCGA)
X.TCGA <- X
y.TCGA <- y
n.TCGA <- length(y.TCGA)
set.seed(72)
Data <- genDataABN(n=100, p=60, a=6, b=2, rho=0.5, beta=c(1,-1,0.5,-0.5,0.5,-0.5))
X <- Data$X
y <- Data$y
n <- nrow(X)
p <- ncol(X)
varType <- Data$varType


# Motivating simulations --------------------------------------------------

# Forward selection
N <- 200
R1 <- R2 <- numeric(N)
for (i in 1:N) {
  yy <- scale(rnorm(100), scale=FALSE)
  XX <- std(matrix(rnorm(100*100), 100, 100))
  jj <- which.max(abs(crossprod(XX, yy)))
  fit1 <- lm(yy~XX[,1])
  fit2 <- lm(yy~XX[,jj])
  R1[i] <- crossprod(yy) - crossprod(fit1$residuals)
  R2[i] <- crossprod(yy) - crossprod(fit2$residuals)
}

# QQ plot for a prespecified feature
qqplot(qchisq(ppoints(100), 1), R1, las=1, bty="n", pch=16, ylim=c(0, 8),
       xlab=expression(chi^2), ylab=expression(T[R]))
lines(0:8, 0:8, lty=2, lwd=2, col="gray")
mean(R1 > qchisq(.95, 1))

# QQ plot when we cherry-pick the best feature
qqplot(qchisq(ppoints(100), 1), R2, las=1, bty="n", pch=16, ylim=c(0, max(R2)),
       xlab=expression(chi^2), ylab=expression(T[R]))
lines(0:8, 0:8, lty=2, lwd=2, col="gray")
mean(R2 > qchisq(.95, 1))


# CovTest: Example data ---------------------------------------------------

# Simulation
# Note: selectiveInference does not return the test statistic, so I have to back-transform here
set.seed(1)
N <- 200
p <- numeric(N)
for (i in 1:N) {
  yy <- scale(rnorm(100), scale=FALSE)
  XX <- std(matrix(rnorm(100*100), 100, 100))
  fit <- lar(XX, yy, maxsteps=2)
  p[i]<- suppressWarnings(larInf(fit, sigma=1)$pv.covtest[1])
}
TS <- qexp(1-p)
par(mar=c(4,4,1,1))
qqplot(qexp(ppoints(length(TS)), 1), TS, las=1, bty="n", pch=16, ylim=c(0, max(TS)),
       xlab="Exp(1)", ylab=expression(T[C]))
lines(0:8, 0:8, lty=2, lwd=2, col="gray")
mean(TS > qexp(.95, 1))


# Example data set
fit <- lar(X, y)
res <- larInf(fit, sigma = estimateSigma(X, y)$sigmahat)
Tab <- data.frame(Var = colnames(X)[res$vars],
                  TS  = qexp(1-res$pv.covtest),
                  p   = res$pv.covtest)
head(Tab)

# FDR-based stopping rule
forwardStop(res$pv.covtest)


# Selective inference -----------------------------------------------------

# Example data
X.std <- std(X)
fit = glmnet(X.std, y, standardize=FALSE)
lam <- 25
b <- coef(fit, s=lam/n)[-1]
res <- fixedLassoInf(X.std, y, b, lam, sigma=1)
bb <- res$vmat %*% y
B <- cbind(bb, res$ci, res$pv)
dimnames(B) <- list(names(res$vars), c('Estimate', 'Lower', 'Upper', 'p'))

# CI plot: Selective inference
CIplot(B, sort=FALSE, xlab=expression(beta), xlim=c(-4, 4))
lines(c(1,1), c(8.5, 9.5), col="gray", lty=2, lwd=2, xpd=1)
lines(c(-1,-1), c(7.5, 8.5), col="gray", lty=2, lwd=2)
lines(c(0.5,0.5), c(5.5, 6.5), col="gray", lty=2, lwd=2)
lines(c(0.5,0.5), c(3.5, 4.5), col="gray", lty=2, lwd=2)
lines(c(-0.5,-0.5), c(4.5, 5.5), col="gray", lty=2, lwd=2)
lines(c(-0.5,-0.5), c(1.5, 2.5), col="gray", lty=2, lwd=2)

# CI plot: OLS ignoring selection
CIplot(lm(y~X.std[,b!=0]), sort=FALSE, labels=names(res$vars), xlab=expression(beta), xlim=c(-4, 4))
lines(c(1,1), c(8.5, 9.5), col="gray", lty=2, lwd=2, xpd=1)
lines(c(-1,-1), c(7.5, 8.5), col="gray", lty=2, lwd=2)
lines(c(0.5,0.5), c(5.5, 6.5), col="gray", lty=2, lwd=2)
lines(c(0.5,0.5), c(3.5, 4.5), col="gray", lty=2, lwd=2)
lines(c(-0.5,-0.5), c(4.5, 5.5), col="gray", lty=2, lwd=2)
lines(c(-0.5,-0.5), c(1.5, 2.5), col="gray", lty=2, lwd=2)


# TCGA data ---------------------------------------------------------------

# Selective inference
X.std <- std(X.TCGA)
fit = glmnet(X.std, y.TCGA, standardize=FALSE)
lam <- 115
b <- coef(fit, s=lam/n.TCGA)[-1]
sum(b!=0)
sh <- estimateSigma(X.std, y.TCGA)$sigmahat  # About 0.4 or 0.45
res <- fixedLassoInf(X.std, y.TCGA, b, lam, sigma=0.45)
bb <- res$vmat %*% y.TCGA
B <- cbind(bb, res$ci, res$pv)
rownames(B) <- names(res$vars)
CIplot(B, sort=FALSE, mar=c(4, 5, 1, 5), xlab=expression(beta))

# Covtest
fit <- lar(X.TCGA, y.TCGA, maxsteps=11)
res <- larInf(fit, sigma=0.45, k=10)
Tab <- data.frame(Var = colnames(X.TCGA)[res$vars],
                  TS  = qexp(1-res$pv.covtest),
                  p   = res$pv.covtest)
Tab
forwardStop(res$pv.covtest)
