(Re)setting blog and checking that code and math render OK, using a simple problem I recently worked a bit on. FWIW, goal was simply to instead of computing the global mean and sending that to nodes, just doing it in one go.

Math

We have a distributed dataset made of multiple local datasets:

D=DADBDC D = D_A \cup D_B \cup D_C \cup \cdots N=D,NA=DA,NB=DB, N = |D|,\quad N_A = |D_A|,\quad N_B = |D_B|,\quad \ldots

The global mean is:

μ=1Ni=1Ndi=E[D] \mu = \frac{1}{N}\sum_{i=1}^{N} d_i = \mathbb{E}[D]

The population variance is:

σ2=1Ni=1N(diμ)2=E[(Dμ)2] \sigma^2 = \frac{1}{N}\sum_{i=1}^{N}(d_i - \mu)^2 = \mathbb{E}\left[(D - \mu)^2\right]

But we do not have all of DD in one single place. However:

σ2=E[D2]E[D]2 \sigma^2 = \mathbb{E}[D^2] - \mathbb{E}[D]^2

Proof:

(Dμ)2=D22Dμ+μ2E[D22Dμ+μ2]=E[D2]E[2Dμ]+E[μ2]=E[D2]2μE[D]+μ2=E[D2]2μ2+μ2=E[D2]μ2=E[D2]E[D]2 \begin{aligned} (D - \mu)^2 &= D^2 - 2D\mu + \mu^2 \\ \mathbb{E}[D^2 - 2D\mu + \mu^2] &= \mathbb{E}[D^2] - \mathbb{E}[2D\mu] + \mathbb{E}[\mu^2] \\ &= \mathbb{E}[D^2] - 2\mu\mathbb{E}[D] + \mu^2 \\ &= \mathbb{E}[D^2] - 2\mu^2 + \mu^2 \\ &= \mathbb{E}[D^2] - \mu^2 \\ &= \mathbb{E}[D^2] - \mathbb{E}[D]^2 \end{aligned}

Now compute the two pieces separately over the distributed dataset.

E[D2]=i=1NA(di,A)2+i=1NB(di,B)2+NA+NB+ \mathbb{E}[D^2] = \frac{\sum_{i=1}^{N_A}(d_{i,A})^2 + \sum_{i=1}^{N_B}(d_{i,B})^2 + \cdots}{N_A + N_B + \cdots} E[D]=i=1NAdi,A+i=1NBdi,B+NA+NB+ \mathbb{E}[D] = \frac{\sum_{i=1}^{N_A} d_{i,A} + \sum_{i=1}^{N_B} d_{i,B} + \cdots}{N_A + N_B + \cdots}

So each partial node only needs to send three numbers:

  1. idi2\sum_i d_i^2
  2. idi\sum_i d_i
  3. NnodeN_{\text{node}}

That is:

(s,t,n) (s, t, n)

where:

s=local sum of squares,t=local sum,n=local sample count s = \text{local sum of squares},\quad t = \text{local sum},\quad n = \text{local sample count}

Then the central aggregator computes:

m2=kskknk m_2 = \frac{\sum_k s_k}{\sum_k n_k} m1=ktkknk m_1 = \frac{\sum_k t_k}{\sum_k n_k}

Finally:

σ2=m2m12 \sigma^2 = m_2 - m_1^2 standard deviation=σ2 \text{standard deviation} = \sqrt{\sigma^2}

Code

/// Computes local partial statistics for standard deviation
fn std_partial_stats(values: &[f64]) -> (f64, f64, usize) {
    // Sum of squares
    let mut s = 0.0;

    // Total sum
    let mut t = 0.0;

    // Count
    let mut n = 0;

    for &x in values {
        s += x * x;
        t += x;
        n += 1;
    }

    (s, t, n)
}
/// Aggregates local partial statistics into a standard deviation
fn aggregate_std(partials: &[(f64, f64, usize)]) -> f64 {
    // Global sum of squares
    let mut global_s = 0.0;

    // Global sum
    let mut global_t = 0.0;

    // Global count
    let mut global_n = 0;

    for &(s, t, n) in partials {
        global_s += s;
        global_t += t;
        global_n += n;
    }

    let m_2 = global_s / global_n as f64;
    let m_1 = global_t / global_n as f64;

    // sqrt(E[D^2] - E[D]^2)
    (m_2 - m_1 * m_1).sqrt()
}