Reference: This project is based on the architecture and concepts presented in the paper Attention Is All You Need by Vaswani et al. We implement the core components of the Transformer (Multi-Head Attention, Feed-Forward Networks, Positional Encoding, etc.) in pure JAX, and showcase a training procedure on a dummy dataset using JAX’s
jitcompilation.
The Transformer is a sequence-to-sequence model that uses self-attention mechanisms to learn contextual relations between tokens in a sequence. Unlike recurrent or convolutional models, it relies entirely on attention to draw global dependencies between input and output sequences.
Key highlights:
- No Recurrent or Convolutional Layers: All context modeling is done through attention mechanisms.
- Parallelizable: Self-attention allows parallel processing of sequences, enabling efficient training.
- Positional Encoding: Injects information about the relative or absolute position of tokens in the sequence.
This project demonstrates how to build and train a simplified Transformer using JAX for automatic differentiation and jit compilation, referencing the original equations and diagrams from Vaswani et al.
Below is the high-level Transformer architecture from the paper:
It consists of an Encoder stack and a Decoder stack. Each layer contains:
- Multi-Head Self-Attention (masked in the decoder’s first sub-layer)
- Add & Norm (residual connection + layer normalization)
- Feed-Forward Network (position-wise)
- Add & Norm again
At the heart of the model is scaled dot-product attention. Given query (Q), key (K), and value (V) matrices, the attention output is computed as:
Attention(Q, K, V) = softmax( (QK^T) / sqrt(d_k) ) * V
where d_k is the dimensionality of the keys (and queries).
To allow the model to attend to different positions from different representation subspaces, multi-head attention is used. Multiple attention “heads” each compute scaled dot-product attention in parallel, and their outputs are concatenated:
MultiHead(Q, K, V) = Concat(head_1, ..., head_h) * W^O
where each head i is:
head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)
After multi-head attention, each position is passed through a position-wise feed-forward network:
FFN(x) = max(0, xW1 + b1) * W2 + b2
This is applied identically to each position, separately and identically, hence “position-wise.”
Because the model contains no recurrence or convolution, it needs a way to encode sequence order. The positional encoding adds sines and cosines of varying frequencies to the embeddings:
PE(pos, 2i) = sin(pos / 10000^(2i / d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i / d_model))
Each sub-layer output is added to the input (residual connection) and then normalized via Layer Normalization:
LayerOutput = LayerNorm(x + Sublayer(x))
We implement these components in pure JAX. Our code structure includes:
scaled_dot_product_attention(q, k, v, mask=None)multi_head_attention(q, k, v, Wq, Wk, Wv, Wo, mask=None, h=8)position_wise_ffn(x, W1, b1, W2, b2, activation=...)layer_norm(x, gamma=None, beta=None, eps=1e-6)add_and_norm(x, sublayer_out, gamma=None, beta=None, eps=1e-6)positional_encoding(seq_len, dim_model)encoder_layer(...),decoder_layer(...)encoder_stack(...),decoder_stack(...)transformer_forward_pass(...)
- Scaled Dot-Product Attention: Implements the equation
QK^T / sqrt(d_k)→ softmax → multiply byV. - Multi-Head Attention: Splits queries, keys, values into multiple heads, applies scaled dot-product attention, then concatenates.
- Feed-Forward Network: Two fully connected layers with a ReLU (or user-defined) activation.
- Positional Encoding: Returns a matrix of sine/cosine positional encodings.
- Add & Norm: Implements the residual connection plus layer normalization.
- Encoder/Decoder Layers: Combines multi-head attention, feed-forward, add & norm.
- Forward Pass: The full encoder-decoder pass, returning logits.
The transformer_forward_pass function integrates everything:
- Embed + Positional Encode the source tokens.
- Encode them with
encoder_stack. - Embed + Positional Encode the target tokens.
- Decode them with
decoder_stack. - Final Linear Projection to obtain logits for each target position.
We train the model using JAX for automatic differentiation and jit compilation, which greatly accelerates the forward and backward passes. For demonstration, we use a dummy dataset of random token indices.
- Number of Layers (
N): 6 - Hidden Dimension (
d_model): 512 - Feed-Forward Dimension (
d_ff): 2048 - Number of Heads (
h): 8 - Vocabulary Size: 37,000 (example)
- Dataset: We generate random source and target sequences of fixed length (e.g., 50 tokens each).
- Loss Function: A cross-entropy loss on the predicted logits vs. the true target tokens.
- Autograd and JIT:
- We define
loss_fnand usejax.gradto compute gradients automatically. - We wrap
loss_fninjax.jitfor speed.
- We define
- SGD or Adam: Update parameters in each iteration.
- Print intermediate losses and track progress.
In practice, you would replace the dummy dataset with real data (e.g., WMT 2014 EN-DE). The code is structured to demonstrate how the components come together in a training loop using JAX.


