As asked
Walk me through what happens when you call loss.backward() in PyTorch. How does autograd build the computation graph, and what are the memory implications of retaining it?
Sample answer outline
A strong answer explains that PyTorch builds a dynamic computation graph during the forward pass, attaching grad_fn nodes to tensors. backward() traverses the graph in reverse, applying the chain rule at each node. retain_graph=True is needed for multiple backward passes but prevents freeing intermediate activations, which can double GPU memory usage in practice.
Expect these follow-ups
- What happens if you call backward() twice without retain_graph=True?
- How does gradient accumulation work and why is it used in large batch training?