The Memory and Computation Cost of GPT
Memory
- Activation values: 50 times the parameter memory
- Optimizer state: 2 times the parameter memory
- Gradients: 1 times the parameter memory
- KV Cache: 0.5 times the memory
Number of Parameters
- Number of parameters: approximately (number):$12h^2l$
Memory Usage
Independent of input data
-
Memory usage during training: Excluding activation values
- Parameters: float16 + float32:$\Phi$
- Gradients: float16 + float32:$\Phi$
- Parameters for AdamW: float32:$2\Phi$
- Total:$20\Phi Bytes$
- Parameters: float16 + float32:$\Phi$
-
Memory usage during inference: Excluding activation values
-
Parameters: float16:$\Phi$
-
Total: $2\Phi Bytes$
-
Activation Values
Approximately: per unit (Bytes):$(34bsh + 5bs^2a)l$
- When b=1, approximately 0.8 times the parameters
KV Cache
Approximately: per unit (number):$2b(s + n)l$
- Approximately 0.5 times the parameter memory
Computation
Computational Complexity
Approximately: per unit (count):$24bsh^2l$
Relationship between Computational Complexity and Number of Parameters
Approximately: 2 times the number of tokens * number of parameters
$$ \frac{24bsh^2l} {12h^2l * bs} = 2 $$
Computation Time
- Backward propagation is approximately twice the computation of forward propagation.
- Activation value recomputation is approximately the same as forward propagation.
- Total is 8 times.
$$ \frac{8 * number of tokens * number of parameters} {number of GPUs * peak flops per GPU * GPU utilization} $$
- GPU utilization is approximately between 0.3 and 0.55.