top | item 47201628

(no title)

jacobn | 1 day ago

Cool! How do you actually implement “Reverse-mode automatic differentiation with a tape-based computational graph” in rust?

discuss

order

AutomataNexus|21 hours ago

Hijacob, AxonML author here. Our autograd is ~3K lines of Rust. Here's the actual architecture:

  Three core pieces:

  1. The GradientFunction trait — every differentiable op implements this:

  pub trait GradientFunction: Debug + Send + Sync {
      // Given dL/d(output), compute dL/d(each input)
      fn apply(&self, grad_output: &Tensor<f32>) -> Vec<Option<Tensor<f32>>>;
      // Linked list of parent grad functions (the "tape" edges)
      fn next_functions(&self) -> &[Option<GradFn>];
      fn name(&self) -> &'static str;
  }

  GradFn is just an Arc<dyn GradientFunction> wrapper — cheap to clone, identity via Arc pointer address.

  2. Forward pass builds the graph implicitly. Every op creates a backward node with saved tensors + links to its
  inputs' grad functions:

  // Multiplication: d/dx(x*y) = y, d/dy(x*y) = x
  pub struct MulBackward {
      next_fns: Vec<Option<GradFn>>,  // parent grad functions
      saved_lhs: Tensor<f32>,         // saved for backward
      saved_rhs: Tensor<f32>,
  }

  impl GradientFunction for MulBackward {
      fn apply(&self, grad_output: &Tensor<f32>) -> Vec<Option<Tensor<f32>>> {
          let grad_lhs = grad_output.mul(&self.saved_rhs).unwrap();
          let grad_rhs = grad_output.mul(&self.saved_lhs).unwrap();
          vec![Some(grad_lhs), Some(grad_rhs)]
      }
      fn next_functions(&self) -> &[Option<GradFn>] { &self.next_fns }
  }

  The Variable wrapper connects it:

  pub fn mul_var(&self, other: &Variable) -> Variable {
      let result = self.data() * other.data();
      let grad_fn = GradFn::new(MulBackward::new(
          self.grad_fn.clone(),   // link to lhs's grad_fn
          other.grad_fn.clone(),  // link to rhs's grad_fn
          self.data(), other.data(),  // save for backward
      ));
      Variable::from_operation(result, grad_fn, true)
  }

  3. Backward pass = DFS topological sort, then reverse walk. This is the whole engine:

  pub fn backward(output: &Variable, grad_output: &Tensor<f32>) {
      let grad_fn = output.grad_fn().unwrap();

      // Topological sort via post-order DFS
      let mut topo_order = Vec::new();
      let mut visited = HashSet::new();
      build_topo_order(&grad_fn, &mut topo_order, &mut visited);

      // Walk in reverse, accumulate gradients
      let mut grads: HashMap<GradFnId, Tensor<f32>> = HashMap::new();
      grads.insert(grad_fn.id(), grad_output.clone());

      for node in topo_order.iter().rev() {
          let grad = grads.get(&node.id()).unwrap();
          let input_grads = node.apply(&grad);  // chain rule

          for (i, next_fn) in node.next_functions().iter().enumerate() {
              if let Some(next) = next_fn {
                  if let Some(ig) = &input_grads[i] {
                      grads.entry(next.id())
                          .and_modify(|g| *g = g.add(ig).unwrap())  // accumulate
                          .or_insert(ig.clone());
                  }
              }
          }
      }
  }

  Leaf variables use AccumulateGrad — a special GradientFunction that writes the gradient into the Variable's shared
  Arc<RwLock<Option<Tensor>>> instead of propagating further. That's how x.grad() works after backward.

  Key Rust-specific decisions:

  - Thread-local graph (thread_local! + HashMap<NodeId, GraphNode>) — no global lock contention, each thread gets its
  own tape
  - Arc<dyn GradientFunction> for the linked-list edges — trait objects give polymorphism, Arc gives cheap cloning and
  stable identity (pointer address = node ID)
  - parking_lot::RwLock over std::sync — faster uncontended reads for the gradient accumulators
  - Graph cleared after backward (like PyTorch's retain_graph=False) — we learned this the hard way when GRU training
  with 120 timesteps leaked ~53GB via accumulated graph nodes

  The "tape" isn't really a flat tape — it's a DAG of GradFn nodes linked via next_functions(). The topological sort
  flattens it into an execution order at backward time. This is the same design as PyTorch's C++ autograd engine, just
  in Rust with ownership semantics doing a lot of the memory safety work for free.