---
title: "Systems Meet AI: Why I Built a Custom GPT-2 Decoder in Rust"
description: "A systems-oriented guide to building GPT-style decoder blocks and training loops with Rust and tch-rs."
date: "2026-07-02"
tags: [Rust, GPT-2, Deep Learning, Systems Programming, AI Engineering]
keywords: [Rust AI, GPT-2 decoder, tch-rs, libtorch, transformer architecture, systems programming for AI]
image: "/My.jpeg"
imageAlt: "Murat Tut portfolio image"
aiSummary: "This article explains how Rust can be used for AI systems work, focusing on tch-rs, GPT-style decoder blocks, attention, training loops, and the tradeoffs between Python flexibility and compiled systems control."
---

*How to build type-safe, compiled-speed machine learning models and training loops in Rust, moving away from Python's runtime overhead.*

For years, Python has been the default language for artificial intelligence. Its massive ecosystem of frameworks (PyTorch, TensorFlow, JAX) and interactive scripting utilities (Jupyter, Google Colab) make it ideal for fast model development and research.

But as AI workloads transition from research environments to production pipelines, Python’s limitations become clear:
* **The Global Interpreter Lock (GIL)**: Restricts true multi-threaded parallel execution.
* **High Memory Overhead**: Python's dynamic typing and garbage collection consume significant memory.
* **Lack of Compile-Time Checks**: Mismatched tensor shapes and type errors only appear at runtime, often after hours of model execution.

For developers seeking systems-level control, **Rust** is emerging as a powerful alternative. By using **`tch-rs`**—Rust bindings to PyTorch’s underlying C++ `libtorch` engine—we can write type-safe, compiled-speed training loops with zero garbage collection overhead.

Here is a technical guide to building a custom GPT-2 decoder block and training loop natively in Rust.

---

## The Rust AI Stack: Why tch-rs?

When you write `import torch` in Python, you are actually executing a wrapper around a high-performance C++ engine called `libtorch`. 

The `tch-rs` library provides direct Rust bindings to this C++ library. This means you do not have to write custom CUDA kernels or rewrite deep learning algorithms from scratch. You get the exact same tensor performance, GPU acceleration (CUDA & MPS), and backpropagation mathematics as PyTorch, but with Rust’s memory guarantees and type safety.

```
       ┌──────────────────────────────────────────────────┐
       │             Rust Application Code                │
       └───────────┬──────────────────────────────────────┘
                   │ FFI (Foreign Function Interface)
                   ▼
          tch-rs Rust Bindings
                   │
                   ▼
         PyTorch C++ (libtorch)
                   │
           ┌───────┴───────┐
           ▼               ▼
       CUDA Kernels   MPS Kernels
```

---

## 1. The Mathematics of Causal Self-Attention

In a generative transformer (like GPT-2), the attention layer maps queries ($Q$), keys ($K$), and values ($V$) to calculate word dependencies. In a causal decoder, the model must be prevented from looking at future tokens. We apply a lower-triangular mask matrix ($M$) where future tokens are set to $-\infty$:

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}} + M\right)V$$

where:
* $Q$, $K$, $V$ are Query, Key, and Value matrices projected from the input embeddings:
  $$Q = XW_Q, \quad K = XW_K, \quad V = XW_V$$
* $d_k$ is the dimensionality of the keys, scaling the dot products to prevent vanishing gradients during softmax.
* $M$ is the causal mask matrix:
  $$M_{i,j} = \begin{cases} 0 & \text{if } i \geq j \\ -\infty & \text{if } i < j \end{cases}$$

In Rust, we construct this mask and execute PyTorch's native scaled attention kernel using `tch-rs`:

```rust
// src/model.rs
impl Module for CausalSelfAttention {
    fn forward(&self, x: &Tensor) -> Tensor {
        let kind = x.kind();
        let (_b, t, _c) = x.size3().unwrap(); // Get batch, sequence length, embedding

        // Generate projections
        let k = self.key_linear.forward(&x);
        let q = self.query_linear.forward(&x);
        let v = self.value_linear.forward(&x);

        // Build lower-triangular causal mask
        let mask = Tensor::ones([t, t], (kind, self.device))
            .tril(0)
            .reshape([1, t, t]);

        // Native, memory-efficient scaled attention
        Tensor::scaled_dot_product_attention(&q, &k, &v, Option::Some(mask), 0.1, false, Option::None)
    }
}
```

---

## 2. Managing Memory Contiguity in Rust

In Python/PyTorch, operations like `view` will automatically check or silently raise errors if your tensor memory layout is non-contiguous. 

In Rust's `tch-rs` bindings, attempting to reshape a transposed tensor without calling `.contiguous()` will trigger a panic in the underlying C++ `libtorch` engine. This is because transpositions only swap layout metadata rather than moving data blocks on disk. 

When stacking Feed-Forward Networks (FFN) and attention outputs, we must explicitly ensure memory contiguity:

```rust
// Memory consolidation block in Rust
let transposed = attention_output.transpose(1, 2); // Swaps sequence and head dimensions
let contiguous_tensor = transposed.contiguous();    // Reorganizes memory blocks sequentially
let reshaped = contiguous_tensor.view([batch_size, seq_len, embed_dim]);
```

---

## 3. Writing the Compiled Training Loop

In Python, model parameters are updated implicitly in the background. In Rust, we build the optimizer and run backpropagation steps explicitly, maintaining absolute control over memory blocks:

```rust
// src/main.rs
fn main() -> anyhow::Result<()> {
    let block_size = 128;
    let vocab_size = 968;
    let device = find_device();
    
    // Create the Variable Store (allocates weights on target device)
    let vs = nn::VarStore::new(device);
    let mut opt = nn::AdamW::default().build(&vs, 1e-4)?;
    
    // Initialize model
    let model = Gpt::new(vs.root(), vocab_size, 128, block_size, 4, 4);

    for epoch in 0..100 {
        // Load training batch
        let (xs, ys, _) = dataset::get_batch_train();
        
        // Forward pass
        let logits = model.forward(&xs.to_kind(Kind::Int64).to_device(device));
        
        // Flatten logits & targets for cross entropy computation
        let (b, t, c) = logits.size3().unwrap();
        let logits = logits.view([b * t, c]);
        let targets = ys.to_kind(Kind::Int64).to_device(device).view([b * t]);
        
        let loss = logits.cross_entropy_for_logits(&targets);

        // Backward pass & gradient updates (zero-grad + step in one atomic call)
        opt.backward_step(&loss);
        
        if epoch % 10 == 0 {
            println!("Epoch: {}, Loss: {:?}", epoch, loss);
        }
    }
    Ok(())
}
```

---

## Key Differences: PyTorch vs. tch-rs

1. **Atomic Gradient Updates**: In Python, you must call `optimizer.zero_grad()`, `loss.backward()`, and `optimizer.step()` as separate lines. Rust replaces this with `opt.backward_step(&loss)`, reducing boilerplate and preventing developers from forgetting to zero out gradients.
2. **Explicit Reshapes**: Python's slice syntax is dynamic. Rust requires explicit tensor dimensional transforms (`.view()`, `.reshape()`), catching layout mismatches early.
3. **Variable Stores (`VarStore`)**: Rust segregates variables into a named tree (`vs.root()`), facilitating safe memory sharing across threads and simplifying model checkpoint serialization to disk.

---

## Engineering Takeaways

1. **Catch silent bugs at compile time**: Deep learning model development is prone to silent failures where tensor dimensions mismatch during training. Rust forces you to define types and dimensions explicitly, raising compiler errors instead of crashing your GPU mid-run.
2. **Be mindful of memory contiguity**: Relational and transpose operations in C++ `libtorch` change memory index offsets without rearranging values. Always verify memory contiguity before applying `view` transformations.
3. **No server overhead**: Compiling your models to native binaries using `cargo build --release` means you can run your AI training and inference on edge servers with a tiny memory footprint, without needing to configure complex Python runtimes.
