Stein’s Paradox, Decision and Pooling

Author

Michael Issa

Published

August 2024

Stein’s paradox is a striking result in statistical decision theory published by Charles Stein in 1955. As a mathematical result, it is exquisite. The underlying sensitivity to the dimensionality of a problem (as with many other results) is surprising at first but subsides as you learn to write the proof down for yourself. What you end up with is a theorem telling you that given a decision rule, a loss function, and a function of those two (the risk function), you can prove that an estimator is better in some non-trivial way. Now, it doesn’t tell you that it’s the best or that it’s better than most. It only tells you that compared to this other estimator, it’s better. You might feel dejected after hearing this. Fear not! I think this result points to some very deep connections in the foundations of statistics: the purported distinction between inference and decision, how we think of our underlying data-generating process, and it also segues quite nicely into more useful techniques in contemporary statistical practice, like multilevel models.

1 Computational Environment Setup

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import warnings
import os
warnings.filterwarnings("ignore")

# Graphic configuration
c_light = "#DCBCBC"
c_light_highlight = "#C79999"
c_mid = "#B97C7C"
c_mid_highlight = "#A25050"
c_dark = "#8F2727"
c_dark_highlight = "#7C0000"

c_light_teal = "#6B8E8E"
c_mid_teal = "#487575"
c_dark_teal = "#1D4F4F"

RANDOM_SEED = 58583389
np.random.seed(RANDOM_SEED)
az.style.use("arviz-darkgrid")

plt.rcParams['font.family'] = 'serif'

plt.rcParams['xtick.labelsize'] = 12  
plt.rcParams['ytick.labelsize'] = 12  
plt.rcParams['axes.labelsize'] = 12  
plt.rcParams['axes.titlesize'] = 12   

plt.rcParams['axes.spines.top'] = False
plt.rcParams['axes.spines.right'] = False
plt.rcParams['axes.spines.left'] = True
plt.rcParams['axes.spines.bottom'] = True

plt.rcParams['axes.xmargin'] = 0  
plt.rcParams['axes.ymargin'] = 0  

plt.subplots_adjust(left=0.15, bottom=0.15, right=0.9, top=0.85)

current_working_directory = os.getcwd()
<Figure size 720x480 with 0 Axes>

2 Stein’s Paradox

2.1 Maximum Likelihood and the James-Stein Estimator

To report a point estimate for a parameter of interest \theta, you usually use maximum likelihood estimation, which chooses the estimate \hat{\theta} that maximizes the likelihood function,

\hat{\theta}_{\text{MLE}} = \underset{\theta}{\operatorname{arg\,max}} \, \mathcal{L}(\theta \mid x).

This provides a very natural interpretation of \hat{\theta} as the ‘most likely’ value. For example, if we sample the test scores of a school and find the average of our sample is 75%, anyone would assume that it’s most likely the population average is also 75%. What justifies this? Well, MLE justifies this. There is a mountain of results about how MLE has all these desirable properties, but we’re interested in its decision-theoretic justification.

Given a loss function that measures the inaccuracy of an estimator based on the squared difference or error:

L(\hat{\theta}, \theta) = (\hat{\theta} - \theta)^2 \\

the risk of the estimator, which is the expected value of this loss function, is:

R(\hat{\theta}, \theta) = \mathbb{E}[(\hat{\theta} - \theta)^2] \\

The estimator is considered better if it has a lower risk. For the sample mean \bar{X} and the median, the risk of the sample mean is always less than or equal to the risk of the median for all possible values of \theta:

R(\bar{X}, \theta) \leq R(\text{median}, \theta), \quad \forall \theta \in \Theta

The mean dominates the median in this case, and the median is ruled inadmissible. You can find a proof of this result on Wikipedia

theta = np.linspace(-5, 5, 400)

risk_mean = np.ones_like(theta)  
risk_median = 1.57 * risk_mean  
risk_half_mean = (theta**2 + 1) / 4

plt.figure(figsize=(10, 6))
plt.plot(theta, risk_mean, label='Mean', color='green')
plt.plot(theta, risk_median, label='Median', color='blue')
plt.plot(theta, risk_half_mean, label='1/2Mean', color='red')

plt.xlabel(r'$\theta$', fontsize=14)
plt.ylabel('Risk', fontsize=14)
plt.title('Risk Functions of Different Estimators', fontsize=16)
plt.legend()
plt.grid(True)

plt.figtext(0.1, 0.01, 'Adapted from Efron and Morris (1977).', 
            ha='left', fontsize=10, color='gray')

plt.show()

We see that as long as you’re estimating one or two variables, the mean is admissible. But what Stein showed is that if you want to estimate three or more variables together, then the mean is inadmissible. He provided one such estimator that performs better than MLE because it “shrinks” the MLE towards 0. This coefficient of shrinkage is:

c = 1 - \frac{\sigma^2}{\sum_{i=1}^{3} X_i}

where the X_i are the ith MLE estimate of our sample. In the case of population heights, this is the sample mean. Multiplying the sample mean by this coefficient of shrinkage results in a better estimate because it minimizes the expected error.

2.2 An aside on the coefficient of 1.57 in the median risk

This section may be skipped. I couldn’t find a sufficient derivation of this from a quick Google search, so I went through the work of deriving it.

Let’s derive the formula for the density of the sample median at the population median, denoted by f(\mu).

For a sample of n independent and identically distributed (i.i.d.) random variables X_1, X_2, \ldots, X_n from a continuous distribution with cumulative distribution function F(x) and probability density function f(x), the order statistics are the sorted values X_{(1)} \leq X_{(2)} \leq \dots \leq X_{(n)}.

The median M of the sample is the \left(\frac{n+1}{2}\right)th order statistic when n is odd, or the average of the two middle order statistics when n is even. The distribution of the median, M, can be approximated using the concept of order statistics. We focus on the density at the population median \mu.

The probability that the sample median M is less than or equal to a value x is given by:

P(M \leq x) = P\left(X_{\left(\frac{n+1}{2}\right)} \leq x\right)

The median is the order statistic around which half of the sample lies below and half above. The distribution of the order statistic is given by:

P(M \leq x) = \sum_{k=\left\lceil \frac{n+1}{2} \right\rceil}^{n} \binom{n}{k} [F(x)]^k [1 - F(x)]^{n-k}

However, we are interested in the density function f_M(x) at x = \mu, the population median. The density of the median M at x = \mu can be approximated for large n as:

f_M(x) \approx \frac{n!}{\left(\frac{n}{2}!\right)^2} [F(x)]^{n/2} [1 - F(x)]^{n/2} f(x)

At x = \mu, since F(\mu) = 0.5 for a symmetric distribution like the normal distribution:

f_M(\mu) \approx \frac{n!}{\left(\frac{n}{2}!\right)^2} \left(\frac{1}{2}\right)^n f(\mu)

Using Stirling’s approximation for factorials, n! \approx \sqrt{2 \pi n} \left(\frac{n}{e}\right)^n, the binomial coefficient simplifies to approximately:

\frac{n!}{\left(\frac{n}{2}!\right)^2} \approx \frac{2^n}{\sqrt{\pi n}}

Thus:

f_M(\mu) \approx \frac{1}{\sqrt{\pi n}} f(\mu)

For a normal distribution \mathcal{N}(\mu, \sigma^2), the PDF at the median \mu is:

f(\mu) = \frac{1}{\sqrt{2 \pi} \sigma}

Substituting this into the density of the median:

f_M(\mu) \approx \frac{1}{\sqrt{\pi n}} \cdot \frac{1}{\sqrt{2 \pi} \sigma} = \frac{1}{\sqrt{2 \pi^2 n} \sigma}

However, we are interested in the variance, which uses the term:

\sigma_{\text{median}}^2 = \frac{1}{4n [f(\mu)]^2}

See the Wikipedia on this. Given the PDF at the median for the normal distribution: f(\mu) = \frac{1}{\sqrt{2 \pi} \sigma} This results in the variance of the sample median being: \text{Var}(M) = \frac{1}{4n \left(\frac{1}{\sqrt{2 \pi} \sigma}\right)^2} = \frac{\pi \sigma^2}{2n}

Given a sample of n i.i.d random variables X_1, X_2, \ldots X_n drawn from a normal distribution \mathcal{N}(\mu, \sigma^2). The sample mean \bar{X} follows: \bar{X} \sim \mathcal{N}(\mu, \frac{\sigma^2}{n}) and the sample Median for large n is: M \sim \mathcal{N}(\mu, \frac{\pi\sigma^2}{2n})

where \frac{\pi}{2} arises from the variance of the sample median that follows a normal distribution and is approximately equal to 1.57.

Since the risk under the quadratic loss function (mean squared error) for an estimator \theta is the variance of the estimator, we get the constant of proportionality of 1.57

3 Baseball Data

The figures for the baseball data are taken from Efron and Hastie (2010, p. 95). We plot the batting average against the outcome

import pandas as pd
baseball = pd.DataFrame({
    'player': ["Clemente", "F Robinson", "F Howard", "Johnstone", "Berry", "Spencer", "Kessinger", "L Alvarado", "Santo", "Swoboda", 
               "Unser", "Williams", "Scott", "Petrocelli", "E Rodriguez", "Campaneris", "Munson", "Alvis"],
    'hits': [18, 17, 16, 15, 14, 14, 13, 12, 11, 11, 10, 10, 10, 10, 10, 9, 8, 7],
    'times_at_bat': 45,
    'true_batting_average': [.346, .298, .276, .222, .273, .27, .263, .21, .269, .23, .264, .256, .303, .264, .226, .286, .316, .2]
})
plt.figure(figsize=(10, 3))


sns.histplot(
    baseball['true_batting_average'], 
    bins=20,  
    color=c_light, 
    kde=False,  
    stat='density', 
    alpha=0.3  
)

sns.kdeplot(
    baseball['true_batting_average'], 
    color=c_light_highlight, 
    fill=True, 
    alpha=0.67, 
    bw_adjust=0.5 
)

median = np.median(baseball['true_batting_average'])
q1 = np.percentile(baseball['true_batting_average'], 25)
q3 = np.percentile(baseball['true_batting_average'], 75)

plt.axvline(median, color=c_dark_highlight, linestyle='--', lw=2, label='Median')
y_iqr = 0.02 
plt.plot([q1, q3], [y_iqr, y_iqr], color=c_dark_highlight, lw=4, label='IQR')

plt.title('Batting Average Density', fontsize=12)
plt.xlabel('True Batting Average')
plt.legend(title='Estimate Type', loc='upper right')

plt.show()

Following the same procedure that Efron and Morris give us, we get the value of 0.212 for our value of C, the shrinkage constant. I follow the presentation in Efron and Hastie for the derivation of the James-Stein estimate for the baseball data Each player’s batting average p_i is given by

p_i \sim \frac{\text{Binomial}(90, P_i)}{90}

where P_i is the true average. We don’t know the true average, so we use the normal approximation of the binomial.

p_i \dot{\sim} \mathcal{N}(P_i, \sigma_0^2)

where \sigma_0^2 is the is the binomial squared standard error, which gives us the variance.

\sigma_0^2 = \frac{\bar{p}(1 - \bar{p})}{n}

Finally, the James-Stein estimate is given by:

\hat{p}_i^{\text{JS}} = \bar{p} + \left(1 - \frac{(N - 3)\sigma_0^2}{\Sigma(p_i - \bar{p})^2}\right)(p_i - \bar{p})

The factor here adjusts how much the player’s individual estimate p_i is pulled towards the group mean \bar{p}

Now, to see the effect of the shrinkage versus without shrinkage we plot the densities.

c = 0.212
baseball['y'] = baseball['hits'] / baseball['times_at_bat']
baseball['y_bar'] = baseball['y'].mean()

baseball['z'] = baseball['y_bar'] + c * (baseball['y'] - baseball['y_bar'])
baseball['theta'] = baseball['true_batting_average']

melted_baseball = pd.melt(baseball, value_vars=['y', 'z'])

melted_baseball['label'] = melted_baseball['variable'].map({
    'z': 'The James-Stein estimate',
    'y': 'Early-season batting average'
})

palette = {
    'The James-Stein estimate': c_light,
    'Early-season batting average': c_mid_highlight
}

# Plot
plt.figure(figsize=(10, 5))

sns.kdeplot(data=melted_baseball, x='value', hue='label', palette=palette, fill=True, alpha=0.67)

plt.axvline(median, color=c_dark_highlight, linestyle='--', lw=2, label='Median')
plt.plot([q1, q3], [y_iqr, y_iqr], color=c_dark_highlight, lw=4, label='IQR')

plt.xlabel('Batting Average')
plt.ylabel('')
plt.title('Distribution of Batting Averages') 

plt.xlim(0, 0.6) 

plt.legend(title='Estimate Type', loc='upper right')
plt.show()

We see that the James-Stein estimator is substantially narrower than the early season batting average, which is more diffuse. This jives well with our claim of the James-Stein estimator being a “shrinkage” estimator. Now that we have the distributions we’ll check if we acheived a lower error rate.

baseball['y_error'] = baseball['theta'] - baseball['y']
baseball['z_error'] = baseball['theta'] - baseball['z']

melted_errors = pd.melt(baseball, value_vars=['y_error', 'z_error'])
melted_errors['label'] = melted_errors['variable'].map({
    'y_error': 'Early-season error',
    'z_error': 'James-Stein error'
})

palette_errors = {
    'Early-season error': c_mid_highlight,
    'James-Stein error': c_light   
}
# Plot
plt.figure(figsize=(10, 5))

plot = sns.kdeplot(data=melted_errors, x='value', hue='label', palette=palette_errors, fill=True, alpha=0.67)

plt.axvline(x=0, linestyle='--', color=c_dark_highlight)

plt.title('Error Distribution of Batting Averages')

plt.show()

4 Is it Even Paradoxical?

At first blush, Stein’s Paradox doesn’t seem all that paradoxical. Okay, sure. Shrinking your estimates towards the sample mean seems to be a better estimate of the end-of-season average rather than the maximum likelihood estimate. The underlying assumption many people have upon encountering this is that there is some common cause for all these baseball players, and our estimate takes into account the effect of this common cause. But James and Stein show us that you can take totally unrelated variables, like batting average and foreign car imports, and shrink the grand average of those variables, resulting in a better estimate! This does rely on the fact that the proportion of imported cars is close to the mean batting average.

Another point of interest I hope to explore elsewhere (lest this turn into a paper) is that the James-Stein estimator is not shift invariant.

Let \mathbf{X} = (X_1, X_2, \dots, X_n) represent a sample of data points, and let \hat{\theta}(\mathbf{X}) be an estimator of some parameter \theta. The estimator \hat{\theta}(\mathbf{X}) is said to be \textit{shift invariant} if, for any constant c, the following condition holds:

\hat{\theta}(\mathbf{X} + c) = \hat{\theta}(\mathbf{X}) + c

where \mathbf{X} + c = (X_1 + c, X_2 + c, \dots, X_n + c).

Shift invariance is one aspect of what is called the Invariance Principle, which intuitively means that inference shouldn’t depend on the idiosyncrasies of our unit of measure or other group-invariant transformations. An interesting fact to note is that invariance might not be such an important property for Bayesians, given that it doesn’t take into account how our prior might behave. I hope to explore these connections in a future post.

5 Must Frequentists Care About Admissibaility?

I.J. Good estimated through a combinatorial argument that there are at least 46,656 types of Bayesianism. According to my humble estimate, there are at least two types of Frequentists. One type, developed by Abraham Wald, is the decision-theoretic strain. The other is the Mayo-Spanos interpretation. I can think of two ways to interpret the former strain: a strict interpretation with limited scope and a loose interpretation with a purported claim to wider scope.

The first interpretation takes the tools decision theory provides (i.e., admissibility, optimality, domination, risk, etc.) and applies a strict interpretation to them. This means that if we confront a decision about our model, then the purpose it’s supposed to serve and the outcomes we care about are immediately well-defined by the loss function that we are considering minimizing. Most likely, we don’t know the most optimal function, but we have some idea of what a good function is for this problem.

For example, if you’re working in manufacturing and wish to measure the error in the machining of some parts for quality control purposes, then you may consider the MSE a fairly good loss function. In industrial decision-making, there is enough outside information about the procedure and the expected quality that you can operationalize the parameter of interest through your loss function. This applies to a wide range of business and industrial decision-making problems.

The difficulty with this strain of frequentism is that we don’t have an uncontroversial loss function in most scientific problems. Instead of speaking of loss functions, we can instead speak of values, and scientists are notoriously at odds regarding what they value in “good” science. They might speak of explanatory power, simplicity, scope, and/or unification. But these are all expressed in extraordinarily different ways. Not only that, but most of these values aren’t easily converted into a function.

So far, we can identify a split between two decision-theoretic-based frequentisms. One thinks Wald’s notions are only properly applied in strict cases. The other considers subsuming the entirety of science under them.

Deborah Mayo and Aris Spanos have argued for a reinterpretation of NP-style frequentism. What’s relevant for our discussion is that Mayo and Spanos warn against taking decision-making as the fundamental guiding aspect of frequentism. They stress that we should care about the fixed state of nature, which we are attempting to estimate. But this contrasts with caring primarily about admissibility. The universal quantifier in our inequality in section 2.1 requires us to quantify over all values of \theta in our \Theta. This is contrary to the fixed state of nature \theta_{true}.

The Mayo-Spanos frequentist does not care about decision-making in the traditional sense. They disagree that admissibility is a necessary or sufficient property for a good decision. Instead, they want decisions to be guided by certain frequentist properties. One recommendation is to consider only consistent estimators. A consistent estimator is one for which, when the estimate is considered as a random variable indexed by the number n of items in the dataset, the estimates converge in probability to the value that the estimator is designed to estimate as n increases. Formally, let \hat{\theta}_n be an estimator of a parameter \theta based on a sample of size n. The estimator \hat{\theta}_n is weakly consistent if:

\lim_{n \to \infty} \Pr\left(|\hat{\theta}_n - \theta| > \varepsilon\right) = 0

It is strongly consistent if:

\Pr\left(\lim_{n \to \infty} \hat{\theta}_n = \theta\right) = 1

Spanos endorses strong consistency as a necessary condition for estimators. However, we would like a definition of consistency that guides decisions when \theta is unknown. But this would require quantifying over all values of \theta! It seems a bit strange that the recommendation from the Mayo-Spanos view might involve giving up on evaluating good estimators based on better performance over all possible values of \theta.

Now that we have elaborated on some species of frequentism, here are the facets we’ve considered so far:

  1. It makes sense to quantify over all values of \theta when considering optimal estimators: (1) always, (2) sometimes, or (3) never.

  2. Estimators must be (1) admissible, (2) admissible and consistent, or (3) consistent.

  3. Loss functions are (1) always, (2) sometimes, or (3) never sufficiently fine-grained enough to capture what we value.

I would guess that you couldn’t believe in 1.2 and 2.1. What seems reasonable to me is to believe in 1.2, 2.2, and 3.2. I don’t have a horse in this race, but I would be hesitant to endorse any operationalizing of frequentism as the strong decision-theoretic view does. We’ll see that in the next section, this worry about operationalism crops up for Bayesians too.

6 Must Bayesian Care About Admissability?

6.1 Pragmatic Bayesianism

I’m not going to consider the varieties of Bayesianism, which total to the number of people who consider themselves Bayesian. There is a nice discussion of the philosophically relevant strains of Bayesianism and Stein’s paradox in Vassend, Sober, and Fitelson. The view I’m going to take for granted is the pragmatic one that most Bayesian practitioners adopt today. I view this position as putting aside the philosophical inheritance from Savage and other subjectivists or behaviorists, recognizing that priors or credences are inadequately defined or reduced in the myriad of ways they’ve been thought to (Eriksson and Hajek), and being realistic about the fact that Bayesian modeling is going to serve different functions depending on the problem at hand. Models combine both a prior and likelihood, “each of which represents some compromise among scientific knowledge, mathematical convenience, and computational tractability” (Gleman and Shalzi) and the prior is not necessarily derivable from degrees of belief, the set of default priors, or any other privileged class of priors.

I don’t think it should be controversial to say that most instances of Bayesian inference, decision-making, or predictive modeling aren’t going to agree on what the prior or likelihood mean. It’s likely going to be a process of thinking about the context of the problem (e.g., do I have to optimize for some single decision in a pipeline? Do I have to decide if a treatment is effective or not? Do I wish to understand the underlying generative procedure?, etc.) and the data.

This is all I think is necessary to say to dissuade simple interpretations that our Bayesian forefathers handed down. I don’t think you need to have a positive story that unifies all these themes about modern Bayesianism to have a good methodology as long as you can recognize that thinking about the prior and likelihood together unifies Bayesian modeling.

6.2 Bayes Rule and Admissibility

Returning to the original question: must Bayesians care about admissibility? Well, it depends on the problem. Admissibility isn’t the be-all-end-all and isn’t all that important for most problems where we don’t care about the loss function. We might care about some other property if we believe, for example, that the sample median is the optimal estimator, even if we’ve specified a MSE loss function for convenience. I think Bayesians actually have more flexibility in justifying the rejection or acceptance of admissible estimators and the loss functions we find upstream than frequentists. It’s well known that reasonable Bayes rules are typically admissible concerning a specified prior. Bayes rules and admissible estimators likely coincide in many problems. The thought I want to focus on is whether or not Bayesians should care about this fact, given that Mayo-Spanos frequentists do not.

If Bayes’ rule has the same nice properties of admissibility and admissibility is what characterizes frequentist inference, then Bayesianism seems to be superior to frequentism, given that the Bayesian has the added advantage of the prior for flexibility. But Mayo and Spanos deny that admissibility is what characterizes frequentist inference. The weird thing I find with Spanos’ emphasis on \theta_{true} is that most frequentist methods are built on the assumption of asymptotics. Those assumptions rely on the model being perfect, which means the data-generating process your data is coming from has to be in \Theta and, with enough data, there is a narrowing of your set of estimates to your point estimate \theta_{true}. The issue with this is that we’ve given some very idealized assumptions. We’ve assumed \theta_{true} is in \Theta, and the repeated measurements, which give us our data, have no heterogeneity, correlations, and variations in the underlying process. This isn’t a problem that only plagues frequentists, but it seems to infect Mayo-Spanos style frequentism severely. Any misspecification of the model results in total ruin for those who are after \theta_{true}.

7 Priors, Shrinkage, and Pooling

The property that the James-Stein estimator puts on display in the baseball example is the ability to shrink each individual mean to the grand mean using a certain factor. The nice thing about this is that in Bayesian modeling, the shrinkage comes from the prior distribution. This becomes especially obvious when we use hierarchical models, which allow us to flexibly shrink our individual parameter estimates by assuming that they all come from some shared distribution yet have their own individual estimate. We strike a compromise between the grand mean and the individual group or individual means.

It’s helpful to have in mind an example when thinking of pooling. For the baseball example, we have a repeated binary trial for N baseball players. Each of these players has some ratio of hits to times-at-bat. Using complete pooling for this results in each player having the same probability of success. This means they all come from the same distribution if we were thinking of the N players as draws from a population with no variance. For our example, complete pooling would consist of having a single parameter for the chance of success for all the batters. No pooling would assume that these players are being drawn from a population with infinite variance. Each of these players has a probability of success that is completely independent of the others. Now, partial pooling strikes a compromise between the extremes. It allows us to encode uncertainty about the population of interest we’re sampling from. We can represent that population through some distribution with finite variance. The idea here is to have a hierarchical structure for each of the players’ parameters, which are describing their hitting ability.

8 A Hierarchical Baseball Model

I follow the treatment by Bob Carpenter for the Stan partial pooling and no pooling model.

8.1 Binomial Logit with no Pooling

from cmdstanpy import CmdStanModel, cmdstan_path, set_cmdstan_path

print(baseball.iloc[:, :4])
         player  hits  times_at_bat  true_batting_average
0      Clemente    18            45                 0.346
1    F Robinson    17            45                 0.298
2      F Howard    16            45                 0.276
3     Johnstone    15            45                 0.222
4         Berry    14            45                 0.273
5       Spencer    14            45                 0.270
6     Kessinger    13            45                 0.263
7    L Alvarado    12            45                 0.210
8         Santo    11            45                 0.269
9       Swoboda    11            45                 0.230
10        Unser    10            45                 0.264
11     Williams    10            45                 0.256
12        Scott    10            45                 0.303
13   Petrocelli    10            45                 0.264
14  E Rodriguez    10            45                 0.226
15   Campaneris     9            45                 0.286
16       Munson     8            45                 0.316
17        Alvis     7            45                 0.200

First, we’ll write down our no pooling model to make sure it makes sense before we code it up. We have y_n baseball hits out of T_n trials for each n player. Those hits follow a binomial distribution. We’ll use a binomial logit for the ease of addition of further effects when we implement partial pooling. Since there is no pooling and not much prior knowledge, each player’s prior probability is represented with some kind of diffuse prior. We’ll use a half normal for the prior.

\alpha_n \sim \mathcal{N}(0, 1)

\theta_n = \frac{1}{1 + e^{-\alpha_n}} = \text{logit}^{-1}(\alpha_n)

y_n \sim \text{BinomialLogit}(T_i, \theta_n)

Now we can get our data ready for our compiled model and sample using CmdStan.

mode1_logit_no_pooling.stan
data {
  int<lower=0> N;              // Number of players
  array[N] int<lower=0> Tr;     // Trials for each player
  array[N] int<lower=0> y;      // Successes for each player
}

parameters {
  vector[N] alpha;              // Logit-transformed chance of success for each player
}

model {
  alpha ~ normal(0, 10);       // Prior for alpha (logit scale)

  for (n in 1:N) {
    target += binomial_logit_lpmf(y[n] | Tr[n], alpha[n]);
  }
}

generated quantities {
   array[N] int y_pred;          // Predicted number of successes
   array[N] real p_hat_pred;              // Mean predicted probability of success

   for (n in 1:N) {
     y_pred[n] = binomial_rng(Tr[n], inv_logit(alpha[n]));
     p_hat_pred[n] = inv_logit(alpha[n]);
   }

}

We can define our data dictionary for our model and load our model using CmdStan.

stan_data = {
  'N': len(baseball['player']),
  'Tr': baseball['times_at_bat'],
  'y': baseball['hits'],
}

model_path_1 = os.path.join(current_working_directory, 'models', 'model1_logit_no_pooling.stan')
stan_model_1 = CmdStanModel(stan_file=model_path_1)

We fit the model and check for diagnostics now.

fit = stan_model_1.sample(data=stan_data, seed=RANDOM_SEED, chains=4, iter_sampling=2000, iter_warmup=1000, show_progress=False )
13:50:01 - cmdstanpy - INFO - CmdStan start processing
13:50:01 - cmdstanpy - INFO - Chain [1] start processing
13:50:01 - cmdstanpy - INFO - Chain [2] start processing
13:50:01 - cmdstanpy - INFO - Chain [3] start processing
13:50:01 - cmdstanpy - INFO - Chain [4] start processing
13:50:02 - cmdstanpy - INFO - Chain [3] done processing
13:50:02 - cmdstanpy - INFO - Chain [2] done processing
13:50:02 - cmdstanpy - INFO - Chain [4] done processing
13:50:02 - cmdstanpy - INFO - Chain [1] done processing
print(fit.diagnose())
Processing csv files: C:\Users\ISSAM_~1\AppData\Local\Temp\tmpopw2zneq\model1_logit_no_poolingd8cdx05q\model1_logit_no_pooling-20241007135001_1.csv, C:\Users\ISSAM_~1\AppData\Local\Temp\tmpopw2zneq\model1_logit_no_poolingd8cdx05q\model1_logit_no_pooling-20241007135001_2.csv, C:\Users\ISSAM_~1\AppData\Local\Temp\tmpopw2zneq\model1_logit_no_poolingd8cdx05q\model1_logit_no_pooling-20241007135001_3.csv, C:\Users\ISSAM_~1\AppData\Local\Temp\tmpopw2zneq\model1_logit_no_poolingd8cdx05q\model1_logit_no_pooling-20241007135001_4.csv

Checking sampler transitions treedepth.
Treedepth satisfactory for all transitions.

Checking sampler transitions for divergences.
No divergent transitions found.

Checking E-BFMI - sampler transitions HMC potential energy.
E-BFMI satisfactory.

Effective sample size satisfactory.

Split R-hat values satisfactory all parameters.

Processing complete, no problems detected.

Everything looks fine. We can plot our posterior predictive check to see how our no pooling model replicated the true batting averages. Here is a summary of our model fit and generated quantities. Everything looks pretty good.

output = fit.summary(percentiles=(5, 50, 95), sig_figs=6)
print(output)
                     Mean      MCSE    StdDev         5%        50%  \
lp__           -45.158500  0.064482  3.038530 -50.592100 -44.810600   
alpha[1]        -0.416262  0.002129  0.306855  -0.928830  -0.414002   
alpha[2]        -0.510640  0.002217  0.312236  -1.030860  -0.504429   
alpha[3]        -0.607732  0.002212  0.309840  -1.121190  -0.601879   
alpha[4]        -0.711733  0.002383  0.322517  -1.257140  -0.703914   
alpha[5]        -0.815590  0.002497  0.330557  -1.370300  -0.810103   
alpha[6]        -0.817324  0.002403  0.323989  -1.357860  -0.810950   
alpha[7]        -0.924998  0.002584  0.341652  -1.503300  -0.919467   
alpha[8]        -1.037390  0.002460  0.332247  -1.593070  -1.032730   
alpha[9]        -1.156750  0.002548  0.357371  -1.771640  -1.143840   
alpha[10]       -1.160130  0.002613  0.356644  -1.769180  -1.144900   
alpha[11]       -1.287640  0.002597  0.375680  -1.925090  -1.275410   
alpha[12]       -1.289090  0.002732  0.361543  -1.907470  -1.278770   
alpha[13]       -1.284620  0.002728  0.367312  -1.918880  -1.271320   
alpha[14]       -1.286700  0.002836  0.362499  -1.901300  -1.274040   
alpha[15]       -1.285880  0.002720  0.364642  -1.906100  -1.273940   
alpha[16]       -1.421960  0.002703  0.375895  -2.062690  -1.409960   
alpha[17]       -1.578110  0.002989  0.402005  -2.266880  -1.562700   
alpha[18]       -1.749130  0.003352  0.420999  -2.461340  -1.726590   
y_pred[1]       17.952600  0.042752  4.615440  10.000000  18.000000   
y_pred[2]       16.994200  0.043285  4.596700  10.000000  17.000000   
y_pred[3]       15.945200  0.041726  4.476690   9.000000  16.000000   
y_pred[4]       15.033600  0.040538  4.408220   8.000000  15.000000   
y_pred[5]       14.010200  0.041986  4.324840   7.000000  14.000000   
y_pred[6]       13.906600  0.040439  4.288930   7.000000  14.000000   
y_pred[7]       13.048100  0.040699  4.277860   6.000000  13.000000   
y_pred[8]       11.976400  0.040026  4.099840   6.000000  12.000000   
y_pred[9]       10.999800  0.038663  4.038720   5.000000  11.000000   
y_pred[10]      10.998900  0.037719  4.052640   5.000000  11.000000   
y_pred[11]      10.103200  0.035365  3.985330   4.000000  10.000000   
y_pred[12]      10.016000  0.038151  3.847060   4.000000  10.000000   
y_pred[13]      10.041700  0.037026  3.897100   4.000000  10.000000   
y_pred[14]       9.987250  0.036800  3.918800   4.000000  10.000000   
y_pred[15]      10.021400  0.036613  3.960130   4.000000  10.000000   
y_pred[16]       9.044750  0.036844  3.770150   3.000000   9.000000   
y_pred[17]       8.093250  0.032784  3.600060   3.000000   8.000000   
y_pred[18]       6.970120  0.032508  3.368260   2.000000   7.000000   
p_hat_pred[1]    0.399663  0.000496  0.071897   0.283162   0.397953   
p_hat_pred[2]    0.377802  0.000507  0.071677   0.262917   0.376500   
p_hat_pred[3]    0.355707  0.000492  0.069423   0.245791   0.353914   
p_hat_pred[4]    0.333036  0.000494  0.069703   0.221467   0.330945   
p_hat_pred[5]    0.311048  0.000516  0.069000   0.202571   0.307869   
p_hat_pred[6]    0.310519  0.000493  0.067619   0.204589   0.307688   
p_hat_pred[7]    0.288899  0.000510  0.068317   0.181934   0.285066   
p_hat_pred[8]    0.266587  0.000470  0.063268   0.168952   0.262555   
p_hat_pred[9]    0.245135  0.000446  0.063983   0.145339   0.241616   
p_hat_pred[10]   0.244490  0.000463  0.063809   0.145644   0.241421   
p_hat_pred[11]   0.222822  0.000426  0.063127   0.127295   0.218333   
p_hat_pred[12]   0.222094  0.000442  0.060369   0.129266   0.217760   
p_hat_pred[13]   0.223039  0.000435  0.061582   0.127987   0.219032   
p_hat_pred[14]   0.222531  0.000467  0.060792   0.129961   0.218566   
p_hat_pred[15]   0.222740  0.000428  0.061153   0.129420   0.218584   
p_hat_pred[16]   0.200900  0.000415  0.058343   0.112776   0.196240   
p_hat_pred[17]   0.178341  0.000403  0.056696   0.093903   0.173260   
p_hat_pred[18]   0.155755  0.000397  0.053106   0.078614   0.151025   

                      95%     N_Eff   N_Eff/s     R_hat  
lp__           -40.800500   2220.51   2925.57  1.001880  
alpha[1]         0.080643  20774.90  27371.50  0.999701  
alpha[2]        -0.002872  19831.10  26127.90  0.999663  
alpha[3]        -0.107446  19622.90  25853.60  0.999557  
alpha[4]        -0.192812  18320.50  24137.70  0.999712  
alpha[5]        -0.278068  17531.70  23098.40  0.999784  
alpha[6]        -0.298183  18178.70  23950.90  0.999608  
alpha[7]        -0.377633  17478.20  23027.90  0.999726  
alpha[8]        -0.507243  18242.10  24034.30  0.999547  
alpha[9]        -0.590991  19675.70  25923.10  0.999589  
alpha[10]       -0.593385  18622.10  24535.10  0.999663  
alpha[11]       -0.680491  20933.90  27580.80  0.999554  
alpha[12]       -0.718664  17515.40  23077.00  0.999799  
alpha[13]       -0.700289  18129.70  23886.40  0.999736  
alpha[14]       -0.706650  16333.60  21519.90  0.999677  
alpha[15]       -0.701349  17973.00  23679.80  0.999677  
alpha[16]       -0.830821  19333.90  25472.90  0.999811  
alpha[17]       -0.951667  18094.70  23840.20  0.999712  
alpha[18]       -1.090360  15771.40  20779.20  0.999698  
y_pred[1]       26.000000  11654.70  15355.40  0.999869  
y_pred[2]       25.000000  11277.30  14858.10  1.000030  
y_pred[3]       23.000000  11510.50  15165.40  0.999613  
y_pred[4]       22.000000  11824.90  15579.60  0.999988  
y_pred[5]       21.000000  10610.20  13979.20  1.000300  
y_pred[6]       21.000000  11248.70  14820.40  0.999858  
y_pred[7]       20.000000  11048.10  14556.10  0.999731  
y_pred[8]       19.000000  10491.90  13823.30  0.999809  
y_pred[9]       18.000000  10911.60  14376.30  0.999712  
y_pred[10]      18.000000  11543.80  15209.30  0.999806  
y_pred[11]      17.000000  12699.50  16731.90  0.999876  
y_pred[12]      17.000000  10168.40  13397.20  0.999782  
y_pred[13]      17.000000  11078.50  14596.20  0.999698  
y_pred[14]      17.000000  11340.10  14940.90  0.999782  
y_pred[15]      17.000000  11698.90  15413.60  0.999908  
y_pred[16]      16.000000  10471.10  13795.90  1.000180  
y_pred[17]      15.000000  12058.90  15887.90  0.999607  
y_pred[18]      13.000000  10736.00  14144.90  1.000020  
p_hat_pred[1]    0.520150  21005.10  27674.70  0.999734  
p_hat_pred[2]    0.499282  20010.80  26364.70  0.999679  
p_hat_pred[3]    0.473164  19912.30  26234.90  0.999576  
p_hat_pred[4]    0.451946  19943.80  26276.40  0.999709  
p_hat_pred[5]    0.430928  17870.70  23545.10  0.999889  
p_hat_pred[6]    0.426002  18822.90  24799.60  0.999608  
p_hat_pred[7]    0.406698  17935.70  23630.70  0.999745  
p_hat_pred[8]    0.375840  18143.30  23904.20  0.999545  
p_hat_pred[9]    0.356408  20588.60  27126.00  0.999568  
p_hat_pred[10]   0.355859  19013.10  25050.20  0.999637  
p_hat_pred[11]   0.336152  21971.30  28947.70  0.999567  
p_hat_pred[12]   0.327687  18694.40  24630.30  0.999722  
p_hat_pred[13]   0.331748  20071.50  26444.70  0.999656  
p_hat_pred[14]   0.330340  16912.30  22282.30  0.999655  
p_hat_pred[15]   0.331513  20384.20  26856.70  0.999669  
p_hat_pred[16]   0.303472  19771.50  26049.40  0.999790  
p_hat_pred[17]   0.278550  19791.00  26075.10  0.999653  
p_hat_pred[18]   0.251551  17855.30  23524.80  0.999803  

We have two plots: one for the no pooling posterior predictive check and one for the batting average or \alpha. The posterior predictive quantities are fairly accurate, but the intervals are quite wide. On the other hand, we have a baseball batting average completely outside our interval. Mumson and Johnstone are anomalies. Pooling should help, since the interval around the grand mean seems to capture all the values fairly well.

fig, axes = plt.subplots(1, 2, figsize=(15, 7))  # Adjust size as needed

# First subplot: Posterior Predictive Check
idata_1 = az.from_cmdstanpy(
    posterior=fit,
    observed_data={'y': baseball['hits']},
    dims={'y_pred': ['player'], 'alpha': ['player']},
    coords={'player': baseball['player']}
)

az.plot_forest(
    idata_1,
    kind='forestplot',
    var_names=["y_pred"],
    filter_vars="regex",
    combined=True,
    colors=[c_mid_highlight],
    ax=axes[0], 
)

y_ticks = axes[0].get_yticks()  
axes[0].scatter(baseball['hits'], list(reversed(y_ticks)), color='black', zorder=3, label='True Values')
axes[0].set_ylim(min(y_ticks) - 0.5, max(y_ticks) + 0.5)
axes[0].set_title('No Pooling Posterior Predictive Check with 94% HDI')

# Second subplot: Batting Average
idata_2 = az.from_cmdstanpy(
    posterior=fit,
    observed_data={'y': baseball['hits']},
    dims={'y_pred': ['player'], 'p_hat_pred': ['player']},
    coords={'player': baseball['player']}
)

az.plot_forest(
    idata_2,
    kind='forestplot',
    var_names=["p_hat_pred"],
    filter_vars="regex",
    combined=True,
    colors=[c_mid_highlight],
    ax=axes[1], 
)

y_ticks_2 = axes[1].get_yticks()  
axes[1].scatter(baseball['true_batting_average'], list(reversed(y_ticks_2)), color='black', zorder=3, label='True Batting Average')
axes[1].set_ylim(min(y_ticks_2) - 0.5, max(y_ticks_2) + 0.5)
axes[1].set_title('No Pooling Batting Average with 94% HDI')

plt.tight_layout()
plt.show()

8.2 Hierachical Binomial Logit

We observed that we had a fairly generic fit of our data with our model, which is to be expected. Now we can implement some partial pooling to account for our knowledge about the shared baseball player distribution from which our players originate. We will use the same model as before but add a hyperprior on our batting ability. We can provide some weakly informative hyperprior on where our values tend to cluster. Based on our graph, we assume, as a first approximation, that they are distributed according to a half-normal. All the values are between 0 and 1, so we can satisfactorily specify our deviation for \sigma as 1.

\alpha_n \vert \mu, \sigma \sim \mathcal{N}(\mu, \sigma)

\mu \vert \sigma \sim \mathcal{N}(0. \sigma)

\sigma \sim \mathcal{N}(0, 1)

\theta_n = \frac{1}{1 + e^{-\alpha_n}} = \text{logit}^{-1}(\alpha_n)

y_n \sim \text{BinomialLogit}(T_i, \theta_n)

model2_logit_partial_pooling.stan
data {
  int<lower=0> N;              // Number of players
  array[N] int<lower=0> Tr;     // Trials for each player
  array[N] int<lower=0> y;      // Successes for each player
}

parameters {
  vector[N] alpha;              // Logit-transformed chance of success for each player
  real mu;
  real<lower=0> sigma; 
}

model {
  sigma ~ normal(0, 1);   // Hyperprior
  mu ~ normal(0.3, sigma);    // Hyperprior
  alpha ~ normal(mu, sigma);

  for (n in 1:N) {
    target += binomial_logit_lpmf(y[n] | Tr[n], mu + alpha[n] * sigma);
  }
}

generated quantities {
   array[N] int y_pred;          // Predicted number of successes
   array[N] real p_hat_pred;              // Mean predicted probability of success


   for (n in 1:N) {
     y_pred[n] = binomial_rng(Tr[n], inv_logit(mu + alpha[n] * sigma));
     p_hat_pred[n] = inv_logit(mu + alpha[n] * sigma);
   }

}
model_path_2 = os.path.join(current_working_directory, 'models', 'model2_logit_partial_pooling.stan')
stan_model_2 = CmdStanModel(stan_file=model_path_2)

We fit the model and check for diagnostics now.

fit2 = stan_model_2.sample(data=stan_data, seed=RANDOM_SEED, chains=4, iter_sampling=2000, iter_warmup=1000, show_progress=False)
13:50:09 - cmdstanpy - INFO - CmdStan start processing
13:50:09 - cmdstanpy - INFO - Chain [1] start processing
13:50:09 - cmdstanpy - INFO - Chain [2] start processing
13:50:09 - cmdstanpy - INFO - Chain [3] start processing
13:50:09 - cmdstanpy - INFO - Chain [4] start processing
13:50:09 - cmdstanpy - INFO - Chain [2] done processing
13:50:10 - cmdstanpy - INFO - Chain [1] done processing
13:50:10 - cmdstanpy - INFO - Chain [3] done processing
13:50:10 - cmdstanpy - INFO - Chain [4] done processing
13:50:10 - cmdstanpy - WARNING - Non-fatal error during sampling:
Exception: binomial_logit_lpmf: Probability parameter is -inf, but must be finite! (in 'model2_logit_partial_pooling.stan', line 19, column 4 to column 71)
Exception: binomial_logit_lpmf: Probability parameter is -inf, but must be finite! (in 'model2_logit_partial_pooling.stan', line 19, column 4 to column 71)
Exception: binomial_logit_lpmf: Probability parameter is inf, but must be finite! (in 'model2_logit_partial_pooling.stan', line 19, column 4 to column 71)
Exception: binomial_logit_lpmf: Probability parameter is -inf, but must be finite! (in 'model2_logit_partial_pooling.stan', line 19, column 4 to column 71)
Consider re-running with show_console=True if the above output is unclear!

Even though we achieve convergence with good diagnostics, we still encounter the following error while fitting our model: “Exception: binomial_logit_lpmf: Probability parameter is -inf.” This is largely because we are using a centered parameterization. A quick fix is the model below that employs a non-centered parameterization.

print(fit2.diagnose())
Processing csv files: C:\Users\ISSAM_~1\AppData\Local\Temp\tmpopw2zneq\model2_logit_partial_poolingq_orbb9w\model2_logit_partial_pooling-20241007135009_1.csv, C:\Users\ISSAM_~1\AppData\Local\Temp\tmpopw2zneq\model2_logit_partial_poolingq_orbb9w\model2_logit_partial_pooling-20241007135009_2.csv, C:\Users\ISSAM_~1\AppData\Local\Temp\tmpopw2zneq\model2_logit_partial_poolingq_orbb9w\model2_logit_partial_pooling-20241007135009_3.csv, C:\Users\ISSAM_~1\AppData\Local\Temp\tmpopw2zneq\model2_logit_partial_poolingq_orbb9w\model2_logit_partial_pooling-20241007135009_4.csv

Checking sampler transitions treedepth.
Treedepth satisfactory for all transitions.

Checking sampler transitions for divergences.
No divergent transitions found.

Checking E-BFMI - sampler transitions HMC potential energy.
E-BFMI satisfactory.

Effective sample size satisfactory.

Split R-hat values satisfactory all parameters.

Processing complete, no problems detected.
non-centered_model2.stan
data {
  int<lower=0> N;              // Number of players
  array[N] int<lower=0> Tr;     // Trials for each player
  array[N] int<lower=0> y;      // Successes for each player
}

parameters {
  vector[N] alpha_raw;          // Raw parameters to be scaled
  real mu;
  real<lower=0> sigma;
}

transformed parameters {
  vector[N] alpha;
  alpha = mu + sigma * alpha_raw; // Scaling alpha_raw by sigma
}

model {
  // Priors
  sigma ~ normal(0, 1);
  mu ~ normal(0.3, 1);
  alpha_raw ~ normal(0, 1);

  // Likelihood
  for (n in 1:N) {
    target += binomial_logit_lpmf(y[n] | Tr[n], alpha[n]);
  }
}

generated quantities {
  array[N] int y_pred;          // Predicted number of successes
  array[N] real p_hat_pred;     // Mean predicted probability of success

  for (n in 1:N) {
    p_hat_pred[n] = inv_logit(alpha[n]);
    y_pred[n] = binomial_rng(Tr[n], p_hat_pred[n]);
  }
}

We quickly rerun it to check if the errors are resolved.

model_path_2 = os.path.join(current_working_directory, 'models', 'non-centered_model2.stan')
stan_model_2 = CmdStanModel(stan_file=model_path_2)

We fit the model again and see that the errors are gone. Although the centered parameterization doesn’t significantly affect our results in this case, it typically has a negative impact on convergence in our numerical calculations of the log density.

fit2 = stan_model_2.sample(data=stan_data, seed=RANDOM_SEED, chains=4, iter_sampling=2000, iter_warmup=1000, show_progress=False)
13:50:14 - cmdstanpy - INFO - CmdStan start processing
13:50:14 - cmdstanpy - INFO - Chain [1] start processing
13:50:14 - cmdstanpy - INFO - Chain [2] start processing
13:50:14 - cmdstanpy - INFO - Chain [3] start processing
13:50:14 - cmdstanpy - INFO - Chain [4] start processing
13:50:15 - cmdstanpy - INFO - Chain [2] done processing
13:50:15 - cmdstanpy - INFO - Chain [3] done processing
13:50:15 - cmdstanpy - INFO - Chain [4] done processing
13:50:16 - cmdstanpy - INFO - Chain [1] done processing

As we can see below, we achieve a considerable number of our point predictions and intervals closer to the true values.

fig, axes = plt.subplots(1, 2, figsize=(15, 7))

idata_3 = az.from_cmdstanpy(
    posterior=fit2,
    observed_data={'y': baseball['hits']},
    dims={'y_pred': ['player'], 'p_hat_pred': ['player']},
    coords={'player': baseball['player']}
    
)

az.plot_forest(
    idata_3,
    kind='forestplot',
    var_names=["p_hat_pred"],
    filter_vars="regex",
    combined=True,
    colors=[c_mid_highlight],
    ax=axes[0], 
    rope = [.2, .3],
)

y_ticks_3 = axes[0].get_yticks()  
axes[0].scatter(baseball['true_batting_average'], list(reversed(y_ticks_3)), color='black', zorder=3, label='True Batting Average')
axes[0].set_ylim(min(y_ticks_3) - 0.5, max(y_ticks_3) + 0.5)
axes[0].set_title('Partial Pooling Batting Average with 94% HDI')

idata_2 = az.from_cmdstanpy(
    posterior=fit,
    observed_data={'y': baseball['hits']},
    dims={'y_pred': ['player'], 'p_hat_pred': ['player']},
    coords={'player': baseball['player']}
)

az.plot_forest(
    idata_2,
    kind='forestplot',
    var_names=["p_hat_pred"],
    filter_vars="regex",
    combined=True,
    colors=[c_mid_highlight],
    ax=axes[1], 
    rope = [.2, .3]
)

y_ticks_4 = axes[1].get_yticks()  
axes[1].scatter(baseball['true_batting_average'], list(reversed(y_ticks_4)), color='black', zorder=3, label='True Batting Average')
axes[1].set_ylim(min(y_ticks_4) - 0.5, max(y_ticks_4) + 0.5)
axes[1].set_title('No Pooling Batting Average with 94% HDI')

axes[1].set_xlim(0, 0.5) 
axes[0].set_xlim(0, 0.5)

plt.tight_layout()
plt.show()

The effect we observe here is very similar to what occurs for the James-Stein estimator. However, it is markedly different from a vanilla Bayes estimator that provides only a constant shrinkage factor through some non-empirical prior. Note that the James-Stein estimator can be derived from an empirical Bayes approach (Effron and Morris, 1973).

9 Conclusion

There is a nice interplay between the geometric properties of higher-dimensional spaces and the way probability densities behave.

from scipy.spatial import distance_matrix

def generate_points_on_hypersphere(n_points, dimensions):
    points = np.random.randn(n_points, dimensions)
    points /= np.linalg.norm(points, axis=1)[:, np.newaxis]  
    return points

n_points = 100  
dims = [3, 10, 30]  
distances = []
fig = plt.figure(figsize=(14, 6))

ax1 = fig.add_subplot(121)
colors = [c_light, c_mid, c_dark]
for i, d in enumerate(dims):
    points = generate_points_on_hypersphere(n_points, d)
    dists = distance_matrix(points, points)
    distances.append(dists.flatten())
    ax1.hist(dists.flatten(), bins=30, alpha=0.5, label=f'Dimension {d}', density=True, color=colors[i])

ax1.set_xlabel('Distance')
ax1.set_ylabel('Density')
ax1.set_title('Distribution of Distances Between Random Points on a Hypersphere')
ax1.legend()
ax1.grid()

ax2 = fig.add_subplot(122, projection='3d')
n_points_3d = 2000 
points_3d = generate_points_on_hypersphere(n_points_3d, dimensions=3)
ax2.scatter(points_3d[:, 0], points_3d[:, 1], points_3d[:, 2], alpha=0.6, color=c_dark)
ax2.set_title('Random Points on a 3D Hypersphere')

plt.tight_layout()
plt.show()

As the number of dimensions increases in Euclidean space, the geometry behaves quite differently compared to lower-dimensional spaces. In high-dimensional spaces, most of the volume of a hypersphere is concentrated near its surface rather than in its center. This means that points sampled uniformly from a high-dimensional space are likely to be found near the edges rather than close to the origin. An additional concern is that as dimensions increase, the distances between random points become more uniform, making it difficult to distinguish between them. The distances between points tend to cluster around a mean distance, complicating the identification of whether two points are close or far apart.

The concept of volume in high-dimensional spaces is also counterintuitive. While the volume of a hypersphere grows with the dimension, the volume of the enclosing hypercube grows as well. As dimensions increase, the ratio of the volume of the hypersphere to that of the hypercube decreases significantly, indicating that the actual “usable” volume—where most points are likely to lie—becomes smaller relative to the total volume available. When shrinkage is applied to an estimator, it effectively constrains the estimator to a smaller region of space. In high dimensions, this can lead to a significant loss of volume, especially if the shrinkage target is poorly chosen. You’re likely to get a good estimate of where density is most concentrated, but you may miss any density that shows up in the tails. This is nicely illustrated in the distance metric distributions above, where some of the random points cluster near each other far away from the others, usually in the tails.

Overall, the behavior of estimators in high-dimensional spaces is influenced by the geometry described above. We’ve seen that the Stein Effect occurs when a shrinkage estimator provides better estimates than an unbiased estimator because the shrinkage brings the estimates closer to the truth—typically zero—while accounting for the concentration of measure. The curvature of the sphere helps justify why shrinking towards a central point, such as the origin, can be beneficial, as it reduces variance more effectively than simply using the sample mean. However, the Reverse Stein Effect highlights the risks of using shrinkage without a reliable target. If the shrinkage target is selected based on data X without understanding its true relationship to the parameter \delta, there is a danger of shrinking to a point that is far from the actual parameter location. The curvature of the hypersphere exacerbates this issue by making it challenging to assess how far the shrinkage target is from the true parameter, particularly when the data points are concentrated at the boundaries. Perlman and Chaudhuri have a fascinating little paper that elaborates on this, with the added benefit of a mashup of Star Trek and statistics in the explanation.

10 Acknowledgements

If you’ve happened across Michael Betancourt’s beautiful writeups you’d recognize the overall typesetting and styling for the CSS and graphics. I really liked that style, so I’ve adapted it to my purposes.

Note: I’ll likely write up the references later if I have the time.