Parameter estimation in clustered data
I am using the native pipe
operator, which is new to R 4.10. This pipe operator is written as a |
followed by a >
. In this document, the operator is printed as |>
, due to the fact that I am using font ligatures. If the pipe doesn’t work for you, simply replace it with the older pipe %>%
We now explore the difference between fitting a model with population-level effects and a mdoel with varying effects to the same data.
We’ll generate some normally distributed data consisting of three conditions, with means of 0, 1 and 2. Each group has standard deviation of 0.5. We’ll draw 20 observations per condition.
Note that we haven’t anything about the structure in the data, so these could be three independent samples, or the could be data from three subjects in one condition (even though I have called the grouping variable condition
p1 <- d |>
ggplot(aes(response, condition, color = condition)) +
geom_point( size = 2) +
scale_color_brewer(type = "qual")
We can compute the condition mean and SD, and then add these to the plot.
p1 +
geom_point(aes(mean, condition, color = condition),
data = condition_means,
size = 6)
Let’s first assume that the data are from independent samples. We want to estimate the three condition means, with the goal of comparing these. Since none of the these conditions is an obvious reference category, let’s just estimate all three means.
formula1 <- response ~ 0 + condition
The statistical formula can be written as:
\[ \operatorname{response} = \beta_{\operatorname{A}}(\operatorname{condition}_{\operatorname{A}}) + \beta_{\operatorname{B}}(\operatorname{condition}_{\operatorname{B}}) + \beta_{\operatorname{C}}(\operatorname{condition}_{\operatorname{C}}) + \epsilon \]
This states that the expected value of the response variable is either \(\beta_{\operatorname{A}}\), \(\beta_{\operatorname{B}}\) or \(\beta_{\operatorname{C}}\), since these are indicator variables.
We can inspect the default priors.
get_prior(formula1, data = d) |>
as_tibble() |> select(1:4)
# A tibble: 5 x 4
prior class coef group
<chr> <chr> <chr> <chr>
1 "" b "" ""
2 "" b "conditionA" ""
3 "" b "conditionB" ""
4 "" b "conditionC" ""
5 "student_t(3, 0, 2.5)" sigma "" ""
The priors on the regression coefficients are flat. This should be avoided, so we’ll use normal distributions centred at 1, which is just the overall mean.
priors1 <- prior(normal(1, 1), class = b)
tibble(x = seq(-3, 5, by = 0.01),
y = dnorm(x, 1, 1)) |>
ggplot(aes(x, y)) + geom_line(size = 2) +
geom_vline(xintercept = 1, linetype = 2) +
ylab("") + xlab("") +
ggtitle("Prior on means")
I usually save the model file in order to avoid having to recompile and sample unless the model specification has changed.
m1 <- brm(formula1,
prior = priors1,
data = d,
file = "../../models/02-m1")
Looking at the posterior estimates, we note that the means are similar to the sample means.
Family: gaussian
Links: mu = identity; sigma = identity
Formula: response ~ 0 + condition
Data: d (Number of observations: 30)
Samples: 4 chains, each with iter = 2000; warmup = 1000; thin = 1;
total post-warmup samples = 4000
Population-Level Effects:
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS
conditionA 0.03 0.18 -0.32 0.39 1.00 3986
conditionB 1.15 0.18 0.79 1.51 1.00 4193
conditionC 2.04 0.18 1.69 2.37 1.00 3913
conditionA 2795
conditionB 2658
conditionC 2768
Family Specific Parameters:
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sigma 0.57 0.08 0.44 0.74 1.00 3394 2812
Samples were drawn using sampling(NUTS). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).
For reference, here are the sample means:
# A tibble: 3 x 3
condition mean sd
<chr> <dbl> <dbl>
1 A 0.0321 0.353
2 B 1.16 0.562
3 C 1.94 0.578
However, stan gives us full posterior distributions, not just point estimates. We can visualize these using the type
argument in mcmc_plots()
mcmc_plot(m1, type = "areas")
We can extract the population-level effects using fixef()
, which is named like this to be consistent with other modelling packages in R. We can get either summaries, or the samples.
Estimate Est.Error Q2.5 Q97.5
conditionA 0.03302576 0.1814484 -0.3152898 0.3900089
conditionB 1.15011637 0.1805871 0.7891059 1.5103085
conditionC 2.03947646 0.1756655 1.6853433 2.3745034
m1_pop <- fixef(m1, summary = FALSE)
The samples can be summarized, resulting in the same numbers as above.
m1_pop |>
as_tibble() |>
names_to = "condition",
values_to = "mean") |>
group_by(condition) |>
summarise(estimate = mean(mean),
sd = sd(mean),
q2.5 = quantile(mean, 0.025),
q97.5 = quantile(mean, 0.975))
# A tibble: 3 x 5
condition estimate sd q2.5 q97.5
<chr> <dbl> <dbl> <dbl> <dbl>
1 conditionA 0.0330 0.181 -0.315 0.390
2 conditionB 1.15 0.181 0.789 1.51
3 conditionC 2.04 0.176 1.69 2.37
Now let’s treat the data as if they were clustered, e.g. repeated measures of a three subject in a single condition. If it helps, we can rename the condition variable to subject
d <- d |> select(subject = condition, response)
We now want to estimate the population-level mean, averaged over subjects, but we are also interested in the subject-level means. This is a partial pooling model, as opposed to the no-pooling model from above.
The formula states that we are predicting repsonse
as an average effect, with varying effects for each subject.
formula2 <- response ~ 1 + (1 | subject)
\[ \begin{aligned} \operatorname{response}_{i} &\sim N \left(\alpha_{j[i]}, \sigma^2 \right) \\ \alpha_{j} &\sim N \left(\mu_{\alpha_{j}}, \sigma^2_{\alpha_{j}} \right) \text{, for subject j = 1,} \dots \text{,J} \end{aligned} \] One way to think of this is that the subject means \(\alpha_{j}\) are themselves drawn from a normal distribution, with mean \(\mu_{\alpha}\) and SD \(\sigma_{\alpha}\). THis is what makes this a hierarchical problem—subject parameters are random draws from a super-population., and are therefore related, in the sense that they have a commmon distribution. In mathematical terms, the subject parameters are \(i.i.d\) and are exchangeable.
Let’s see what priors we get for this model.
get_prior(formula2, data = d) |>
as_tibble() |> select(1:4)
# A tibble: 5 x 4
prior class coef group
<chr> <chr> <chr> <chr>
1 "student_t(3, 0.9, 2.5)" Intercept "" ""
2 "student_t(3, 0, 2.5)" sd "" ""
3 "" sd "" "subject"
4 "" sd "Intercept" "subject"
5 "student_t(3, 0, 2.5)" sigma "" ""
We now have a student_t(3, 0.9, 2.5)
prior on the intercept, which is the overall mean, and we have a student_t(3, 0, 2.5)
prior on the SD of the subjects’ varying effects around the mean. This is the term \(\sigma_{\alpha}\). Because SD parameters must be \(>0\), \(\sigma_{\alpha}\) has a half student-t distribution. The parameter sigma
is the residual standard deviation.
p_intercept <- tibble(x = seq(-10, 15, by = 0.01),
Intercept = dstudent_t(x, 3, 0.9, 2.5)) |>
# pivot_longer(c(Intercept, sd), names_to = "prior", values_to = "y") |>
ggplot(aes(x, Intercept)) + geom_line(size = 2) +
ylab("") + xlab("") +
p_sd <- tibble(x = seq(0, 15, by = 0.01),
sd = dstudent_t(x, 3, 0, 2.5)) |>
ggplot(aes(x, sd)) + geom_line(size = 2) +
ylab("") + xlab("") +
p_intercept + p_sd
Using the default priors is not a good idea at all here; the priors are too wide. This means that the sampling algorithm will be trying to explore parts of the parameter space where the likelihood is very small. This is very inefficient, and Stan gives us a warning if we try to run this model.
Family: gaussian
Links: mu = identity; sigma = identity
Formula: response ~ 1 + (1 | condition)
Data: d (Number of observations: 30)
Samples: 4 chains, each with iter = 2000; warmup = 1000; thin = 1;
total post-warmup samples = 4000
Group-Level Effects:
~condition (Number of levels: 3)
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS
sd(Intercept) 1.42 0.86 0.47 3.64 1.01 733
sd(Intercept) 1197
Population-Level Effects:
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
Intercept 0.95 0.86 -0.93 2.69 1.00 742 692
Family Specific Parameters:
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sigma 0.51 0.07 0.40 0.69 1.00 1733 1771
Samples were drawn using sampling(NUTS). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).
The esatimate for the intercept and the SD is very uncertain.
mcmc_plot(m2, type = "areas")
We can do better that that, by choosing more informative priors:
p_intercept <- tibble(x = seq(-10, 15, by = 0.01),
Intercept = dnorm(x, 1, 1)) |>
ggplot(aes(x, Intercept)) + geom_line(size = 2) +
ylab("") + xlab("") +
p_sd <- tibble(x = seq(0, 15, by = 0.01),
sd = dstudent_t(x, 3, 0, 1)) |>
ggplot(aes(x, sd)) + geom_line(size = 2) +
ylab("") + xlab("") +
p_intercept + p_sd
Family: gaussian
Links: mu = identity; sigma = identity
Formula: response ~ 1 + (1 | subject)
Data: d (Number of observations: 60)
Samples: 4 chains, each with iter = 2000; warmup = 1000; thin = 1;
total post-warmup samples = 4000
Group-Level Effects:
~subject (Number of levels: 3)
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS
sd(Intercept) 1.10 0.52 0.46 2.44 1.00 1168
sd(Intercept) 1469
Population-Level Effects:
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
Intercept 1.02 0.52 -0.07 2.08 1.00 1225 1246
Family Specific Parameters:
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sigma 0.52 0.05 0.43 0.63 1.00 1991 2231
Samples were drawn using sampling(NUTS). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).
It isn’t that obvious, but the estimate of the intercept is narrower in model 3.
posts <- tibble(name = str_c("m", 2:3),
model = str_c("Model ", 2:3)) |>
mutate(fit = map(name, get)) |>
mutate(post = map(fit, posterior_samples))
# head(posts)
posts <- posts |>
select(-fit) |>
posts_mean <- posts |>
pivot_longer(starts_with("b_Intercept"), names_to = "subject",
values_to = "value")
posts_mean |>
ggplot(aes(x = value, y = model)) +
The same is true of the standard deviation of the varying intercepts— in model 3, we have a mean of \(1.10\) with a 95% of \([0.46, 2.44]\), where in model 2 we have a mean of \(1.42\) and a 95% of \([0.47, 3.64]\).
We are not just interested in the population-level effect; we want the subject effects too.
While you can get these using the ranef()
function (ranef(m3)
gives you a summary, you have to set summary = FALSE
to get the samples (ranef(m3, summary = FALSE)
I really like using tidybayes
for this, as it’s functions return tidy data frames.
If we want just the varying effects, we can use r_subject[subject]
with the placeholer subject
. This will be replaced by the subject ID
m3 |>
# A tibble: 12,000 x 5
# Groups: subject [3]
subject r_subject .chain .iteration .draw
<chr> <dbl> <int> <int> <int>
1 A -1.05 1 1 1
2 A -2.29 1 2 2
3 A -0.751 1 3 3
4 A -1.98 1 4 4
5 A -1.70 1 5 5
6 A -1.21 1 6 6
7 A -1.05 1 7 7
8 A -1.15 1 8 8
9 A -0.565 1 9 9
10 A -0.888 1 10 10
# … with 11,990 more rows
The varying effects are centred at zero. If you want the subjects-specific means, you have to add the population-level estimate.
m3 %>%
spread_draws(b_Intercept, r_subject[subject]) |>
median_qi(subject_mean = b_Intercept + r_subject)
# A tibble: 3 x 7
subject subject_mean .lower .upper .width .point .interval
<chr> <dbl> <dbl> <dbl> <dbl> <chr> <chr>
1 A 0.0482 -0.177 0.276 0.95 median qi
2 B 1.15 0.932 1.38 0.95 median qi
3 C 1.92 1.69 2.14 0.95 median qi
p3_shrinkage <- m3 |>
spread_draws(b_Intercept, r_subject[subject]) |>
median_qi(subject_mean = b_Intercept + r_subject) |>
ggplot(aes(y = subject, x = subject_mean,
color = subject)) +
geom_pointinterval(aes(xmin = .lower, xmax = .upper), point_size = 4) +
scale_color_brewer(type = "qual")
p3_shrinkage +
geom_point(aes(mean, subject),
data = subject_means,
size = 2, color = "black") +
ggtitle("Hierarchical shrinkage")
Plotting the no-pooling estimates alongside the pooled estimates reveals that the partial-pooling estimates are less extreme that the no-pooling ones. This is due to the fact that the estimates are drawn to the overall mean, and is known as shrinkage.
Now we will simulate data from a pre-post treatment study. The treatment is applied to every subject, so is what is commonly referred to as a within-subject variable.
This function generates data for n_subjects
subjects in two conditions, with the following parameters:
params <- tribble(~Parameter, ~`Default value`, ~Desccription,
"a", 3.5, "average pre-treatment effect (intercepts)",
"b", -1, "average difference between pre and post",
"sigma_a", 1, "std dev in intercepts",
"sigma_b", 0.8, "std dev in differences (slopes)",
"rho", -0.7, "correlation between intercepts and slopes",
"n_subject", 10, "no. subjects",
"n_trials", 20, "no. trials per subject per condition",
"sigma", 0.5, "residual standard deviation")
params |>
kableExtra::kbl() |>
kableExtra::kable_paper("hover", full_width = T)
Parameter | Default value | Desccription |
a | 3.5 | average pre-treatment effect (intercepts) |
b | -1.0 | average difference between pre and post |
sigma_a | 1.0 | std dev in intercepts |
sigma_b | 0.8 | std dev in differences (slopes) |
rho | -0.7 | correlation between intercepts and slopes |
n_subject | 10.0 | no. subjects |
n_trials | 20.0 | no. trials per subject per condition |
sigma | 0.5 | residual standard deviation |
simulate_treatment <- function(a = 3.5,
b = -1,
sigma_a = 1,
sigma_b = 0.8,
rho = -0.7,
n_subjects = 10,
n_trials = 20,
sigma = 0.5) {
# Ccombine the terms
mu <- c(a, b)
cov_ab <- sigma_a * sigma_b * rho
SD <- matrix(c(sigma_a^2, cov_ab,
cov_ab, sigma_b^2), ncol = 2)
# sigmas <- c(sigma_a, sigma_b) # standard deviations
# rho <- matrix(c(1, rho, # correlation matrix
# rho, 1), nrow = 2)
# # now matrix multiply to get covariance matrix
# SD <- diag(sigmas) %*% rho %*% diag(sigmas)
varying_effects <-
MASS::mvrnorm(n_subjects, mu, SD) |>
# as_tibble(.name_repair = "unique") |>
data.frame() |>
purrr::set_names("a_j", "b_j")
d_linpred <-
varying_effects |>
mutate(subject = 1:n_subjects) |>
expand(nesting(subject, a_j, b_j), post = c(0, 1)) |>
mutate(mu = a_j + b_j * post,
sigma = sigma) |>
mutate(treatment = ifelse(post == 0, "pre", "post"),
treatment = factor(treatment, levels = c("pre", "post")))
d <- d_linpred |>
slice(rep(1:n(), each = n_trials)) |>
mutate(response = rnorm(n = n(), mean = mu, sd = sigma))
plot_linpred <- function(d, violin = FALSE) {
d_linpred <- d |>
group_by(subject, treatment) |> distinct(mu, .keep_all = TRUE)
p <- d_linpred |>
ggplot(aes(x = treatment, y = mu))
if (isTRUE(violin)) {
p <- p +
geom_violin(aes(x = treatment, y = response,
fill = treatment),
alpha = 0.5,
data = d) +
geom_jitter(aes(x = treatment, y = response,
color = treatment),
width = 0.1, size = 2,
data = d)
p <- p +
geom_line(aes(group = 1), color = "#8B9DAF", size = 1, linetype = 3) +
geom_point(aes(fill = treatment),
shape = 21,
colour = "black",
size = 4,
stroke = 2) +
scale_fill_brewer(type = "qual") +
scale_color_brewer(type = "qual") +
coord_cartesian(ylim = c(0, 8)) +
ylab("Response") +
theme(legend.position = "none",
axis.ticks.x = element_blank()) +
facet_wrap(~ subject) +
ggtitle("Linear predictor")
d1 <- simulate_treatment(n_subjects = 5)
plot_linpred(d1, violin = T)
d2 <- simulate_treatment(n_subjects = 5,
sigma_a = 1.2,
sigma_b = 0.2,
rho = -0.7)
plot_linpred(d2, violin = F)
d <- simulate_treatment(n_subjects = 10, n_trials = 50)
plot_linpred(d, violin = F)
\[\begin{align*} \text{response}_i & \sim \operatorname{Normal}(\mu_i, \sigma) \\ \mu_i & = \alpha_{\text{subject}_i} + \beta_{\text{subject}_i} \text{treatment}_i \\ \begin{bmatrix} \alpha_\text{subject} \\ \beta_\text{cafe} \end{bmatrix} & \sim \text{MVNormal} \left (\begin{bmatrix} \alpha \\ \beta \end{bmatrix}, \mathbf{S} \right ) \\ \mathbf S & = \begin{bmatrix} \sigma_\alpha & 0 \\ 0 & \sigma_\beta \end{bmatrix} \mathbf R \begin{bmatrix} \sigma_\alpha & 0 \\ 0 & \sigma_\beta \end{bmatrix} \\ \alpha & \sim \operatorname{Normal}(0, 10) \\ \beta & \sim \operatorname{Normal}(0, 10) \\ \sigma & \sim \operatorname{HalfCauchy}(0, 1) \\ \sigma_\alpha & \sim \operatorname{HalfCauchy}(0, 1) \\ \sigma_\beta & \sim \operatorname{HalfCauchy}(0, 1) \\ \mathbf R & \sim \operatorname{LKJcorr}(2), \end{align*}\]
get_prior(response ~ treatment + (treatment | subject),
data = d) |>
as_tibble() |> select(1:4)
# A tibble: 10 x 4
prior class coef group
<chr> <chr> <chr> <chr>
1 "" b "" ""
2 "" b "treatmentpost" ""
3 "lkj(1)" cor "" ""
4 "" cor "" "subject"
5 "student_t(3, 3.2, 2.5)" Intercept "" ""
6 "student_t(3, 0, 2.5)" sd "" ""
7 "" sd "" "subject"
8 "" sd "Intercept" "subject"
9 "" sd "treatmentpost" "subject"
10 "student_t(3, 0, 2.5)" sigma "" ""
ggplot(data = r_1, aes(x = V2)) +
geom_density(color = "transparent", fill = "#5e81ac", alpha = 2/3) +
geom_density(data = r_2,
color = "transparent", fill = "#a3be8c", alpha = 2/3) +
geom_density(data = r_4,
color = "transparent", fill = "#bf616a", alpha = 2/3) +
geom_text(data = tibble(x = c(.83, .62, .46),
y = c(.54, .74, 1),
label = c("eta = 1", "eta = 2", "eta = 4")),
aes(x = x, y = y, label = label),
color = "#A65141", family = "Courier") +
scale_y_continuous(NULL, breaks = NULL) +
fit1_prior <- brm(response ~ treatment + (treatment | subject),
prior = prior(normal(0, 10), class = b),
data = d,
sample_prior = "only",
file = "../../models/02-fit1-prior")
fit1 <- brm(response ~ treatment + (treatment | subject),
prior = prior(normal(0, 10), class = b),
data = d,
file = "../../models/02-fit1")
Family: gaussian
Links: mu = identity; sigma = identity
Formula: response ~ treatment + (treatment | subject)
Data: d (Number of observations: 1000)
Samples: 4 chains, each with iter = 2000; warmup = 1000; thin = 1;
total post-warmup samples = 4000
Group-Level Effects:
~subject (Number of levels: 10)
Estimate Est.Error l-95% CI u-95% CI
sd(Intercept) 0.84 0.23 0.53 1.40
sd(treatmentpost) 0.83 0.24 0.50 1.40
cor(Intercept,treatmentpost) -0.63 0.21 -0.91 -0.09
Rhat Bulk_ESS Tail_ESS
sd(Intercept) 1.00 1399 1558
sd(treatmentpost) 1.00 1268 1742
cor(Intercept,treatmentpost) 1.00 1740 2081
Population-Level Effects:
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS
Intercept 3.39 0.28 2.83 3.93 1.00 959
treatmentpost -0.84 0.26 -1.37 -0.33 1.00 1391
Intercept 1280
treatmentpost 1769
Family Specific Parameters:
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sigma 0.49 0.01 0.47 0.52 1.00 3066 2416
Samples were drawn using sampling(NUTS). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).
pp_check(fit1_prior, type = "dens_overlay_grouped", group = "treatment")
pp_check(fit1, type = "dens_overlay_grouped", group = "treatment")
pp_check(fit1, nsamples = 100, type ='stat', stat='median')
grid <- d |>
modelr::data_grid(subject, treatment)
# A tibble: 20 x 2
subject treatment
<int> <fct>
1 1 pre
2 1 post
3 2 pre
4 2 post
5 3 pre
6 3 post
7 4 pre
8 4 post
9 5 pre
10 5 post
11 6 pre
12 6 post
13 7 pre
14 7 post
15 8 pre
16 8 post
17 9 pre
18 9 post
19 10 pre
20 10 post
fits <- grid %>%
preds <- grid %>%
d %>%
ggplot(aes(y = treatment, x = response)) +
stat_interval(aes(x = .prediction), data = preds) +
stat_pointinterval(aes(x = .value), data = fits, .width = c(.66, .95),
position = position_nudge(y = -0.3)) +
geom_point() +
conditional_effects(fit1, re_formula = NA)
