library(data.table)
library(hdrm)
library(hqreg)
library(kableExtra)
library(ggplot2)

# Illustration
set.seed(1)
x4 <- rnorm(100 * 20) |> matrix(100, 20) |> ncvreg::std()
y4 <- sample(1:4, 100, replace = TRUE)
fit <- glmnet(x4, y4, family = "multinomial")
B <- do.call(cbind, coef(fit, s = 0.05))
dim(B)
head(B)
eta <- predict(fit, s = 0.05, newx = x4, type = 'link')[,, 1]
head(eta)
colSums(eta)
colSums(eta) |> sum()

# Multinomial outcome
attach_data(Ramaswamy2001)

cvfit <- cv.glmnet(X, y, family = 'multinomial', keep = TRUE)

# cvfit <- cv.glmnet(X, y, family = 'multinomial')

par(mar = c(4, 4.1, 2.5, 0.5))
plot(
  cvfit,
  xlim = rev(range(log(cvfit$lambda))),
  xlab = expression(log(lambda)),
  las = 1,
  bty = 'n'
)

# Get coefficient matrix
b <- vapply(
  coef(cvfit, s = 'lambda.min'),
  function(x) {x[-1,1]},
  numeric(ncol(X))
)

bb <- b[rowSums(b) != 0,]
ord <- do.call(order, lapply(14:1, function(j) {x <- bb[, j]; x[x == 0] <- Inf; x}))
df <- as.data.frame.table(bb)
df$Var1 <- factor(df$Var1, labels = rownames(bb), levels = rownames(bb)[ord])
ggplot(df, aes(Var1, Freq, color = Var2)) +
  geom_segment(aes(xend = Var1, yend = 0)) +
  annotate(
    'segment',
    x = 1,
    xend = nrow(bb),
    y = 0,
    yend = 0,
    col = '#77777722'
  ) +
  facet_grid(Var2 ~ ., switch = 'y') +
  theme_void() +
  theme(panel.spacing.y = unit(-0.75, 'lines')) +
  guides(color = 'none')

summary_table <- function(cvfit, x, y) {
  cv_min <- cvfit$index['min',]
  pred <- factor(
    levels(y)[apply(cvfit$fit.preval[, , cv_min], 1, which.max)],
    levels = levels(y)
  )
  tab <- table(pred, y)
  tab <- round(100 * t(t(table(pred, y)) / as.numeric(table(y))))
  rownames(tab) <- colnames(tab) <- abbreviate(levels(y), minlength = 2)
  tab <- rbind(tab, table(y))
  tab <- cbind(tab, c(table(pred), length(y)))
  return(list(tab = tab, pred = pred))
}
s <- summary_table(cvfit, X, y)
kbl(s$tab, 'latex', booktabs = TRUE) |>
  kable_styling(font_size = 7)

# Robust regression simulation: Setup
res <- list(
  loss = c('Least squares', 'Huber'),
  rep = 1:100,
  sig = 1:10
) |>
  expand.grid()
res$n <- 100
res$p <- 100
res$se <- NA_real_

# Robust regression simulation: Loop
pb <- progress::progress_bar$new(total = nrow(res))
for (i in 1:nrow(res)) {
  o <- res[i,]
  if (i == 1 || o$rep != res$rep[i - 1]) {
    x <- rnorm(o$n * o$p) |> matrix(o$n, o$p)
    messy <- rbinom(o$n, 1, prob = 0.1)
    b <- rep(1:0, c(2, o$p - 2))
    y <- x %*% b + (1 - messy) * rnorm(o$n) + messy * rnorm(o$n, sd = o$sig)
  }
  
  # Analyze
  if (o$loss == 'Least squares') {
    cvfit <- cv.glmnet(x, y)
    bhat <- coef(cvfit, s = 'lambda.min')[-1]
  } else {
    capture.output(cvfit <- cv.hqreg(x, y)) |> invisible()
    bhat <- coef(cvfit, lambda = 'lambda.min')[-1]
  }

  # Summarize
  res$se[i] <- crossprod(bhat - b)
  pb$tick()
}

as.data.table(res)[, .(mse = mean(se)), .(loss, sig)] |>
  ggplot(aes(sig, mse, group = loss, color = loss)) +
  geom_line() +
  theme_minimal() +
  ylab('MSE') +
  xlab(expression(sigma * ' (contamination)'))

# Fit Cox regression models
attach_data(Shedden2008)
head(Z)
xx <- cbind(
  Age = as.numeric(Z$Age),
  Sex = Z$Sex == 'Male',
  Chemo = Z$AdjChemo == 'Yes',
  std(X)
)
w <- rep(0:1, c(3, ncol(X)))
fold <- assign_fold(y, 10, 1)

cvfit1 <- cv.glmnet(
  xx,
  S,
  family = 'cox',
  penalty.factor = w,
  lambda.min = 0.35
)
cvfit2 <- cv.ncvsurv(
  xx,
  S,
  penalty = 'lasso',
  penalty.factor = w,
  lambda.min = 0.35
)

# xx <- cbind(Z, X)
# w <- rep(0:1, c(ncol(Z), ncol(X)))
# cv.glmnet(xx, S, family = 'cox',
#           penalty.factor = w)

# cv.ncvsurv(xx, S, penalty = 'lasso',
#            penalty.factor = w)

par(mar = c(4, 4.1, 2.5, 0.5))
plot(cvfit1, xlim = rev(range(log(cvfit1$lambda))), las = 1, bty = 'n')

par(mar = c(4, 4.1, 2.5, 0.5))
plot(cvfit2, bty = 'n')

