Standard neural networks predict a single output for a given input. That assumption is rarely questioned until you try to invert a function.
This tutorial shows exactly where and why that assumption breaks, and how Mixture Density Networks (MDNs) fix it by modeling distributions, not point estimates.
Part 1: The Forward Problem
Everything works
Start with a function that is slightly nonlinear but well-behaved:
For each input $x$, there is essentially one valid output $y$. That means $p(y|x)$ is unimodal, the conditional mean $\mathbb{E}[y|x]$ is meaningful, and MSE is a valid objective.
A plain MLP trained with MSE converges cleanly. The network learns $\mathbb{E}[y|x]$.
Part 2: The Inverse Problem
Where it breaks
Now swap the axes. Same data, different question:
Because the forward function is non-monotonic, multiple $x$ values map to the same $y$. Now $p(x|y)$ is multi-modal and the conditional mean lies between valid solutions.
The MLP produces $\mathbb{E}[x \mid y]$, a curve through empty space. This is a loss-model mismatch.
Inverse kinematics, vision (depth ambiguity), physics (cause from effect), control and planning. If your data has multiple valid outputs, point regression is wrong.
Part 3: What MDNs Change
Distributions, not points
MDNs do not make the network more powerful. They change what it represents: parameters of a conditional distribution instead of a single value.
Each input $y$ produces K candidate solutions. Each has a probability mass. The Σ is just a loop:
def gmm_probability(x, y, net): pi, mu, sigma = net.forward(y) total = 0 for k in range(K): total += pi[k] * normal_pdf(x, mu[k], sigma[k]) return total
Output Transformations
softmax → sum to 1
raw output
exp → always positive
The Loss: Negative Log-Likelihood
The MDN captures multiple modes. Samples follow the data distribution, not the mean.
Playground
Draw your own
Draw a distribution and compare MLP vs MDN:
The Takeaway
# MLP: single output y_pred = network(x) # point estimate loss = mse(y_pred, y_true) # MDN: distribution parameters pi, mu, sigma = network(x) # K components p = sum(pi[k] * gaussian(y, mu[k], sigma[k]) for k in range(K)) loss = -log(p) # NLL
MDNs are not magic. They change what the network is allowed to output. When your conditional distribution is multi-modal, that change is necessary.