Quantization Aware Training Implementation Guide

⋮⋮

Table of Contents

Quantization Aware Training

Basic information

LoRA:

  • Old method to improve domain specific fine-tuning faces disadvantages like sequential processing bottleneck. So LoRA changes the feed forward layer to self.linear(x) + (x @ self.lora_A @ self.lora_B) * self.scaling .

PRT (Precision Range Test)

  • Start from the lowest bit and detect the predetermined threshold to record the B_min. Then the B_max should be determined by the max precision you will be need to experiment with.

Implementation details

Loading weights

For model from the transformer, you should remember to import the weight of the projection layer

Initializing the model

Use apply function in torch to optimize the initialization of different layers. In the current case, you should initialize transformer layer with its final projection layer and also two layers of ffn.

Inheritance of nn.Module and torch.autograd.Function

APPENDIX ON torch and transformer usage

Important functions in torch

  1. torch.save: uses the pickle for operations. For tensors, its raw data, size information, gradient requirements. For models, it mainly store the state_dict which is an orderedDict that maps each layer or params name to its tensor values.
  2. torch.amp.GradScaler(‘cuda’): Automatic Mixed Precision
  3. torch.nn.module: its handy to inherit this class for your customized model. You can just do model(input_params) to call the “_call_” function inherited to perform ffn.
  4. For dataset objects imported with the load_dataset from the datasets library. Its handy to call the .feature property or you can call the _dict_ method to understand the structure.
  5. return_tensor = “pt” adds a dimension so remember to do the [0] for the tensor.
  6. While loading information from a dataset, pay attention to the dataset padding token and eos token, if their choice is the same, you should change the padding token to be something else

Important functions in transformer

  1. GPT2Config return an json with all those configurations and then could be utilized by other tuning methods.
  2. model.eval() turn on the evaluation mode and disable those dropout,

APPENDIX ON GPU RAM usage

GPU Memory Hierarchy & Parameter Impact Guide

Memory Hierarchy Overview

Memory TypeSize (A100)BandwidthLatencyWhat’s Stored
Registers256 KB/SM~19 TB/s1 cycleActive thread variables, loop counters
L1 Cache/SMEM192 KB/SM~19 TB/s~30 cyclesShared memory, frequently accessed data
L2 Cache40 MB~4 TB/s~200 cyclesRecently accessed data from HBM
HBM (VRAM)40-80 GB~2 TB/s~400 cyclesModel weights, activations, optimizer states

Parameter Impact on Memory Usage

ParameterHBM UsageL1/SMEM UsageRegister UsageImpact Description
batch_sizeHigh 🔴Medium 🟡Low 🟢Multiplies activation memory linearly
model_sizeHigh 🔴Low 🟢Low 🟢Weights stored entirely in HBM
sequence_lengthHigh 🔴Medium 🟡Low 🟢Quadratic for attention (seq_len²)
hidden_dimHigh 🔴Medium 🟡Low 🟢Affects weight matrices & activations
num_layersHigh 🔴Low 🟢Low 🟢Linear increase in weights
precision (FP32/16/8)High 🔴Medium 🟡Medium 🟡Halves memory per precision drop
gradient_accumulationLow 🟢Low 🟢Low 🟢Reduces batch memory requirement
optimizer (SGD/Adam)High 🔴Low 🟢Low 🟢Adam uses 3x memory (m, v states)

Detailed HBM Storage Breakdown

ComponentFormulaFP32 MemoryFP16 MemoryStored Location
Model Weightsnum_params × precision4 bytes/param2 bytes/paramHBM
Gradientsnum_params × precision4 bytes/param2 bytes/paramHBM
Adam Optimizer2 × num_params × FP328 bytes/param8 bytes/param*HBM
Activationsbatch × seq_len × hidden × layersVariableVariable/2HBM
KV Cache (LLMs)batch × heads × seq_len × dim × layers × 2LargeLarge/2HBM
Temp Buffersworkspace for ops~1-2 GB~0.5-1 GBHBM

*Adam states typically stay FP32 even in mixed precision

Kernel-Level Memory Usage

OperationRegister PressureL1/SMEM UsageHBM Access Pattern
GEMM (MatMul)High 🔴High 🔴Tiled access
Element-wiseMedium 🟡Low 🟢Sequential streaming
SoftmaxMedium 🟡Medium 🟡Row-wise access
LayerNormMedium 🟡Medium 🟡Channel-wise access
AttentionHigh 🔴High 🔴Complex tiling
Conv2DHigh 🔴High 🔴Im2col or tiled

Optimization Strategies by Memory Type

Memory TypeOptimization StrategyImpact
HBMGradient checkpointing, model sharding, mixed precisionReduce total storage
L2 CacheIncrease arithmetic intensity, kernel fusionReduce HBM traffic
L1/SMEMTile size tuning, shared memory allocationBetter data reuse
RegistersLoop unrolling, reduce live variablesHigher throughput

Practical Example: GPT-2 Medium (345M Parameters)

Memory Breakdown

ComponentCalculationMemory Usage
Parameters345M params × 4 bytes1.4 GB (FP32)
Gradients345M params × 4 bytes1.4 GB (FP32)
Adam States345M × 2 × 4 bytes2.8 GB (FP32)
Activationsbatch=8, seq=1024, ~20 layers~4 GB
Total TrainingSum of above~9.6 GB
Inference OnlyParameters only~1.4 GB

Memory Usage by Precision

PrecisionWeightsGradientsAdamActivationsTotal Training
FP321.4 GB1.4 GB2.8 GB4 GB9.6 GB
FP16 Mixed0.7 GB0.7 GB2.8 GB2 GB6.2 GB
INT80.35 GBN/AN/A1 GB1.35 GB (Inference)

Memory Calculation Formulas

Training Memory

Total_Memory = Model_Weights + Gradients + Optimizer_States + Activations + Temp_Buffers

Model Weights

Model_Memory = num_parameters × bytes_per_param

Activation Memory (Transformer)

Activation_Memory = batch_size × seq_length × hidden_dim × num_layers × 
                   (attention_heads + mlp_ratio + norm_layers)

Attention Memory

Attention_Memory = batch_size × num_heads × seq_length² × head_dim × num_layers

Common Memory Bottlenecks

BottleneckSymptomsSolution
OOM on forward passCrashes during model(input)Reduce batch size or model size
OOM on backward passCrashes during loss.backward()Enable gradient checkpointing
OOM on optimizer stepCrashes during optimizer.step()Use gradient accumulation or efficient optimizer
Slow trainingLow GPU utilizationIncrease batch size or arithmetic intensity
Memory fragmentationOOM with available memoryClear cache: torch.cuda.empty_cache()