# Setup
library(rjags)
library(survival)
source("http://myweb.uiowa.edu/pbreheny/7210/f19/notes/fun.R")
Data <- read.delim("https://s3.amazonaws.com/pbreheny-data-sets/Pike1966.txt")
# For the piecewise exponential model, you will also have to download the file:
# http://myweb.uiowa.edu/pbreheny/7210/f19/notes/piecewise.jag

# Slide 7: Reference prior
d <- sum(Data$Death)
v <- sum(Data$Time/365.25)
lam <- sort(c(seq(0, 3, len=199), 1))
plot(lam, dgamma(lam, 1, 0), col='gray', lwd=3, type='l', bty='n', ylim=c(0, 1.75),
     xlab=expression(lambda), ylab=expression(p(lambda)), las=1)
lines(lam, dgamma(lam, d+1, v), col=pal(2)[2], lwd=3)
pgamma(1, d+1, v)
qgamma(c(.025, .975), d+1, v)

# Slide 8: Informative prior
plot(lam, dgamma(lam, 3, 3), col='gray', lwd=3, type='l', bty='n', ylim=c(0, 1.75),
     xlab=expression(lambda), ylab=expression(p(lambda)), las=1)
lines(lam, dgamma(lam, 2+d, 2+v), col=pal(2)[2], lwd=3)
pgamma(1, d+3, v+3)
qgamma(c(.025, .975), d+3, v+3)

# Bayes: Slides 12-13
model_text <- textConnection(
'model {
  for (i in 1:n) {
    cens[i] ~ dinterval(t[i], tos[i])
    t[i] ~ dgamma(shape, rate)
  }
  shape ~ dunif(0, 1000)
  rate ~ dunif(0, 1000)
}')
jagsData <- with(Data, list(n=nrow(Data), t=ifelse(Death==1, Time/365.25, NA), tos=Time/365.25, cens=1-Death))
model <- jags.model(model_text, data=jagsData, n.chains=1, n.adapt=1000)
post <- jags.samples(model, c('rate', 'shape'), 20000)
rate <- post$rate
dnplot(rate, col=pal(2)[2], xlab=expression(lambda))

# Bad empirical Bayes
mle <- c(35.01, 22.23)
badmodel_text <- textConnection(
'model {
  for (i in 1:n) {
    cens[i] ~ dinterval(t[i], c[i])
    t[i] ~ dgamma(shape, rate)
  }
  rate ~ dunif(0, 1000)
}')
jagsData <- with(Data, list(n=nrow(Data), shape=mle[2], t=ifelse(Death==1, Time/365.25, NA), c=Time/365.25, cens=1-Death))
badmodel <- jags.model(badmodel_text, data=jagsData, n.chains=1, n.adapt=1000)
post.bad <- jags.samples(badmodel, c('rate', 'shape'), 20000)
rate.bad <- post.bad$rate
dnplot(cbind(rate.bad, rate), xlab=expression(lambda))
toplegend(legend=c("Bad Bayes", "Bayes"), col=pal(2), lwd=3)

# Intervals
quantile(rate, c(.025, .975))
quantile(rate.bad, c(.025, .975))

# Piecewise illustration vs Weibull: K=4
K <- 4
tt <- seq(0, 1, len=99)
t_cut <- seq(0, 1, length=K+1)
t_mid <- t_cut[1:K] + 0.45*diff(t_cut)
lam <- dweibull(t_mid, 1.5)/pweibull(t_mid, 1.5, lower.tail=FALSE)
plot(tt, dweibull(tt, 1.5)/pweibull(tt, 1.5, lower.tail=FALSE), type='l', bty='n', las=1, col=pal(2)[1], lwd=3, xlab='Time', ylab='Hazard')
lines(tt, lam[as.integer(cut(tt, t_cut, include.lowest = TRUE))], col=pal(2)[2], lwd=3, type='s')
H <- approxfun(t_cut, c(0, cumsum(lam*diff(t_cut))))  # Cumulative hazard
plot(tt, pweibull(tt, 1.5, lower.tail=FALSE), type='l', bty='n', las=1, col=pal(2)[1], lwd=3, xlab='Time', ylab='Survival')
lines(tt, exp(-H(tt)), col=pal(2)[2], lwd=3)

# Piecewise illustration vs Weibull: K=10
K <- 10
tt <- seq(0, 1, len=99)
t_cut <- seq(0, 1, length=K+1)
t_mid <- t_cut[1:K] + 0.45*diff(t_cut)
lam <- dweibull(t_mid, 1.5)/pweibull(t_mid, 1.5, lower.tail=FALSE)
plot(tt, dweibull(tt, 1.5)/pweibull(tt, 1.5, lower.tail=FALSE), type='l', bty='n', las=1, col=pal(2)[1], lwd=3, xlab='Time', ylab='Hazard')
lines(tt, lam[as.integer(cut(tt, t_cut, include.lowest = TRUE))], col=pal(2)[2], lwd=3, type='s')
H <- approxfun(t_cut, c(0, cumsum(lam*diff(t_cut))))  # Cumulative hazard
plot(tt, pweibull(tt, 1.5, lower.tail=FALSE), type='l', bty='n', las=1, col=pal(2)[1], lwd=3, xlab='Time', ylab='Survival')
lines(tt, exp(-H(tt)), col=pal(2)[2], lwd=3)

# Piecewise exponential model
set.seed(1)
uniq <- with(Data, sort(unique(Time[Death==1])))
a <- c(0, uniq[-length(uniq)] + diff(uniq)/2, max(Data$Time)+1)  # Cut points
jagsData <- with(Data, list(
  n = nrow(Data),       # Number of subjects
  J = length(uniq),     # Num of gaps between failure times
  K = length(uniq),     # Num of lambda values to estimate
  t = Time,             # Time on study
  d = Death,            # 1 if event (death) observed
  Z = Group - 1.5,      # Group (+0.5 / -0.5)
  a = a,                # Cut points
  period = 1:length(uniq)))  # Maps lambdas to intervals
fit <- jags.model('piecewise.jag', data=jagsData, n.chains=4, n.adapt=1000)
post <- jags.samples(fit, c('beta', 'lam'), 10000)
post1 <- post

# diagnostics
gelman.diag(as.mcmc.list(post$beta))
gelman.diag(as.mcmc.list(post$lam))

# beta
summary(as.mcmc.list(post$beta))
dnplot(post$beta, col=pal(2)[2], xlab=expression(beta))
post$beta  # Posterior mean
quantile(post$beta, c(.025, .975))

# Baseline survival
S <- function(t, lam, a) {
  H <- approxfun(a, c(0, cumsum(lam*diff(a))))  # Cumulative hazard
  exp(-H(t))
}
tt <- seq(0, max(Data$Time), len=99)
km_fit <- survfit(Surv(Time, Death) ~ 1, Data)
pm_lam <- apply(post$lam, 1, mean)
L <- as.matrix(as.mcmc.list(post$lam))
post_s <- apply(L, 1, S, t=tt, a=a)
ci_s <- apply(post_s, 1, quantile, c(0.025, 0.975))
Plot(km_fit, col='gray')
lines(tt, S(tt, pm_lam, a), col=pal(2)[2], lwd=3)
polygon(c(tt, rev(tt)), c(ci_s[1,], rev(ci_s[2,])), col=pal(2, alpha=0.4)[2], border=NA)

# Survival of each group
km_fit <- survfit(Surv(Time, Death) ~ Group, Data)
pm_lam1 <- apply(post$lam, 1, mean) * exp(mean(post$beta) * -0.5)
pm_lam2 <- apply(post$lam, 1, mean) * exp(mean(post$beta) * 0.5)
b <- unlist(as.mcmc.list(post$beta))
post_s1 <- sapply(1:length(b), function(i) S(tt, L[i,]*exp(-0.5*b[i]), a=a))
ci_s1 <- apply(post_s1, 1, quantile, c(0.025, 0.975))
post_s2 <- sapply(1:length(b), function(i) S(tt, L[i,]*exp(0.5*b[i]), a=a))
ci_s2 <- apply(post_s2, 1, quantile, c(0.025, 0.975))
Plot(km_fit, col='gray', xlim=c(0, max(Data$Time)), legend='none')
lines(tt, S(tt, pm_lam1, a), col=pal(2)[1], lwd=3)
lines(tt, S(tt, pm_lam2, a), col=pal(2)[2], lwd=3)
polygon(c(tt, rev(tt)), c(ci_s1[1,], rev(ci_s1[2,])), col=pal(2, alpha=0.4)[1], border=NA)
polygon(c(tt, rev(tt)), c(ci_s2[1,], rev(ci_s2[2,])), col=pal(2, alpha=0.4)[2], border=NA)

# Piecewise exponential model, K=4
set.seed(1)
jagsData2 <- jagsData
jagsData2$K <- 4
jagsData2$period <- sort(rep_len(1:jagsData2$K, length(uniq)))  # Suboptimal
model <- jags.model('piecewise.jag', data=jagsData2, n.chains=4, n.adapt=1000)
post2 <- jags.samples(model, c('beta', 'lam'), 10000)
jagsData3 <- jagsData2
jagsData3$period <- c(1, sort(rep_len(2:jagsData3$K, length(uniq)-1)))
model <- jags.model('piecewise.jag', data=jagsData3, n.chains=4, n.adapt=1000)
post3 <- jags.samples(model, c('beta', 'lam'), 10000)

# Baseline survival
km_fit <- survfit(Surv(Time, Death) ~ 1, Data)
pm_lam2 <- apply(post2$lam, 1, mean)[jagsData2$period]
pm_lam3 <- apply(post3$lam, 1, mean)[jagsData3$period]
L2 <- as.matrix(as.mcmc.list(post2$lam))
post_s2 <- apply(L2[, jagsData2$period], 1, S, t=tt, a=a)
ci_s2 <- apply(post_s2, 1, quantile, c(0.025, 0.975))
L3 <- as.matrix(as.mcmc.list(post3$lam))
post_s3 <- apply(L3[, jagsData3$period], 1, S, t=tt, a=a)
ci_s3 <- apply(post_s3, 1, quantile, c(0.025, 0.975))
Plot(km_fit, col='gray')
lines(tt, S(tt, pm_lam2, a), col=pal(2)[2], lwd=3)
Plot(km_fit, col='gray')
lines(tt, S(tt, pm_lam3, a), col=pal(2)[2], lwd=3)

# Comparison of methods
B <- matrix(NA, 5, 3, dimnames=list(c('Exponential', 'Weibull', 'Cox', 'BPE, K=29', 'BPE, K=4'), c('Est', 'Lower', 'Upper')))
fit <- survreg(Surv(Data$Time, Data$Death) ~ Group, Data, dist='exponential')
rownames(fit$var) <- colnames(fit$var) <- names(fit$coef) # Fix weird survreg bug
B[1,] <- -c(coef(fit)[2], rev(confint(fit, 2)))
fit <- survreg(Surv(Data$Time, Data$Death) ~ Group, Data, dist='weibull')
rownames(fit$var) <- colnames(fit$var) <- c(names(fit$coef), 'scale') # Fix weird survreg bug
B[2,] <- -c(coef(fit)[2], rev(confint(fit, 2)))/fit$scale
fit <- coxph(Surv(Time, Death) ~ Group, Data)
B[3,] <- c(coef(fit), confint(fit))
B[4,] <- c(mean(post$beta), quantile(post$beta, c(.025, .975)))
B[5,] <- c(mean(post3$beta), quantile(post3$beta, c(.025, .975)))
B
