Bit Representation:

Algorithms:

Kernels and Fusing:

High-Quality Data:

Future Directions:

References to Key Research:

Resources:

PyTorch Training Tips

The memory-efficient weight loading technique you’ve explored focuses on minimizing the memory overhead when loading large pretrained or fine-tuned models in PyTorch. Here’s a breakdown of the key approaches and their rationale:

Here’s a structured notebook-style code representation of the conversation, with explanations embedded as comments:

Memory-Efficient Model Weight Loading in PyTorch

# Step 1: Import necessary libraries
import torch
import gc
import time
import os
import psutil
from threading import Thread

# Step 2: Utilities for memory tracking and cleanup
def start_memory_tracking():
    """Initialize GPU memory tracking."""
    if torch.cuda.is_available():
        torch.cuda.reset_peak_memory_stats()
    else:
        print("CUDA is not available. GPU memory tracking will not work.")

def print_memory_usage():
    """Print the maximum GPU memory allocated."""
    max_gpu_memory = torch.cuda.max_memory_allocated() / (1024 ** 3)  # Convert bytes to GB
    print(f"Maximum GPU memory allocated: {max_gpu_memory:.1f} GB")

def cleanup():
    """Cleanup unused memory."""
    gc.collect()
    torch.cuda.empty_cache()
    time.sleep(3)  # Allow memory to clear
    torch.cuda.reset_peak_memory_stats()

def memory_usage_in_gb(func, *args, **kwargs):
    """Monitor CPU memory usage during the execution of a function."""
    process = psutil.Process(os.getpid())
    baseline_mem = process.memory_info().rss / 1024 ** 3  # Convert to GB

    mem_usage = []
    done = False

    def monitor_memory():
        while not done:
            mem_usage.append(process.memory_info().rss / 1024 ** 3)  # Convert to GB
            time.sleep(0.1)

    t = Thread(target=monitor_memory)
    t.start()

    # Run the function
    func(*args, **kwargs)

    done = True
    t.join()
    peak_mem_usage_gb = max(mem_usage) - baseline_mem
    return peak_mem_usage_gb

# Step 3: Define the model setup
class GPTModel(torch.nn.Module):
    def __init__(self, config):
        super().__init__()
        # Define a simple transformer-like structure for demonstration
        self.embedding = torch.nn.Embedding(config["vocab_size"], config["emb_dim"])
        self.layers = torch.nn.ModuleList([
            torch.nn.TransformerEncoderLayer(
                d_model=config["emb_dim"],
                nhead=config["n_heads"],
                batch_first=True
            )
            for _ in range(config["n_layers"])
        ])
    
    def forward(self, x):
        x = self.embedding(x)
        for layer in self.layers:
            x = layer(x)
        return x

BASE_CONFIG = {
    "vocab_size": 50257,
    "context_length": 1024,
    "drop_rate": 0.0,
    "qkv_bias": True
}

model_configs = {
    "gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "n_heads": 12},
    "gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16},
    "gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20},
    "gpt2-xl (1558M)": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25},
}

# Choose model configuration
CHOOSE_MODEL = "gpt2-large (774M)"
BASE_CONFIG.update(model_configs[CHOOSE_MODEL])

# Step 4: Memory-efficient weight loading methods
def load_weights_sequentially():
    """Load model weights sequentially to save GPU memory."""
    start_memory_tracking()

    model = GPTModel(BASE_CONFIG).to(device)
    state_dict = torch.load("model.pth", map_location="cpu")

    # Sequentially copy weights
    with torch.no_grad():
        for name, param in model.named_parameters():
            if name in state_dict:
                param.copy_(state_dict[name].to(device))
            else:
                print(f"Warning: {name} not found in state_dict.")

    print_memory_usage()

def load_with_meta_device():
    """Load model weights using PyTorch's meta device."""
    start_memory_tracking()

    with torch.device("meta"):
        model = GPTModel(BASE_CONFIG)

    model = model.to_empty(device=device)
    state_dict = torch.load("model.pth", map_location=device)

    # Sequentially copy weights
    with torch.no_grad():
        for name, param in model.named_parameters():
            if name in state_dict:
                param.copy_(state_dict[name])

    print_memory_usage()

def load_with_mmap():
    """Load model weights with mmap=True for memory efficiency."""
    start_memory_tracking()

    with torch.device("meta"):
        model = GPTModel(BASE_CONFIG)

    model.load_state_dict(
        torch.load("model.pth", map_location=device, mmap=True)
    )

    print_memory_usage()

def load_weights_individually():
    """Save and load each weight tensor separately."""
    param_dir = "model_parameters"

    # Save individual parameters
    state_dict = torch.load("model.pth")
    os.makedirs(param_dir, exist_ok=True)
    for name, param in state_dict.items():
        torch.save(param.cpu(), os.path.join(param_dir, f"{name}.pt"))

    # Load individual parameters
    start_memory_tracking()

    with torch.device("meta"):
        model = GPTModel(BASE_CONFIG)

    model = model.to_empty(device=device)

    with torch.no_grad():
        for name, param in model.named_parameters():
            weight_path = os.path.join(param_dir, f"{name}.pt")
            if os.path.exists(weight_path):
                param.copy_(torch.load(weight_path, map_location="cpu").to(device))

    print_memory_usage()

# Example usage
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Uncomment the desired method to test
# load_weights_sequentially()
# load_with_meta_device()
# load_with_mmap()
# load_weights_individually()

Explanation of Methods

  1. Sequential Loading:
    • Loads model weights one by one, transferring them from CPU to GPU, minimizing GPU memory overhead.
  2. Meta Device:
    • Instantiates a “meta” model with no memory allocation initially and directly loads weights into GPU, reducing CPU memory usage.
  3. Memory Mapping (mmap=True):
    • Allows direct access to weights from disk instead of loading them fully into memory, ideal for low-memory systems.
  4. Individual Tensor Loading:
    • Saves each tensor as a separate file, loads them one at a time during model setup, and immediately applies the weight, ensuring minimal memory usage.

This approach ensures efficient utilization of memory resources while dealing with large models in PyTorch.

1. The Problem

When loading a large model’s weights using torch.load(model.pth), you often end up with the model in GPU memory twice (once during model instantiation and again when loading the state_dict). This can be problematic on systems with limited GPU or CPU memory, especially for larger models like GPT-2 (XL).

2. Sequential Weight Loading

A solution to avoid loading the model twice into memory is sequential loading. The idea here is:

By doing this, you only hold one parameter tensor in memory at a time during the transfer process. This method significantly reduces peak memory usage compared to the default approach.

3. Using the meta Device

The meta device in PyTorch allows you to define the model’s structure without immediately allocating memory. By using this, you can bypass CPU memory entirely and directly load weights onto the GPU, which is especially useful when CPU memory is constrained.

4. Memory Mapping (mmap=True)

The mmap=True option in PyTorch enables memory-mapped file I/O, allowing data to be accessed directly from disk rather than being fully loaded into RAM. This method is effective in environments where both CPU and GPU memory are limited but disk space is sufficient.

5. Loading Weights Individually

An even more memory-conservative option is to save and load each weight tensor separately:

Key Takeaways:

This combination of techniques allows you to load models in resource-constrained environments while optimizing memory use for efficient deep learning workflows.

Random Coding Stuff Saved

Here’s a collection of coding tricks and techniques I’ve saved for handling tricky cases while training models in PyTorch. These are things I often refer back to when working on projects.

1. Gradient Accumulation (For Large Batch Sizes)

When my GPU memory can’t handle large batch sizes, I use gradient accumulation over several smaller batches:

accumulation_steps = 4  # Adjust based on your memory constraints
optimizer.zero_grad()

for i, (inputs, labels) in enumerate(data_loader):
    outputs = model(inputs)
    loss = loss_fn(outputs, labels)
    loss = loss / accumulation_steps  # Normalize loss for accumulation
    loss.backward()

    if (i + 1) % accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad()

# Handle any remaining gradients
if (i + 1) % accumulation_steps != 0:
    optimizer.step()
    optimizer.zero_grad()

2. Mixed Precision Training

For faster computation and reduced memory usage, I use torch.cuda.amp for mixed precision training:

scaler = torch.cuda.amp.GradScaler()

for inputs, labels in data_loader:
    optimizer.zero_grad()
    
    with torch.cuda.amp.autocast():
        outputs = model(inputs)
        loss = loss_fn(outputs, labels)
    
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

3. Model Checkpointing for Resuming Training

I always save and load checkpoints to resume training efficiently:

# Save checkpoint
torch.save({
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss,
}, 'checkpoint.pth')

# Load checkpoint
checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch']
loss = checkpoint['loss']

4. Custom Learning Rate Scheduler

For complex requirements, I create my own learning rate scheduler:

def custom_lr_scheduler(epoch):
    if epoch < 10:
        return 0.001
    elif epoch < 20:
        return 0.0001
    else:
        return 0.00001

for epoch in range(num_epochs):
    lr = custom_lr_scheduler(epoch)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    # Training loop

5. Distributed Data Parallel (DDP) Training

When I train models on multiple GPUs, I use DDP:

import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

dist.init_process_group("nccl", init_method="env://")
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)

model = model.to(local_rank)
model = DDP(model, device_ids=[local_rank])

# Usual training loop

6. Gradient Clipping

To prevent exploding gradients, I clip them:

torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

# Inside training loop
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()

7. Handling Imbalanced Datasets

For imbalanced datasets, I use weighted loss functions:

from torch.nn import CrossEntropyLoss

class_weights = torch.tensor([0.1, 0.9]).to(device)  # Adjust weights
loss_fn = CrossEntropyLoss(weight=class_weights)

# Training loop
loss = loss_fn(outputs, labels)

8. Debugging NaNs in Gradients

Sometimes I run into NaNs in gradients and debug them like this:

for name, param in model.named_parameters():
    if param.grad is not None:
        if torch.isnan(param.grad).any():
            print(f"NaN detected in gradients of {name}")

9. Efficient Data Loading

I use prefetch_factor and persistent_workers to make data loading faster:

data_loader = DataLoader(
    dataset, 
    batch_size=32, 
    shuffle=True, 
    num_workers=4, 
    prefetch_factor=2, 
    persistent_workers=True
)

10. Model Debugging with Hooks

To inspect intermediate layer outputs, I use hooks:

def hook_fn(module, input, output):
    print(f"Layer: {module.__class__.__name__}")
    print(f"Output shape: {output.shape}")

for layer in model.children():
    layer.register_forward_hook(hook_fn)

These snippets have been lifesavers for me when training models in PyTorch. Let me know if you want to dive deeper into any of these!

Scratchpad - Ujjwal Upadhyay