# Setup; slide 2
library(survival)
source("http://myweb.uiowa.edu/pbreheny/7210/f19/notes/fun.R")
PBC <- pbc[!is.na(pbc$trt),]
PBC$Trt <- 1 * (PBC$trt==1)
set.seed(2)
X <- matrix(rnorm(30*nrow(PBC)), nrow(PBC), 30)
PBC <- cbind(PBC[,c('time', 'status', 'trt', 'albumin', 'stage', 'hepato', 'bili')], X)
f <- function(x) {pmin(x, 3.5)}
fit1 <- coxph(Surv(time/365.25, status!=0) ~ trt + albumin, PBC)
fit2 <- coxph(Surv(time/365.25, status!=0) ~ trt + stage + hepato + f(albumin) + log(bili), PBC)
fit3 <- coxph(Surv(time/365.25, status!=0) ~ ., PBC)

# Slide 3
sd(fit1$linear.predictors)
sd(fit2$linear.predictors)
sd(fit3$linear.predictors)

# Slide 4
brk <- seq(-7,7,0.5)
at <- log(4^(-(-4:4)))
lab <- paste(c('1/', '')[1*(at>=0)+1], exp(abs(at)))
hist(fit1$linear.predictors, xlim=c(-7,7), breaks=brk, xaxt='n', xlab='Hazard ratio')
axis(1, at=at, labels=lab)
mtext('Model 1')
hist(fit2$linear.predictors, xlim=c(-7,7), breaks=brk, xaxt='n', xlab='Hazard ratio')
axis(1, at=at, labels=lab)
mtext('Model 2')
hist(fit3$linear.predictors, xlim=c(-7,7), breaks=brk, xaxt='n', xlab='Hazard ratio')
axis(1, at=at, labels=lab)
mtext('Model 3')

# Slide 5
survPlot <- function(fit, ...) {
  sfit <- survfit(fit)
  plot(sfit, mark.time=FALSE, conf.int=FALSE, las=1, bty='n', ...)
  for (m in c(-2,-1,1,2)) {
    tmp <- sfit
    tmp$surv <- tmp$surv^exp(m*sd(fit$linear.predictors))
    lines(tmp, mark.time=FALSE, conf.int=FALSE)
  }
}
survPlot(fit1, xlab='Time (years)', ylab='Survival')
mtext('Model 1')
survPlot(fit2, xlab='Time (years)', ylab='Survival')
mtext('Model 2')
survPlot(fit3, xlab='Time (years)', ylab='Survival')
mtext('Model 3')

# R2
summary(fit1)$rsq
summary(fit2)$rsq
summary(fit3)$rsq
LR <- 2*diff(fit1$loglik)
1 - exp(-LR/fit1$n)

# Concordance
survConcordance(Surv(time, status!=0) ~ fit2$linear.predictors, PBC)
summary(fit1)$concordance
summary(fit2)$concordance
summary(fit3)$concordance

# KL, Brier prediction error
Error <- function(fit, Time, score=c("KL", "Brier")) {
  score <- match.arg(score)
  cfit <- survfit(Surv(time/365.25, status==0)~1, PBC)
  sfit <- survfit(fit, newdata = PBC)
  S <- summary(sfit, Time)$surv
  Score <- IPCW <- matrix(NA, nrow(S), ncol(S))
  for (i in 1:length(Time)) {
    CensBefore <- fit$y[,2]==0 & fit$y[,1] < Time[i]
    y <- fit$y[,1] > Time[i]
    if (any(y==0)) IPCW[i, y == 0] <- summary(cfit, PBC$time[y==0]/365.25)$surv
    if (any(y==1)) IPCW[i, y == 1] <- summary(cfit, Time[i])$surv
    if (score == "KL") {
      Score[i, y] <- -log(S[i, y])
      Score[i, !y] <- -log(1-S[i, !y])
    } else {
      Score[i,] <- (y - S[i,])^2
    }
    Score[i, CensBefore] <- 0
  }
  Err <- Score/IPCW
  apply(Err, 1, mean)
}

# KL
Time <- seq(0, 11, len=99)
Err <- cbind(Error(fit1, Time), Error(fit2, Time), Error(fit3, Time))
matplot(Time, Err, type='s', ylab='Prediction error', las=1, bty='n', lwd=3, col=pal(3), lty=1)
toplegend(legend=paste('Model', 1:3), lwd=3, col=pal(3))
kl_lim <- range(Err)

# Brier
Err <- cbind(Error(fit1, Time, "Brier"), Error(fit2, Time, "Brier"), Error(fit3, Time, "Brier"))
matplot(Time, Err, type='s', ylab='Prediction error', las=1, bty='n', lwd=3, col=pal(3), lty=1)
toplegend(legend=paste('Model', 1:3), lwd=3, col=pal(3))

# Shrinkage
g <- function(fit) {
  LR <- 2*diff(fit$loglik)
  p <- length(coef(fit))
  (LR-p)/LR
}
g(fit1)
g(fit2)
g(fit3)

# SD of shrunken PI
s1 <- fit1$linear.predictors*g(fit1)
s2 <- fit2$linear.predictors*g(fit2)
s3 <- fit3$linear.predictors*g(fit3)
sd(s1)
sd(s2)
sd(s3)

# Sim
set.seed(1)
beta <- c(rep(log(2), 2), rep(0,28))
eta <- X %*% beta
y <- rexp(length(eta), exp(eta))
fit.sim <- coxph(Surv(y)~X)
g(fit.sim)
par(mar=c(5,5,0.5,0.5))
plot(eta, fit.sim$linear.predictors, pch=19, col='gray60', las=1, bty='n',
     xlab=expression(eta), ylab=expression(hat(eta)))
abline(0, 1, col=pal(2)[1], lwd=3)
abline(0, coef(lm(fit.sim$linear.predictors~0+eta)), col=pal(2)[2], lwd=3)

s <- fit.sim$linear.predictors*g(fit.sim)
plot(eta, s, pch=19, col='gray60', las=1, bty='n',
     xlab=expression(eta), ylab=expression(hat(eta)))
abline(0, 1, col=pal(2)[1], lwd=3)
abline(0, coef(lm(s~0+eta)), col=pal(2)[2], lwd=3)

# Cross-validation
cvError <- function(fit, Data, Time, score=c("KL", "Brier"), K=10) {
  score <- match.arg(score)
  Score <- matrix(NA, length(Time), nrow(Data))
  fold <- rep_len(1:10, nrow(Data))
  for (k in 1:K) {
    Train <- Data[fold != k,]
    cvfit <- update(fit, formula=formula(fit), data=Train, model=TRUE)
    sfit <- survfit(cvfit, newdata = Data)
    S <- summary(sfit, Time)$surv
    for (i in 1:length(Time)) {
      y <- fit$y[fold==k,1] > Time[i]  # Obs surv for left-out fold
      s <- S[i, fold==k]               # Pred surv for left-out fold
      z <- double(length(s))
      if (score == "KL") {
        z[y] <- -log(s[y])
        z[!y] <- -log(1-s[!y])
      } else {
        z <- (y - s)^2
      }
      CensBefore <- (fit$y[fold==k,2] == 0) & (fit$y[fold==k,1] < Time[i])
      z[CensBefore] <- 0
      Score[i, fold==k] <- z
    }
  }
  cfit <- survfit(Surv(time/365.25, status==0)~1, Data)
  cProb <- summary(cfit, Time)$surv
  Err <- sweep(Score, 1, cProb, '/')
  apply(Err, 1, mean)
}
Time <- seq(0, 11, len=99)
cvErr <- cbind(cvError(fit1, PBC, Time), cvError(fit2, PBC, Time), cvError(fit3, PBC, Time))
matplot(Time, cvErr, type='s', ylab='Prediction error', las=1, bty='n', lwd=3, col=pal(3), lty=1, ylim=kl_lim)
toplegend(legend=paste('Model', 1:3), lwd=3, col=pal(3))
