require(survival)
source("http://myweb.uiowa.edu/pbreheny/7210/f17/notes/fun.R")
Data <- read.delim("http://myweb.uiowa.edu/pbreheny/data/Pike1966.txt")

# Setup
l <- function(theta) {
  lam <- theta[1]
  a <- theta[2]
  d <- Data$Death
  t <- Data$Time/365.25
  val <- 0
  for (i in 1:nrow(Data)) {
    val <- val + d[i]*dgamma(t[i], shape=a, rate=lam, log=TRUE) + (1-d[i])*pgamma(t[i], shape=a, rate=lam, log=TRUE, lower.tail=FALSE)
  }
  val
}

# Slide 16
opt <- optim(c(1, 1), l, lower=0.1, upper=100, control=list(fnscale=-1), method="L-BFGS-B", hessian=TRUE)
H <- opt$hessian
mle <- opt$par
fit <- survfit(Surv(Time/365.25, Death) ~ 1, Data)
Plot(fit, xlab='Time on study (Years)', conf.int=FALSE)
tt <- seq(0, 1, len=199)
lines(tt, pgamma(tt, shape=mle[2], rate=mle[1], low=FALSE), lwd=3, col=pal(2)[1])

# Slide 18: Likelihoods
N <- 149
ll <- seq(10, 60, len=N)
ls <- function(sig, lam) l(c(lam, sig))
prof <- function(lam) {
  L <- length(lam)
  val <- numeric(L)
  for (i in 1:length(lam)) {
    sig <- optimize(ls, interval=c(0.0001, 1000), lam=lam[i], maximum=TRUE)$max
    val[i] <- l(c(lam[i], sig)) - l(c(mle[1], mle[2]))
  }
  val
}
est <- function(lam) {
  L <- length(lam)
  val <- numeric(L)
  for (i in 1:length(lam)) {
    val[i] <- l(c(lam[i], mle[2])) - l(c(mle[1], mle[2]))
  }
  val
}
plot(ll, prof(ll), col=pal(2)[2], lwd=3, type='l', las=1, bty='n', xlab=expression(lambda), ylab="Likelihood", xlim=c(20, 50), ylim=c(-3,0))
lines(ll, est(ll), col=pal(2)[1], lwd=3)
toplegend(legend=c("Estimated", "Profile"), col=pal(2), lwd=3)

# Slide 19: Wald approx
l.bad <- function(lam) l(c(lam, mle[2]))
opt.bad <- optim(1, l.bad, lower=0.1, upper=100, control=list(fnscale=-1), method="L-BFGS-B", hessian=TRUE)
I <- -1*H
I.bad <- drop(-1*opt.bad$hessian)
SE <- sqrt(solve(I)[1,1])
SE.bad <- sqrt(1/I.bad)
plot(ll, -(ll-mle[1])^2/(2*SE^2), col=pal(2)[2], lwd=3, type='l', las=1, bty='n', xlab=expression(lambda), ylab="Likelihood (Wald approx)", xlim=c(20, 50), ylim=c(-3,0))
lines(ll, -(ll-mle[1])^2/(2*SE.bad^2), col=pal(2)[1], lwd=3)
toplegend(legend=c("Estimated", "Profile"), col=pal(2), lwd=3)

# Slide 21: Bayes
require(R2jags); invisible(runif(1))
model <- function() {
  for (i in 1:n) {
    cens[i] ~ dinterval(t[i], c[i])
    t[i] ~ dgamma(a, lam)
  }
  a ~ dunif(0, 1000)
  lam ~ dunif(0, 1000)
}
jData <- with(Data, list(n=nrow(Data), t=ifelse(Death==1, Time/365.25, NA), c=Time/365.25, cens=1-Death))
fit <- jags(model=model, data=jData, param=c("a", "lam"), n.chains=1, n.iter=21000, n.burn=1000, n.thin=1, DIC=FALSE)
lam <- fit$BUGSoutput$sims.list$lam
model.bad <- function() {
  for (i in 1:n) {
    cens[i] ~ dinterval(t[i], c[i])
    t[i] ~ dgamma(a, lam)
  }
  lam ~ dunif(0, 1000)
}
jData <- with(Data, list(n=nrow(Data), a=mle[2], t=ifelse(Death==1, Time/365.25, NA), c=Time/365.25, cens=1-Death))
fit.bad <- jags(model=model, data=jData, param=c("a", "lam"), n.chains=1, n.iter=21000, n.burn=1000, n.thin=1, DIC=FALSE)
lam.bad <- fit.bad$BUGSoutput$sims.list$lam
dnplot(cbind(lam.bad, lam), xlab=expression(lambda))
toplegend(legend=c("Bad Bayes", "Bayes"), col=pal(2), lwd=3)

# Slide 22: Intervals
uniroot(function(x) 2*prof(x) + qchisq(0.95, 1), c(0.001, mle[1]))$root
uniroot(function(x) 2*prof(x) + qchisq(0.95, 1), c(mle[1], 1000))$root
uniroot(function(x) 2*est(x) + qchisq(0.95, 1), c(0.001, mle[1]))$root
uniroot(function(x) 2*est(x) + qchisq(0.95, 1), c(mle[1], 1000))$root
mle[1] + qnorm(c(.025,.975))*SE
mle[1] + qnorm(c(.025,.975))*SE.bad
quantile(lam, c(.025, .975))
quantile(lam.bad, c(.025, .975))

