library(hdrm)
library(PRROC)
golub <- read_data(Golub1999)
y <- golub$y
x <- golub$X
colnames(x) <- janitor::make_clean_names(colnames(x))

# Deviance demo
dev <- function(y, p) { -2 * (y * log(p) + (1 - y) * log(1 - p)) }
dev(0, 0.001)
dev(0, 0.5)
dev(0, 0.999)

# cv.glmnet(x, y, family = 'binomial')

cvfit_lasso <- cv.ncvreg(
  x,
  y,
  family = 'binomial',
  penalty = 'lasso',
  nfolds = length(y),
  lambda.min = 0.01,
  returnY = TRUE
)
par(mar = c(4, 4, 3, 0.5))
plot(cvfit_lasso, ylim = c(0, 1.5), bty = 'n')

# cv.ncvreg(x, y, family = 'binomial', gamma = 4)

cvfit_mcp <- cv.ncvreg(
  x,
  y,
  family = 'binomial',
  penalty = 'MCP',
  gamma = 4,
  nfolds = length(y),
  lambda.min = 0.1
)
par(mar = c(4, 4, 3, 0.5))
plot(cvfit_mcp, ylim = c(0, 1.5), bty = 'n')

par(mar = c(4, 4, 3, 0.5))
plot(cvfit_lasso, type = 'rsq', ylim = c(0, 0.7), bty = 'n')

par(mar = c(4, 4, 3, 0.5))
plot(cvfit_mcp, type = 'rsq', ylim = c(0, 0.7), bty = 'n')

par(mar = c(4, 4, 3, 0.5))
plot(cvfit_lasso, type = 'pred', ylim = c(0, 0.5), bty = 'n')

par(mar = c(4, 4, 3, 0.5))
plot(cvfit_mcp, type = 'pred', ylim = c(0, 0.5), bty = 'n')

# This ensures a monotone curve
interpolate <- function(pr) {
  pr$curve[,2] <- cummax(pr$curve[,2])
  pr
}

# Precision-recall
par(mfrow = c(1, 2))
cvp <- cvfit_lasso$Y[, cvfit_lasso$min]
pr <- pr.curve(cvp[y == 'AML'], cvp[y == 'ALL'], curve = TRUE) |> interpolate()
plot(
  pr, color = pal(2)[2], main = '', auc.main = FALSE, bty = 'n', las = 1,
  ylab = 'PPV', xlab = 'Sensitivity'
)
text(
  0.6,
  0.6,
  paste('AUC: ', formatC(pr$auc.integral, 3, format = 'f')),
  col = hdrm::pal(2)[2],
  cex = 1
)

# ROC
roc <- roc.curve(cvp[y == 'AML'], cvp[y == 'ALL'], curve = TRUE)
plot(
  roc, color = hdrm::pal(2)[1], main = '', auc.main = FALSE, bty = 'n', las = 1,
  xlab = "1 - Specificity"
)
lines(0:1, 0:1, col = 'gray', lty = 2)
text(
  0.3,
  0.6,
  paste('AUC: ', formatC(roc$auc, 3, format = 'f')),
  col = hdrm::pal(2)[1],
  cex = 1
)

# Do **NOT** do this
pred <- cvfit_lasso$fit$linear.predictors[, cvfit_lasso$min]
pr <- pr.curve(pred[y == 'AML'], pred[y == 'ALL'], curve = TRUE) |> interpolate()
plot(
  pr, color = pal(2)[2], main = '', auc.main = FALSE, bty = 'n', las = 1,
  ylab = 'PPV', xlab = 'Sensitivity'
)
text(
  0.6,
  0.6,
  paste('AUC: ', formatC(pr$auc.integral, 3, format = 'f')),
  col = hdrm::pal(2)[2],
  cex = 1
)

# ROC
roc <- roc.curve(pred[y == 'AML'], pred[y == 'ALL'], curve = TRUE)
plot(
  roc, color = hdrm::pal(2)[1], main = '', auc.main = FALSE, bty = 'n', las = 1,
  xlab = "1 - Specificity"
)
lines(0:1, 0:1, col = 'gray', lty = 2)
text(
  0.3,
  0.6,
  paste('AUC: ', formatC(roc$auc, 3, format = 'f')),
  col = hdrm::pal(2)[1],
  cex = 1
)

# cv.glmnet(x, y, family = 'binomial',
#           type.measure = 'auc')
# cv.ncvreg(x, y, family = 'binomial', returnY = TRUE)

