AxonML provides comprehensive distributed training support:
| Strategy | Description | Memory | Communication |
|---|---|---|---|
| DDP | Data parallelism with gradient sync | Full model per GPU | All-reduce gradients |
| FSDP | Fully sharded data parallel | Sharded across GPUs | All-gather params |
| Pipeline | Model split across stages | Partitioned model | Point-to-point |
use axonml::distributed::{World, DDP};
fn main() {
// Initialize distributed world
let world = World::init();
let rank = world.rank();
let world_size = world.world_size();
println!("Process {}/{}", rank, world_size);
// Create model
let model = Sequential::new()
.add(Linear::new(784, 256))
.add(ReLU)
.add(Linear::new(256, 10));
// Wrap in DDP
let mut ddp_model = DDP::new(model, world.default_group().clone());
// Broadcast initial parameters from rank 0
ddp_model.sync_parameters();
// Training loop (same as single GPU)
let mut optimizer = Adam::new(ddp_model.parameters(), 0.001);
for epoch in 0..10 {
for (inputs, targets) in train_loader.iter() {
let output = ddp_model.forward(&inputs);
let loss = loss_fn.compute(&output, &targets);
optimizer.zero_grad();
loss.backward();
// Gradient sync happens automatically
ddp_model.sync_gradients();
optimizer.step();
}
}
}
# Launch with 4 GPUs
torchrun --nproc_per_node=4 train.py
# Or with MPI
mpirun -np 4 ./target/release/train
FSDP implements ZeRO optimization stages:
| Stage | Sharded | Memory Reduction |
|---|---|---|
| ZeRO-1 | Optimizer states | ~4x |
| ZeRO-2 | + Gradients | ~8x |
| ZeRO-3 | + Parameters | ~Nx (N = world size) |
use axonml::distributed::{FSDP, ShardingStrategy, CPUOffload};
let model = create_large_model();
// Wrap in FSDP (ZeRO-3 by default)
let fsdp_model = FSDP::new(model, world.default_group().clone())
.sharding_strategy(ShardingStrategy::FullShard) // ZeRO-3
.cpu_offload(CPUOffload::None);
// Training is the same
let output = fsdp_model.forward(&input);
use axonml::distributed::ShardingStrategy;
// ZeRO-2: Shard optimizer + gradients
let fsdp = FSDP::new(model, pg)
.sharding_strategy(ShardingStrategy::ShardGradOp);
// ZeRO-3: Full sharding
let fsdp = FSDP::new(model, pg)
.sharding_strategy(ShardingStrategy::FullShard);
// No sharding (like DDP)
let fsdp = FSDP::new(model, pg)
.sharding_strategy(ShardingStrategy::NoShard);
use axonml::distributed::CPUOffload;
// Offload parameters and gradients to CPU
let fsdp = FSDP::new(model, pg)
.cpu_offload(CPUOffload::Full);
// Offload only optimizer states
let fsdp = FSDP::new(model, pg)
.cpu_offload(CPUOffload::OptimizerStates);
Split model across GPUs by layers:
use axonml::distributed::{Pipeline, PipelineSchedule, PipelineStage};
// Define stages
let stage0 = Sequential::new()
.add(Linear::new(784, 1024))
.add(ReLU);
let stage1 = Sequential::new()
.add(Linear::new(1024, 1024))
.add(ReLU);
let stage2 = Sequential::new()
.add(Linear::new(1024, 10));
// Create pipeline
let pipeline = Pipeline::new(vec![
PipelineStage::new(stage0, Device::CUDA(0)),
PipelineStage::new(stage1, Device::CUDA(1)),
PipelineStage::new(stage2, Device::CUDA(2)),
])
.num_microbatches(8)
.schedule(PipelineSchedule::GPipe);
// Forward pass handles micro-batching
let output = pipeline.forward(&input);
use axonml::distributed::PipelineSchedule;
// GPipe: All forward, then all backward
let pipeline = pipeline.schedule(PipelineSchedule::GPipe);
// 1F1B: Interleaved forward/backward
let pipeline = pipeline.schedule(PipelineSchedule::Interleaved1F1B);
use axonml::distributed::*;
let pg = ProcessGroup::mock();
let mut tensor = Tensor::randn(&[100]);
// All-reduce (sum)
all_reduce_sum(&mut tensor, &pg);
// All-reduce (mean)
all_reduce_mean(&mut tensor, &pg);
// All-gather
let gathered = all_gather(&tensor, &pg);
// Broadcast from rank 0
broadcast(&mut tensor, &pg);
// Reduce-scatter
let scattered = reduce_scatter_sum(&tensor, &pg);
// Barrier (synchronization)
barrier(&pg);
use axonml::distributed::{send, recv};
if rank == 0 {
send(&tensor, 1, &pg); // Send to rank 1
} else if rank == 1 {
let received = recv(0, &pg); // Receive from rank 0
}
use axonml::distributed::{World, ProcessGroup};
let world = World::init();
// Default group (all ranks)
let default_pg = world.default_group();
// Create subgroup
let ranks = vec![0, 1]; // Only ranks 0 and 1
let subgroup = world.new_group(ranks);
// Check membership
if subgroup.contains(world.rank()) {
// Do something for this subgroup
}
For very large layers:
use axonml::distributed::{ColumnParallelLinear, RowParallelLinear};
// Column parallel: split output features
let col_linear = ColumnParallelLinear::new(1024, 4096, &pg);
// Row parallel: split input features
let row_linear = RowParallelLinear::new(4096, 1024, &pg);
// Typical usage in transformer MLP:
// x -> ColumnParallel -> GELU -> RowParallel -> output
let h = col_linear.forward(&x);
let h = h.gelu();
let output = row_linear.forward(&h);
Combine multiple strategies:
// 8 GPUs: 2-way tensor parallel × 4-way data parallel
let tp_group = world.new_group(vec![0, 1]); // GPUs 0-1
let dp_group = world.new_group(vec![0, 2, 4, 6]); // One from each TP group
// Apply tensor parallelism to large layers
let attn = TensorParallelAttention::new(hidden_size, num_heads, &tp_group);
// Wrap entire model in DDP for data parallelism
let ddp_model = DDP::new(model, dp_group);
use axonml::distributed::FSDPMemoryStats;
let stats = fsdp_model.memory_stats();
println!("Peak memory: {} MB", stats.peak_memory_mb);
println!("Current allocation: {} MB", stats.current_allocation_mb);