# Setup; slide 2
library(survival)
source("http://myweb.uiowa.edu/pbreheny/7210/f18/notes/fun.R")
PBC <- pbc[!is.na(pbc$trt),]
f <- function(x) {pmin(x, 3.5)}
fit1 <- coxph(Surv(time/365.25, status!=0) ~ trt + albumin, PBC)
fit2 <- coxph(Surv(time/365.25, status!=0) ~ trt + stage + hepato + f(albumin) + log(bili), PBC)
set.seed(2)
X <- matrix(rnorm(30*nrow(PBC)), nrow(PBC), 30)
fit3 <- coxph(Surv(time/365.25, status!=0) ~ trt + stage + hepato + f(albumin) + log(bili) + X, PBC)

# Slide 3
sd(fit1$linear.predictors)
sd(fit2$linear.predictors)
sd(fit3$linear.predictors)

# Slide 4
brk <- seq(-7,7,0.5)
at <- log(4^(-(-4:4)))
lab <- paste(c('1/', '')[1*(at>=0)+1], exp(abs(at)))
hist(fit1$linear.predictors, xlim=c(-7,7), breaks=brk, xaxt='n', xlab='Hazard ratio')
axis(1, at=at, labels=lab)
mtext('Model 1')
hist(fit2$linear.predictors, xlim=c(-7,7), breaks=brk, xaxt='n', xlab='Hazard ratio')
axis(1, at=at, labels=lab)
mtext('Model 2')
hist(fit3$linear.predictors, xlim=c(-7,7), breaks=brk, xaxt='n', xlab='Hazard ratio')
axis(1, at=at, labels=lab)
mtext('Model 3')

# Slide 5
survPlot <- function(fit, ...) {
  sfit <- survfit(fit)
  plot(sfit, mark.time=FALSE, conf.int=FALSE, las=1, bty='n', ...)
  for (m in c(-2,-1,1,2)) {
    tmp <- sfit
    tmp$surv <- tmp$surv^exp(m*sd(fit$linear.predictors))
    lines(tmp, mark.time=FALSE, conf.int=FALSE)
  }
}
survPlot(fit1, xlab='Time (years)', ylab='Survival')
mtext('Model 1')
survPlot(fit2, xlab='Time (years)', ylab='Survival')
mtext('Model 2')
survPlot(fit3, xlab='Time (years)', ylab='Survival')
mtext('Model 3')

# R2
summary(fit1)$rsq
summary(fit2)$rsq
summary(fit3)$rsq
LR <- 2*diff(fit1$loglik)
1 - exp(-LR/fit1$n)

# Concordance
survConcordance(Surv(time, status!=0) ~ fit2$linear.predictors, PBC)
summary(fit1)$concordance
summary(fit2)$concordance
summary(fit3)$concordance

# KL, Brier prediction error
Error <- function(fit, Time, score=c("KL", "Brier")) {
  score <- match.arg(score)
  cfit <- survfit(Surv(time/365.25, status==0)~1, PBC)
  sfit <- survfit(fit, newdata = PBC)
  S <- summary(sfit, Time)$surv
  cProb <- summary(cfit, Time)$surv
  Score <- matrix(NA, nrow(S), ncol(S))
  for (i in 1:length(Time)) {
    CensBefore <- fit$y[,2]==0 & fit$y[,1] < Time[i]
    y <- fit$y[,1] > Time[i]
    if (score == "KL") {
      Score[i, y] <- -log(S[i, y])
      Score[i, !y] <- -log(1-S[i, !y])
    } else {
      Score[i,] <- (y - S[i,])^2
    }
    Score[i, CensBefore] <- 0
  }
  Err <- sweep(Score, 1, cProb, '/')
  apply(Err, 1, mean)
}

# KL
Time <- seq(0, 11, len=99)
Err <- cbind(Error(fit1, Time), Error(fit2, Time), Error(fit3, Time))
matplot(Time, Err, type='s', ylab='Prediction error', las=1, bty='n', lwd=3, col=pal(3), lty=1)
toplegend(legend=paste('Model', 1:3), lwd=3, col=pal(3))

# Brier
Err <- cbind(Error(fit1, Time, "Brier"), Error(fit2, Time, "Brier"), Error(fit3, Time, "Brier"))
matplot(Time, Err, type='s', ylab='Prediction error', las=1, bty='n', lwd=3, col=pal(3), lty=1)
toplegend(legend=paste('Model', 1:3), lwd=3, col=pal(3))

# Shrinkage
g <- function(fit) {
  LR <- 2*diff(fit$loglik)
  p <- length(coef(fit))
  (LR-p)/LR
}
g(fit1)
g(fit2)
g(fit3)

# SD of shrunken PI
s1 <- fit1$linear.predictors*g(fit1)
s2 <- fit2$linear.predictors*g(fit2)
s3 <- fit3$linear.predictors*g(fit3)
sd(s1)
sd(s2)
sd(s3)

# Sim
set.seed(1)
beta <- c(rep(log(2), 2), rep(0,28))
eta <- X %*% beta
y <- rexp(length(eta), exp(eta))
fit.sim <- coxph(Surv(y)~X)
g(fit.sim)
par(mar=c(5,5,0.5,0.5))
plot(eta, fit.sim$linear.predictors, pch=19, col='gray60', las=1, bty='n',
     xlab=expression(eta), ylab=expression(hat(eta)))
abline(0, 1, col=pal(2)[1], lwd=3)
abline(0, coef(lm(fit.sim$linear.predictors~0+eta)), col=pal(2)[2], lwd=3)

s <- fit.sim$linear.predictors*g(fit.sim)
plot(eta, s, pch=19, col='gray60', las=1, bty='n',
     xlab=expression(eta), ylab=expression(hat(eta)))
abline(0, 1, col=pal(2)[1], lwd=3)
abline(0, coef(lm(s~0+eta)), col=pal(2)[2], lwd=3)
