Data <- read.delim("http://web.as.uky.edu/statistics/users/pbreheny/701/data/earnings.txt")
source("http://web.as.uky.edu/statistics/users/pbreheny/701/S13/notes/fun.R")

## Setup
y <- log(Data$Earnings)
eid <- as.numeric(Data$Ethnicity)
Height <- Data$Height - mean(Data$Height)
Age <- cut(Data$Age, c(18,35,50,65), right=FALSE)
aid <- as.numeric(Age)
J <- length(levels(Data$Ethnicity))
K <- length(levels(Age))
n <- nrow(Data)
Interaction <- as.factor(paste(Data$Ethnicity, Age))
iid <- as.numeric(Interaction)
L <- length(levels(Interaction))
z <- c(0,0)

## Model
jData <- list("y", "J", "K", "n", "Height", "aid", "eid", "z")
model <- function() {
  ## Likelihood
  for (i in 1:n){
    y[i] ~ dnorm(y.hat[i], sigma.y^(-2))
    y.hat[i] <- a[eid[i],aid[i]] + b[eid[i],aid[i]]*Height[i]
  }
  sigma.y ~ dunif(0, 100)
  
  ## Prior for a, b
  for (j in 1:J) {
    for (k in 1:K) {
      a[j,k] <- T[j,k,1]
      b[j,k] <- T[j,k,2]
      T[j,k,1:2] ~ dmnorm(T.hat[j,k,], Tau.T)
      for (l in 1:2) {
        T.hat[j,k,l] <- mu[l] + G[j,l] + D[k,l]
      }
    }
  }
  
  ## Hyperpriors
  for (j in 1:2) {
    mu[j] ~ dnorm(0, 0.0001)
  }
  for (j in 1:J){
    G[j,1:2] ~ dmnorm (z, Tau.E)
  }
  for (k in 1:K){
    D[k,1:2] ~ dmnorm (z, Tau.A)
  }
  
  ## Priors for covariance matrices
  Tau.T ~ dwish(W[1,,], 3)
  Tau.A ~ dwish(W[2,,], 3)
  Tau.E ~ dwish(W[3,,], 3)
  for (i in 1:3) {
    for (j in 1:2) {
      xi[i,j] ~ dunif(0,10)
      W[i,j,j] <- xi[i,j]^(-2)
      W[i,j,3-j] <- 0
    }
  }
  
  ## Quantities of interest
  Sigma.T <- inverse(Tau.T)
  Sigma.A <- inverse(Tau.A)
  Sigma.E <- inverse(Tau.E)
  sigma.t[1] <- sqrt(Sigma.T[1,1])
  sigma.t[2] <- sqrt(Sigma.T[2,2])
  sigma.a[1] <- sqrt(Sigma.A[1,1])
  sigma.a[2] <- sqrt(Sigma.A[2,2])
  sigma.e[1] <- sqrt(Sigma.E[1,1])
  sigma.e[2] <- sqrt(Sigma.E[2,2])
  rho[1] <- Sigma.A[1,2]/(sigma.a[1]*sigma.a[2])
  rho[2] <- Sigma.E[1,2]/(sigma.e[1]*sigma.e[2])
  rho[3] <- Sigma.T[1,2]/(sigma.t[1]*sigma.t[2])
}
require(R2jags); invisible(runif(1))
fit <- jags(model=model, param=c("mu", "a", "b", "sigma.y", "sigma.a", "sigma.e", "sigma.t", "rho"), data=jData, n.iter=30000, n.thin=1)
attach.jags(fit)

## R2jags output
plot(fit)
fit

## Convergence
max(gelman.diag(as.mcmc(fit))$psrf)
min(effectiveSize(as.mcmc(fit)))

################
## Posteriors ##
################

## a, b
pma <- apply(a, 2:3, mean)
pmb <- apply(b, 2:3, mean)
rownames(pma) <- rownames(pmb) <- levels(Data$Ethnicity)
colnames(pma) <- colnames(pmb) <- levels(Age)
col <- colorRampPalette(c("#008DFFFF", "gray90", "#FF4E37FF"))(100)
levelplot(exp(pma), xlab="", ylab="", col.regions=col)
levelplot(exp(5*pmb), xlab="", ylab="", col.regions=col)

## Rho
psm(rho)

## sigma
psm(sigma.y)
psm(sigma.a)
psm(sigma.e)
psm(sigma.t)
sigma.alpha <- c(mean(sigma.y), mean(sigma.a[,1]), mean(sigma.e[,1]), mean(sigma.t[,1]))
prop.table(sigma.alpha^2)

## Regression lines
xlim <- range(Height)
ylim <- c(6,12)
lind <- round(seq(1, nrow(sigma.y), 100))
par(mfcol=c(K,J))
for (j in 1:J) {
  for (k in 1:K) {
    ind <- eid==j & aid==k
    plot(0, type="n", xlab="Height (inches from mean)", ylab="log(Earnings)", xlim=xlim, ylim=ylim)
    for (l in lind) abline(a[l,j,k], b[l,j,k], col="gray80")
    points(jitter(Height[ind]), jitter(y[ind]), pch=19, cex=0.5)
    abline(pma[j,k], pmb[j,k], col="blue", lwd=3)
    mtext(paste(levels(Data$Ethnicity)[j], levels(Age)[k], sep=": "), cex=0.8)
  }
}

## Regression lines on original scale
expline <- function(a, b, ...) {lines(xx, exp(a+xx*b)/1000, ...)}
xlim <- range(Height)
ylim <- c(0, 100)
lind <- round(seq(1, nrow(sigma.y), 100))
xx <- seq(xlim[1], xlim[2], len=99)
par(mfcol=c(K,J))
for (j in 1:J) {
  for (k in 1:K) {
    ind <- eid==j & aid==k
    plot(0, type="n", xlab="Height (inches from mean)", ylab="Earnings (thousands)", xlim=xlim, ylim=ylim, las=1)
    for (l in lind) expline(a[l,j,k], b[l,j,k], col="gray80")
    points(jitter(Height[ind]), exp(jitter(y[ind]))/1000, pch=19, cex=0.5)
    expline(pma[j,k], pmb[j,k], col="blue", lwd=3)
    mtext(paste(levels(Data$Ethnicity)[j], levels(Age)[k], sep=": "), cex=0.8)
  }
}
