Hijacob, AxonML author here. Our autograd is ~3K lines of Rust. Here's the actual architecture:
Three core pieces:
1. The GradientFunction trait — every differentiable op implements this:
pub trait GradientFunction: Debug + Send + Sync {
// Given dL/d(output), compute dL/d(each input)
fn apply(&self, grad_output: &Tensor<f32>) -> Vec<Option<Tensor<f32>>>;
// Linked list of parent grad functions (the "tape" edges)
fn next_functions(&self) -> &[Option<GradFn>];
fn name(&self) -> &'static str;
}
GradFn is just an Arc<dyn GradientFunction> wrapper — cheap to clone, identity via Arc pointer address.
2. Forward pass builds the graph implicitly. Every op creates a backward node with saved tensors + links to its
inputs' grad functions:
// Multiplication: d/dx(x*y) = y, d/dy(x*y) = x
pub struct MulBackward {
next_fns: Vec<Option<GradFn>>, // parent grad functions
saved_lhs: Tensor<f32>, // saved for backward
saved_rhs: Tensor<f32>,
}
impl GradientFunction for MulBackward {
fn apply(&self, grad_output: &Tensor<f32>) -> Vec<Option<Tensor<f32>>> {
let grad_lhs = grad_output.mul(&self.saved_rhs).unwrap();
let grad_rhs = grad_output.mul(&self.saved_lhs).unwrap();
vec![Some(grad_lhs), Some(grad_rhs)]
}
fn next_functions(&self) -> &[Option<GradFn>] { &self.next_fns }
}
The Variable wrapper connects it:
pub fn mul_var(&self, other: &Variable) -> Variable {
let result = self.data() * other.data();
let grad_fn = GradFn::new(MulBackward::new(
self.grad_fn.clone(), // link to lhs's grad_fn
other.grad_fn.clone(), // link to rhs's grad_fn
self.data(), other.data(), // save for backward
));
Variable::from_operation(result, grad_fn, true)
}
3. Backward pass = DFS topological sort, then reverse walk. This is the whole engine:
pub fn backward(output: &Variable, grad_output: &Tensor<f32>) {
let grad_fn = output.grad_fn().unwrap();
// Topological sort via post-order DFS
let mut topo_order = Vec::new();
let mut visited = HashSet::new();
build_topo_order(&grad_fn, &mut topo_order, &mut visited);
// Walk in reverse, accumulate gradients
let mut grads: HashMap<GradFnId, Tensor<f32>> = HashMap::new();
grads.insert(grad_fn.id(), grad_output.clone());
for node in topo_order.iter().rev() {
let grad = grads.get(&node.id()).unwrap();
let input_grads = node.apply(&grad); // chain rule
for (i, next_fn) in node.next_functions().iter().enumerate() {
if let Some(next) = next_fn {
if let Some(ig) = &input_grads[i] {
grads.entry(next.id())
.and_modify(|g| *g = g.add(ig).unwrap()) // accumulate
.or_insert(ig.clone());
}
}
}
}
}
Leaf variables use AccumulateGrad — a special GradientFunction that writes the gradient into the Variable's shared
Arc<RwLock<Option<Tensor>>> instead of propagating further. That's how x.grad() works after backward.
Key Rust-specific decisions:
- Thread-local graph (thread_local! + HashMap<NodeId, GraphNode>) — no global lock contention, each thread gets its
own tape
- Arc<dyn GradientFunction> for the linked-list edges — trait objects give polymorphism, Arc gives cheap cloning and
stable identity (pointer address = node ID)
- parking_lot::RwLock over std::sync — faster uncontended reads for the gradient accumulators
- Graph cleared after backward (like PyTorch's retain_graph=False) — we learned this the hard way when GRU training
with 120 timesteps leaked ~53GB via accumulated graph nodes
The "tape" isn't really a flat tape — it's a DAG of GradFn nodes linked via next_functions(). The topological sort
flattens it into an execution order at backward time. This is the same design as PyTorch's C++ autograd engine, just
in Rust with ownership semantics doing a lot of the memory safety work for free.
AutomataNexus|21 hours ago