Mixture Density Networks

Why standard regression collapses on inverse problems

Standard neural networks predict a single output for a given input. That assumption is rarely questioned until you try to invert a function.

This tutorial shows exactly where and why that assumption breaks, and how Mixture Density Networks (MDNs) fix it by modeling distributions, not point estimates.

Part 1: The Forward Problem

Start with a function that is slightly nonlinear but well-behaved:

Forward Function $$y = x + 0.3 \sin(2\pi x) + \epsilon$$
where $\epsilon \sim \text{Uniform}(-0.1, 0.1)$ is small random noise.

For each input $x$, there is essentially one valid output $y$. That means $p(y|x)$ is unimodal, the conditional mean $\mathbb{E}[y|x]$ is meaningful, and MSE is a valid objective.

Forward Problem: y = f(x)
Data
MLP
Epoch:0
MSE:-
Status:Ready
Result

A plain MLP trained with MSE converges cleanly. The network learns $\mathbb{E}[y|x]$.

Part 2: The Inverse Problem

Now swap the axes. Same data, different question:

The Inversion $$\text{Forward: } y = f(x) \quad \longrightarrow \quad \text{Inverse: } x = f^{-1}(y)$$
Given $y$, what was $x$? This is no longer well-posed.

Because the forward function is non-monotonic, multiple $x$ values map to the same $y$. Now $p(x|y)$ is multi-modal and the conditional mean lies between valid solutions.

Inverse Problem: x = f⁻¹(y)
Data
MLP
Epoch:0
MSE:-
Status:Ready
Result

The MLP produces $\mathbb{E}[x \mid y]$, a curve through empty space. This is a loss-model mismatch.

This shows up everywhere

Inverse kinematics, vision (depth ambiguity), physics (cause from effect), control and planning. If your data has multiple valid outputs, point regression is wrong.

Part 3: What MDNs Change

MDNs do not make the network more powerful. They change what it represents: parameters of a conditional distribution instead of a single value.

Gaussian Mixture Model $$p(x \mid y) = \sum_{k=1}^{K} \pi_k(y) \cdot \mathcal{N}\big(x \mid \mu_k(y), \sigma_k(y)\big)$$

Each input $y$ produces K candidate solutions. Each has a probability mass. The Σ is just a loop:

gmm_probability.pypython
def gmm_probability(x, y, net):
    pi, mu, sigma = net.forward(y)
    total = 0
    for k in range(K): total += pi[k] * normal_pdf(x, mu[k], sigma[k])
    return total

Output Transformations

πk
Mixing weights
softmax → sum to 1
μk
Means
raw output
σk
Std deviations
exp → always positive

The Loss: Negative Log-Likelihood

NLL Loss $$\mathcal{L} = -\log p(x \mid y) = -\log \sum_{k=1}^{K} \pi_k \cdot \mathcal{N}(x \mid \mu_k, \sigma_k)$$
Maximize the probability the model assigns to the true target. Works for multi-modal distributions.
MDN on Inverse Problem
Data
Samples
Epoch:0
NLL:-
Status:Ready
Result

The MDN captures multiple modes. Samples follow the data distribution, not the mean.

Playground

Draw a distribution and compare MLP vs MDN:

Draw Your Own Distribution
Data
MLP
MDN Samples
Points:0
Epoch:0
Loss:-
Status:Draw points
K:

The Takeaway

summary.pypython
# MLP: single output
y_pred = network(x)  # point estimate
loss = mse(y_pred, y_true)

# MDN: distribution parameters
pi, mu, sigma = network(x)  # K components
p = sum(pi[k] * gaussian(y, mu[k], sigma[k]) for k in range(K))
loss = -log(p)  # NLL

MDNs are not magic. They change what the network is allowed to output. When your conditional distribution is multi-modal, that change is necessary.