# Setup
source('https://myweb.uiowa.edu/pbreheny/7110/f25/notes/fun.R')
library(data.table)
library(lme4)
library(survival)

# Linear mixed model illustration (define the competing approaches)
analyze <- function(x, y, s, method, u) {
  n <- max(s)
  m <- length(y)/n
  if (method == 'mle') {
    fit <- lm(y ~ x + as.factor(s))
    v <- sum(fit$residuals^2)/(n*m)
    b <- coef(summary(fit))['x',]
  } else if (method == 'marginal') {
    fit <- lme4::lmer(y ~ x + (1|s))
    v <- summary(fit)$sigma^2
    b <- coef(summary(fit))['x',]
  } else if (method == 'oracle') {
    yy <- y-u
    fit <- lm(yy ~ x)
    v <- sum(fit$residuals^2)/(n*m)
    b <- coef(summary(fit))['x',]
  } else if (method == 'naive') {
    fit <- lm(y ~ x)
    v <- summary(fit)$sigma^2
    b <- coef(summary(fit))['x',]
  } else if (method == 'diff') {
    n <- length(y)
    yy <- y[2*(1:n)] - y[2*(1:n)-1]
    xx <- x[2*(1:n)] - x[2*(1:n)-1]
    fit <- lm(yy ~ 0 + xx)
    v <- sum(fit$residuals^2)/(n)
    b <- coef(summary(fit))['xx',]
  }
  out <- data.frame(estimate=b['Estimate'], se=b['Std. Error'], variance=v)
  rownames(out) <- method
  out
}

# Data generating mechanism
gen_data <- function(n = 100, m = 2) {
  x <- runif(n * m)
  u <- rep(rnorm(n), each = m)
  y <- rnorm(m * n, x +u)
  s <- rep(1:n, each = m)
  list(x = x, u = u, y = y, s = s)
}

# Simulation
set.seed(1)
res <- expand.grid(
  method = c('mle', 'marginal', 'oracle', 'naive', 'diff'),
  replicate = 1:1000,
  estimate=NA_real_,
  variance = NA_real_
)
pb <- progress::progress_bar$new(total = nrow(res))
for (i in 1:nrow(res)) {
  if (i==1 || res$replicate[i] != res$replicate[i-1])
    dat <- gen_data()
  fit <- analyze(dat$x, dat$y, dat$s, res$method[i], dat$u)
  res$estimate[i] <- fit$estimate
  res$variance[i] <- fit$variance
  pb$tick()
}

# Summarize results
as.data.table(res) |>
  _[, .(
    BetaAvg = mean(estimate),
    BetaRMSE = sqrt(mean((estimate - 1)^2)),
    Variance = mean(variance)), method] |>
  _[, method := forcats::fct_recode(
    method,
    'Profile' = 'mle',
    'Mixed' = 'marginal',
    'Oracle' = 'oracle',
    'Naive' = 'naive',
    'Conditional' = 'diff'
  )] |>
  knitr::kable('latex', booktabs = TRUE, digits = 2) |>
  kableExtra::kable_styling(position = "center")

# LMM: What if u is correlated with x?
gen_data <- function(n = 100, m = 2) {
  u <- rep(rnorm(n), each = m)
  x <- rnorm(m * n, u)
  y <- rnorm(m * n, x + u)
  s <- rep(1:n, each = m)
  list(x = x, u = u, y = y, s = s)
}

# Simulation
res <- expand.grid(
  method = c("mle", "marginal", "oracle", "naive", "diff"),
  replicate = 1:1000,
  estimate = NA_real_,
  variance = NA_real_
)
pb <- progress::progress_bar$new(total=nrow(res))
for (i in 1:nrow(res)) {
  if (i==1 || res$replicate[i] != res$replicate[i-1])
    dat <- gen_data()
  fit <- analyze(dat$x, dat$y, dat$s, res$method[i], dat$u)
  res$estimate[i] <- fit$estimate
  res$variance[i] <- fit$variance
  pb$tick()
}

# Summarize results
as.data.table(res) |>
  _[, .(
    BetaAvg = mean(estimate),
    BetaRMSE = sqrt(mean((estimate - 1)^2)),
    Variance = mean(variance)), method
  ] |>
  _[, method := forcats::fct_recode(
    method,
    "Profile" = "mle",
    "Mixed" = "marginal",
    "Oracle" = "oracle",
    "Naive" = "naive",
    "Conditional" = "diff"
  )] |>
  knitr::kable("latex", booktabs = TRUE, digits = 2) |>
  kableExtra::kable_styling(position = "center")

# Quadrature example #1: A polynomial
gq <- lme4::GHrule(3)
sum(gq[,"w"] * gq[,"z"]^4)  # Exactly correct
sum(gq[,"w"] * gq[,"z"]^6)  # Incorrect (15 is true answer)

# Quadrature example #2: Something exotic
z <- rnorm(1000000)
mean(sqrt(abs(z)+abs(z^3)))
GH <- lme4::GHrule(20, FALSE)
sum(GH$w * sqrt(abs(GH$z)+abs(GH$z^3)))
GH <- lme4::GHrule(100, FALSE)
sum(GH$w * sqrt(abs(GH$z)+abs(GH$z^3)))

# Quadrature: Variance of median
n <- 11
m <- replicate(100000, median(rnorm(n)))
f <- function(x) {
n * choose(n - 1, 5) * x^2 * pnorm(x)^5 * (1 - pnorm(x))^5
}
GH20 <- lme4::GHrule(20, FALSE)
GH100 <- lme4::GHrule(100, FALSE)
data.frame(Variance = c(
var(m),
pi / (2*n),
sum(GH20$w * f(GH20$z)),
sum(GH100$w * f(GH100$z))),
row.names = c(
  'Monte Carlo ($N=100,000$)',
  'Asymptotic',
  'Gauss-Hermite ($K=20$)',
  'Gauss-Hermite ($K=100$)')) |>
knitr::kable('latex', booktabs = TRUE, digits = 4, escape = FALSE) |>
kableExtra::kable_styling(position = "center")

# # Logistic simulation
analyze <- function(x, y, s, method) {
  if (method == "mle") {
    fit <- glm(y ~ x + as.factor(s), family = 'binomial')
    b <- coef(fit)['x']
  } else if (method == "marginal") {
    fit <- glmer(y ~ x + (1|s), family = 'binomial') |> suppressMessages()
    b <- fixef(fit)['x']
  } else if (method == 'naive') {
    fit <- glm(y ~ x, family='binomial')
    b <- coef(fit)['x']
  } else if (method == 'conditional') {
    fit <- clogit(y ~ x + strata(s), data.frame(y = y, x = x, s = s))
    b <- coef(fit)['x']
  }
  b
}

# Data generating mechanism
gen_data <- function(n = 100, m = 2) {
  x <- rnorm(n * m)
  u <- rep(rnorm(n, sd = 2), m)
  y <- rbinom(n * m, 1, binomial()$linkinv(x + u))
  s <- rep(1:n, m)
  list(x = x, u = u, y = y, s = s)
}

# Simulation
set.seed(1)
res <- expand.grid(
  method = c("naive", "mle", "conditional", "marginal"),
  replicate = 1:1000,
  estimate = NA_real_
)
pb <- progress::progress_bar$new(total = nrow(res))
for (i in 1:nrow(res)) {
  if (i==1 || res$replicate[i] != res$replicate[i-1])
    dat <- gen_data()
  res$estimate[i] <- analyze(dat$x, dat$y, dat$s, res$method[i])
  pb$tick()
}

# Summarize results
as.data.table(res) |>
  _[, .(
    Mean = mean(estimate),
    RMSE = sqrt(mean((estimate - 1)^2))
  ),
    method
  ] |>
  _[, method := forcats::fct_recode(
    method,
    "Profile" = "mle",
    "Marginal" = "marginal",
    "Naive" = "naive",
    "Conditional" = "conditional"
  )] |>
  knitr::kable("latex", booktabs = TRUE, digits = 2) |>
  kableExtra::kable_styling(position = "center")

