Unlocking the Enigma: Decoding LLM’s Memory Game
Introduction
Language models are growing in size, with models like PaLM having 540B parameters and others like OPT, GPT-3, and BLOOM having around 176B parameters. However, running these models is challenging due to their resource requirements, such as needing multiple expensive GPUs. To address this, techniques like quantization and distillation have been developed to reduce the model size while maintaining performance.
A new course launched for interview preparation
We have launched a new course “Interview Questions and Answers on Large Language Models (LLMs)” series.
This program is designed to bridge the job gap in the global AI industry. It includes 100+ questions and answers from top companies like FAANG and Fortune 500 & 100+ self-assessment questions.
The course offers regular updates, self-assessment questions, community support, and a comprehensive curriculum covering everything from Prompt Engineering and basics of LLM to Supervised Fine-Tuning (SFT) LLM, Deployment, Hallucination, Evaluation, and Agents etc.
Detailed curriculum (Get 50% off using coupon code MED50 for first 10 users)
Free self assessment on LLM (30 MCQs in 30 mins)
Common data types in Machine learning
Large language models memory needs
We’ve seen that training the model uses much more memory than just putting the model on the GPU. This is because there are many components during training that use GPU memory. The components on GPU memory are the following:
- Model Weights 💡
- Optimizer States (2 States) 🧮🔢
- Gradients 📊
- Forward Activations Saved for Gradient Computation 🔄
- Temporary Buffers 📦
- Functionality-Specific Memory 🧠
Lets take an example of 1 billion parameter model how much GPU is required for training & Inference.
GPU size for training
1 Parameter (weights) = 4 bytes (FP32) so 1 billion = 4 * 10⁹ = 4 GB
This is just to load parameters (during inference) but for finetuning models we need to maintain other states as discussed above.
- Optimizer (2 states) => 8 bytes (FP32)
- Gradients => 4 bytes (FP32)
- Activation & temp variable => ~ 8 bytes (FP32)
so to finetune LLM of 1 billion parameter on full precision 32 bits we would need 4 GB (Model weights) + 8 GB (2 optimizer states) + 4 GB (Gradients) + 8 GB (activation & other temp variables) = 24 GB of GPU for training and 4 GB (Model weights) for inference
- 📝 **Note:** This is the size of GPU required when trained on full precision; most of modern LLM is trained with mixed precision.
Quantization # 8 bit optimization
You might have observed above optimization state takes lot of memory & if we quantize optimization states from 32 bit to 8 bit we will approximately save 38% of memory.
To Reduce size of memory required by LLM Quantizing optimizer state will really help but when Quantization is done its likely to loose some bit of information. Challenge is to do Quantization without loosing any bit of information.
Introduction to model quantization
In simple terms model quantization is approximate representation of floating point numbers in integers with certain information loss called quantization error.
In order to perform model quantization from FP16 to Int8, we need to follow below steps
- Find absolute max for the vector
- Find quantization factor: int8 has range between -127 to 127. divide 127 by Max of the vector from step 1
- Multiply each value by quantization factor
- Round each value to nearest integer
- To De-quantize new integer value by quantization factor
What happens when we add outlier in the list. as you keep increasing magnitude of outlier quantization error (reconstruction error) will keep increasing.
In above example In scenario 2 we have added 54.5 extreme outlier in the list, when we quantize and then dequantize values quantization error is very high.
In scenario 2 , reconstruction error for first value (1.2) is very high ( 1.2–0.86 = 0.34 ) where as in first case reconstruction error is negligible.
If there is an outlier in the list reconstruction error is high.
Block-Wise Quantization
In block-wise quantization all values are divided into small blocks and than quantization is done. which reduces outliers in the list interns reduces quantization error while Dequantization
Putting it all together
This is how model training happens with quantized optimization state with almost zero degradation.
- Chunks Optimization state weights into small blocks
- Do block-wise quantization
- Normalize data & find closest 8 bit value
- During training, de-normalize 8 bit value with Quantization factor
- Dequantize optimizer state
- Update optimizer state
You might have observed It has increased few more multiplication steps. Quantization will reduces the memory footprint of LLM but will increase response time for model.
Reducing inference cost of LLM
Large language model like BLOOM 175 B does not fit on commercial GPU, not only for training, inference as well.
They have applied best 8 bit quantization to LLM during inference & observe trend that zero-shot accuracy decreases as model size increases beyond certain threshold values (> 2.7B parameter’s)
To solve this issue researcher has develop a procedure Int8 matrix multiplication for feed-forward and attention projection layers in transformers, which cut the memory needed for inference by half while retaining full precision performance.
LLM.int8(): zero degradation matrix multiplication for Large Language Models
As we have seen above if there is an outlier in the block where we have tried to do quantization , reconstruction error is very high. As magnitude of outlier increases so as quantization error. Additionally, even tiny values like 0.1 can turn into 0, which means we’ll lose all the information.
One interesting observation they have made is as size of model increase number and magnitude of outlier increases. For 125M parameter model only 1.5% of hidden layers values are outlier with magnitude of outlier comparatively smaller (-3 to -7). As you scale up-to 6.7B model 75% of hidden layers values are outlier with magnitude of outlier very large (-40 to -61) which was causing very high in reconstruction error and intern reduce in accuracy of model. as you further increase in size % of outlier does not increase but magnitude keep increasing. But those outliers are very systematic & concentrated in only 6 feature dimensions across the entire transformer.
Solution to this is to have mixed precision decomposition. They perform 16-bit matrix multiplication for the outlier feature dimensions and 8-bit matrix multiplication for the other 99.9% of the dimensions.
In essence, LLM Int8 seeks to complete the matrix multiplication computation in three steps:
- From the input hidden states, extract the outliers (i.e. values that are larger than a certain threshold. threshold 6.0 is sufficient to reduce transformer performance degradation close to zero) by column.
- Perform the matrix multiplication of the outliers in FP16 and the non-outliers in int8.
- Dequantize the non-outlier results and add both outlier and non-outlier results together to receive the full result in FP16.
What about inference time
Modern GPU can multiply 200 elements in the same time it takes to load 1 element from memory i.e. multiplication is 200 times cheaper than loading bits / weights in memory.
Summarization
- Memory Requirements during Training: Training large language models consumes substantial GPU memory due to factors such as model weights, optimizer states, gradients, forward activations, and temporary buffers. For a 1 billion parameter model, training may require 24 GB of GPU memory.
- Quantization for Memory Reduction: Quantization, particularly 8-bit optimization, can reduce memory requirements by approximately 38% but may result in some information loss. Model quantization involves finding max values, determining quantization factors, rounding values, and de-quantizing.
- Block-Wise Quantization: Block-wise quantization divides values into smaller blocks, reducing outliers and quantization errors during dequantization.
- Reducing Inference Costs: as model size increases reconstruction errors in models with more and larger outliers. To mitigate this, a solution is proposed involving mixed precision decomposition, utilizing 16-bit matrix multiplication for outliers and 8-bit for the majority of dimensions. This approach aims to minimize errors and maintain model accuracy during matrix computations.
Subscribe to our Youtube channel for more such updates