axonml-train)use axonml::prelude::*;
use axonml_nn::{Module, CrossEntropyLoss};
use axonml_optim::{Adam, Optimizer};
fn train<M: Module>(
model: &mut M,
batches: &[(Variable, Variable)],
epochs: usize,
) {
let mut optimizer = Adam::new(model.parameters(), 0.001);
let loss_fn = CrossEntropyLoss::new();
for epoch in 0..epochs {
model.train();
let mut total_loss = 0.0;
for (inputs, targets) in batches {
let output = model.forward(inputs);
let loss = loss_fn.compute(&output, targets);
optimizer.zero_grad();
loss.backward();
optimizer.step();
total_loss += loss.data().to_vec()[0];
}
println!("Epoch {}: Loss = {:.4}", epoch, total_loss / batches.len() as f32);
}
}
All optimizers take Vec<Parameter> and a learning rate. Options are set via a with_options(...) constructor or chainable builder methods.
use axonml_optim::SGD;
let opt = SGD::new(model.parameters(), 0.01);
// Builder chain
let opt = SGD::new(model.parameters(), 0.01)
.momentum(0.9)
.nesterov(true)
.weight_decay(1e-4);
// Or all-at-once:
let opt = SGD::with_options(params, /*lr=*/0.01, /*momentum=*/0.9, /*weight_decay=*/1e-4, /*nesterov=*/true);
use axonml_optim::{Adam, AdamW};
let opt = Adam::new(model.parameters(), 0.001);
// `betas` takes a tuple
let opt = Adam::new(model.parameters(), 0.001)
.betas((0.9, 0.999))
.eps(1e-8)
.weight_decay(0.0);
// AdamW (decoupled weight decay)
let opt = AdamW::new(model.parameters(), 0.001)
.betas((0.9, 0.999))
.weight_decay(0.01);
use axonml_optim::LAMB;
// LAMB's builder takes two f32s, not a tuple
let opt = LAMB::new(model.parameters(), 0.001)
.betas(0.9, 0.999)
.weight_decay(0.01);
use axonml_optim::RMSprop;
let opt = RMSprop::new(model.parameters(), 0.01)
.alpha(0.99)
.eps(1e-8)
.momentum(0.0);
All schedulers take an &impl Optimizer and implement the LRScheduler trait. Call scheduler.step(&mut optimizer) once per epoch (or step, depending on the scheduler).
use axonml_optim::{Adam, Optimizer, StepLR, MultiStepLR, ExponentialLR,
CosineAnnealingLR, OneCycleLR, WarmupLR, ReduceLROnPlateau,
LRScheduler};
let mut opt = Adam::new(model.parameters(), 0.1);
// Step: decay by gamma every step_size epochs
let mut sch = StepLR::new(&opt, /*step_size=*/10, /*gamma=*/0.1);
// MultiStep: decay at specific milestones
let mut sch = MultiStepLR::new(&opt, vec![30, 60, 90], 0.1);
// Exponential
let mut sch = ExponentialLR::new(&opt, 0.95);
// Cosine annealing — takes t_max (period in steps)
let mut sch = CosineAnnealingLR::new(&opt, 100);
// OneCycle — max_lr + total_steps
let mut sch = OneCycleLR::new(&opt, /*max_lr=*/0.1, /*total_steps=*/1000);
// Linear warmup
let mut sch = WarmupLR::new(&opt, /*warmup_steps=*/1000);
// Reduce on plateau — takes no threshold at construction
let mut sch = ReduceLROnPlateau::new(&opt);
// scheduler.step() takes a metric for ReduceLROnPlateau
Located in axonml_autograd::amp:
use axonml_autograd::amp::{autocast, AutocastGuard, is_autocast_enabled, autocast_dtype};
use axonml_core::DType;
// Function-scoped autocast
let output = autocast(DType::F16, || {
model.forward(&input)
});
// RAII guard
{
let _guard = AutocastGuard::new(DType::F16);
let output = model.forward(&input);
// guard dropped -> autocast disabled
}
if is_autocast_enabled() {
println!("Autocast is enabled at {:?}", autocast_dtype());
}
Located in axonml_optim::grad_scaler:
use axonml_optim::GradScaler;
let mut scaler = GradScaler::new();
// or: let mut scaler = GradScaler::with_scale(2_f32.powi(16));
// Inside the loop:
let scaled_loss_value = scaler.scale_loss(loss_value);
loss.backward();
// Unscale + inf/nan check — skips the optimizer step if non-finite grads
let mut grads: Vec<f32> = /* collect flat grads */ Vec::new();
if scaler.unscale_grads(&mut grads) {
optimizer.step();
}
Located in axonml_autograd::checkpoint:
use axonml_autograd::checkpoint::{checkpoint, checkpoint_sequential};
// Single-function checkpoint
let output = checkpoint(|x: &Variable| heavy_layer.forward(x), &input);
// Sequential layers, split into segments (trades recompute for peak memory)
let output = checkpoint_sequential(
/*num_layers=*/24, /*segments=*/4, &input,
|layer_idx, x| layers[layer_idx].forward(x),
);
Clipping utilities live in axonml_train::trainer:
use axonml_train::clip_grad_norm;
let total_norm = clip_grad_norm(&model.parameters(), /*max_norm=*/1.0);
no_grad is a context manager in axonml_autograd:
use axonml_autograd::no_grad;
fn evaluate<M: Module>(model: &mut M, batches: &[(Variable, Variable)]) -> f32 {
model.eval();
no_grad(|| {
let mut correct = 0;
let mut total = 0;
for (inputs, targets) in batches {
let output = model.forward(inputs);
// Compare argmax with targets; add your own logic here.
total += output.shape()[0];
// correct += ...
}
100.0 * correct as f32 / total.max(1) as f32
})
}
axonml-train)The higher-level training glue lives in the dedicated axonml-train crate (split out of the umbrella in April 2026):
use axonml_train::{
TrainingConfig, TrainingHistory, TrainingMetrics, Callback, EarlyStopping,
ProgressLogger, clip_grad_norm, compute_accuracy,
};
// Model benchmarking
use axonml_train::{benchmark_model, throughput_test, profile_model_memory, ThroughputConfig};
// Adversarial training
use axonml_train::{AdversarialTrainer, fgsm_attack, pgd_attack};
Every training binary in llm-training uses a shared lifecycle.rs with pause/resume/stop/checkpoint signal handlers so weeks-long runs survive process restart, plus a train_ctl control binary. Per project policy, every training binary ships with the live training monitor (axonml::TrainingMonitor) — it is not opt-out.
Detection has a dedicated guide — see Object Detection Training for Nexus / Phantom / NightVision.
Quick reference:
use axonml_vision::losses::{FocalLoss, GIoULoss, UncertaintyLoss, compute_centerness};
use axonml_nn::{BCEWithLogitsLoss, SmoothL1Loss};
let focal = FocalLoss::new(); // alpha=0.25, gamma=2.0
let focal2 = FocalLoss::with_params(0.25, 2.0);
let smooth = SmoothL1Loss::new();
let bce = BCEWithLogitsLoss::new();
// GIoU is a bare compute function
let gl = GIoULoss::compute(&pred_boxes, &target_boxes);
// Nexus / Phantom training steps:
use axonml_vision::training::{nexus_training_step, phantom_training_step};
// Each runs forward → target assignment → loss → backward → optimizer step.
Evaluation:
use axonml_vision::training::{compute_ap, compute_map, compute_coco_map};
let ap = compute_ap(&detections, &ground_truths, 0.5);
let m = compute_map(&all_dets, &all_gts, num_classes, 0.5);
let cm = compute_coco_map(&all_dets, &all_gts, num_classes);
The Aegis Biometric Suite (axonml-vision::models::biometric) ships with specialty losses and GPU training pipelines for all modalities:
use axonml_vision::models::biometric::losses::{
ArgusLoss, EchoLoss, ContrastiveLoss, CenterLoss, AngularMarginLoss,
CrystallizationLoss, ThemisLoss, LivenessLoss,
};
let argus_loss = ArgusLoss::new(num_classes, embed_dim);
let echo_loss = EchoLoss::new(margin);
let themis = ThemisLoss::new();
Training examples are wired up as example binaries: train_mnemosyne (LFW face), train_argus (CASIA-Iris), train_ariadne (FVC2000 fingerprint), plus bench_mnemosyne for verification-pair ROC-AUC / EER / FAR/FRR.
NightVision (YOLOX-inspired, thermal) has preset configs for each thermal domain:
use axonml_vision::models::nightvision::{NightVision, NightVisionConfig};
let model = NightVision::new(NightVisionConfig::wildlife(20)); // 20 animal species
let model = NightVision::new(NightVisionConfig::human()); // search & rescue
let model = NightVision::new(NightVisionConfig::interstellar(3, 3)); // 3-band, 3 classes
let model = NightVision::new(NightVisionConfig::multi_domain(50)); // all domains + domain tag
let model = NightVision::new(NightVisionConfig::edge(10)); // compact
Both model parameters and input tensors must live on the same device — Error::DeviceMismatch is the most common GPU training error:
use axonml_core::Device;
let device = Device::Cuda(0);
// Move model (walks Module::parameters and calls Parameter::to_device)
// Exact API lives on the Module trait's default `to_device` impl.
// Each batch:
for (inputs, targets) in batches {
let x = Variable::new(inputs.to_device(device).unwrap(), false);
let y = Variable::new(targets.to_device(device).unwrap(), false);
let output = model.forward(&x);
// loss, backward, step ...
}
Last updated: 2026-04-16 (v0.6.1)