As asked
Implement cross-entropy loss from raw logits for a multiclass problem using the log-sum-exp trick. Explain why computing softmax then log then cross-entropy in three steps is numerically unstable.
Sample answer outline
Computing softmax naively (exp(x) / sum(exp(x))) overflows for large logits and underflows for very negative logits. The stable version subtracts the max logit before exponentiation: log_softmax(x) = x - max(x) - log(sum(exp(x - max(x)))). Cross-entropy is then -x[label] + max(x) + log(sum(exp(x - max(x)))). This equals the logsumexp formulation. The key insight is that log and exp are inverses so computing softmax then taking log is wasteful and unstable.
Reference implementation (python)
import numpy as np
def cross_entropy(logits: np.ndarray, label: int) -> float:
"""logits: (vocab,), label: int index"""
shifted = logits - logits.max()
log_sum_exp = np.log(np.exp(shifted).sum())
return -shifted[label] + log_sum_exp
def cross_entropy_batch(
logits: np.ndarray, labels: np.ndarray
) -> float:
"""logits: (N, C), labels: (N,)"""
shifted = logits - logits.max(axis=1, keepdims=True)
log_sum_exp = np.log(np.exp(shifted).sum(axis=1))
correct = shifted[np.arange(len(labels)), labels]
return (-correct + log_sum_exp).mean()Expect these follow-ups
- How does PyTorch's F.cross_entropy implement this under the hood?
- What changes when implementing label smoothing on top of this loss?