Lecture notes and code samples from ML Foundations Summer School in Bilimler Koyu. The notebook is based largely on the following resources:

  • The Annotated Transformer
  • Makemore
  • UvA lecture notes

The text is generated by ChatGPT - after careful prompting 😋

Import statements¶

In [ ]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

from torchvision.datasets import CIFAR10
from torchvision import transforms

import numpy as np
import matplotlib.pyplot as plt

from datasets import load_dataset
from transformers import AutoTokenizer

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
/mnt/lustre/work/vernade/cyildiz40/.conda/sumsch/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm

The main task in this notebook¶

We will learn a higher-order Markovian model that

  • takes as input some context that has length more than 1
  • predicts the future based on the context: \begin{align*} p(x_0, x_1, \ldots, x_N) = p(x_0) \prod_{n=1}^N p(x_n | x_{<n}) \end{align*}

This model will translate German to English. Let's see what the dataset looks like.

In [2]:
tok = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-de-en")

if tok.bos_token is None:
    tok.add_special_tokens({"bos_token": "<bos>"})
if tok.pad_token is None:
    tok.add_special_tokens({"pad_token": "<pad>"})
if tok.eos_token is None:
    tok.add_special_tokens({"eos_token": "<eos>"})

print('Vocab size:', len(tok))

ds_train = load_dataset("bentrevett/multi30k", split="train")   # fields: "de", "en"
ds_valid = load_dataset("bentrevett/multi30k", split="validation")

# Peek at the first example
print(len(ds_train), 'data points.')
print('\nAn example data point:', ds_train[0])

de_en = '[DE] ' + ds_train[0]['de'] + ' [EN] ' + ds_train[0]['en']
print('\nInput-outputs derived from the first data point:')

split_de_en = de_en.split(' ')
for i in range(14,23):
    print('\nInput:', ' '.join(split_de_en[:i]))
    print('Output:', ' '.join(split_de_en[i:i+1]))


# print('\nInput:', de_en[:76])
# print('Output:', de_en[76:79])

# print('\nInput:', de_en[:80])
# print('Output:', de_en[80:85])

# print('\nInput:', de_en[:87])
# print('Output:', de_en[87:92])

# print('\nInput:', de_en[:93])
/mnt/lustre/work/vernade/cyildiz40/.conda/sumsch/lib/python3.11/site-packages/transformers/models/marian/tokenization_marian.py:175: UserWarning: Recommended: pip install sacremoses.
  warnings.warn("Recommended: pip install sacremoses.")
Vocab size: 58102
29000 data points.

An example data point: {'en': 'Two young, White males are outside near many bushes.', 'de': 'Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.'}

Input-outputs derived from the first data point:

Input: [DE] Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche. [EN]
Output: Two

Input: [DE] Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche. [EN] Two
Output: young,

Input: [DE] Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche. [EN] Two young,
Output: White

Input: [DE] Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche. [EN] Two young, White
Output: males

Input: [DE] Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche. [EN] Two young, White males
Output: are

Input: [DE] Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche. [EN] Two young, White males are
Output: outside

Input: [DE] Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche. [EN] Two young, White males are outside
Output: near

Input: [DE] Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche. [EN] Two young, White males are outside near
Output: many

Input: [DE] Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche. [EN] Two young, White males are outside near many
Output: bushes.

1. Introduction¶

Different architectures¶

Over the past decades, machine learning has produced many different architectures, each suited to different kinds of data:

  • Multilayer perceptrons (MLPs) were the earliest general-purpose neural networks, used for tasks like basic classification and regression.
  • Convolutional neural networks (CNNs) then became the workhorse for computer vision, learning how to recognize patterns in images by exploiting local structure.
  • For sequential data, recurrent neural networks (RNNs) were widely adopted, because they could process variable-length sequences by maintaining a hidden state over time. A prominent example of RNNs in practice was machine translation, where models such as the Sequence-to-Sequence architecture (Sutskever et al., 2014, figure below) with attention mechanisms (Bahdanau et al., 2015) achieved major breakthroughs.

No description has been provided for this image
A cartoon illustration of RNN-based machine translation (image from https://www.interdb.jp/dl/part03/ch13.html)

Attention¶

The idea of attention itself was revolutionary. It allowed models to dynamically focus on the most relevant parts of the input, instead of compressing everything into a single fixed-size vector. Early applications included neural machine translation, image captioning, and speech recognition. But then came a turning point: the landmark paper “Attention Is All You Need” (Vaswani et al., 2017). This work proposed the Transformer, a model built entirely around attention layers, without recurrence or convolutions. It introduced the now-famous encoder–decoder architecture, where the encoder transforms an input sequence (like a German sentence) into a sequence of contextual representations, and the decoder generates the output sequence (like the English translation) step by step, attending both to the encoder’s outputs and to its own past predictions.

No description has been provided for this image
Attention is all you need architecture (image from Vaswani et al., 2017)

Causal language models¶

The Transformer quickly became the dominant architecture for language tasks, and soon spread to vision, speech, and beyond. But in this tutorial, our focus will not be on faithfully reproducing the exact “Attention Is All You Need” setup. Instead, we’ll study the core building blocks — attention, masking, autoregression — through the lens of causal language models (LLMs). These models, like GPT-2 (Radford et al., 2019) and its descendants, use the Transformer in a simpler “decoder-only” form, where the goal is always the same: predict the next token given the previous ones. By starting here, we’ll build a clear intuition for how modern translation and large-scale language models are trained.

No description has been provided for this image
Causal vs masked language models (image from https://www.holisticai.com/blog/from-transformer-architecture-to-prompt-engineering)


What are we going to implement?¶

1. A Transformer network¶

We will implement, from scratch, a transformer network. This network will be quite generic - so much so that it could be used even for computer vision as we demonstrate at the end of this notebook.

2. Auto-regressive (causal) next item predictors¶

A central idea in sequence modeling is auto-regression: given a sequence of items, a model learns to predict the next item based on the context of all previous ones. Here, the model could be any ML model, and we will consider MLPs and Transformers. This approach is powerful because it transforms the problem of modeling complex sequences into a series of simple, repeated predictions. For text, this means predicting the next word given the words that came before; for time series, it could mean predicting tomorrow’s value from past values. Once trained, such a model can be used to generate arbitrarily long sequences: start with an initial context, predict the next token, append it to the context, and repeat. This is the principle behind modern large language models like GPT.

3. Data loaders and training modules¶

Of course, we would like to see if our method actually works or not. For this, we will consider a German-to-English translation dataset

2. Transformer network implementation¶

2.1. Transformer network¶

Now we begin implementing the Transformer architecture. The figure below shows the overall structure.

No description has been provided for this image
Transformer network with positional embeddings and transformer blocks. (image from https://ieeexplore.ieee.org/document/9438286)

In [3]:
T = 256     # Context length (sequence length)
V = 10000   # Vocabulary size (for embeddings)
N = 12      # Number of Transformer blocks
d = 256     # Embedding dimension
H = 8       # Number of attention heads

class TransformerNetwork(nn.Module):
    def __init__(self, T, N, input_emb, pos_emb, transf_block, norm_layer, output_layer):
        ''' Transformer network with N blocks, context length T, and given components. 
        Args:
            T: Context length (sequence length).
            N: Number of Transformer blocks.
            input_emb: Input embedding layer/module.
            pos_emb: Positional embedding layer/module.
            transf_block: A single Transformer block (to be repeated N times).
            norm_layer: Normalization layer/module (e.g., LayerNorm).
            output_layer: Output layer/module (e.g., for classification or regression).
        '''
        super(TransformerNetwork, self).__init__()
        self.T = T
        self.input_emb = input_emb
        self.pos_emb = pos_emb
        self.transf_blocks = nn.ModuleList([transf_block for _ in range(N)])
        self.norm_layer = norm_layer
        self.output_layer = output_layer
    
    @property
    def device(self):
        return next(self.parameters()).device
    
    @property
    def dtype(self):
        return next(self.parameters()).dtype
    
    @property
    def num_layers(self):
        return len(self.transf_blocks)
    
    @property
    def d(self):
        return self.input_emb.embedding_dim

    def forward(self, X):
        '''
        Forward pass through the Transformer network.
        Args:
            X: Input tensor of shape (batch_size, T, input_dim), or (B,T,d).
        Returns:
            Output tensor after passing through the Transformer network.
        '''
        device = X.device
        B,T = X.shape[0:2]
        # assert T == self.T, f'Input sequence length {T} does not match model context length {self.T}.'

        # Input embedding
        X  = self.input_emb(X)  # B,T,d

        # positional embeddings
        P  = self.pos_emb(torch.arange(T, device=device)) # T,d
        P  = P.unsqueeze(0).expand(B, T, -1)  # B,T,d 

        # adding positional embeddings
        X = X + P  # B,T,d

        # passing through Transformer blocks
        for block in self.transf_blocks:
            X = block(X)  # B,T,d
        
        # final normalization
        X = self.norm_layer(X)  # B,T,d

        # output layer
        out = self.output_layer(X)  # B,T,output_dim (depends on output_layer)

        return out

2.2. Embeddings and positional encodings¶

At the input layer, tokens are first mapped into embeddings, and we add positional encodings so the model can take the token order into account. The sequence of embeddings is then passed through a stack of Transformer blocks:

  • Embeddings: Neural networks cannot work directly with discrete symbols (like words or image patches), so we map them into continuous vectors called embeddings. These vectors capture semantic or structural properties in a way that the model can process. In NLP, embeddings represent tokens such that similar words are close in vector space; in vision, embeddings represent image patches or regions in a comparable way.
  • Positional encodings: Transformers do not have an inherent sense of order, since self-attention treats all tokens symmetrically. To provide the model with information about sequence or spatial position, we add positional encodings to the embeddings. In text, this means encoding word positions within a sentence; in vision, it means encoding the 2D location of a patch within an image. These positional signals let the model distinguish, for example, “the first word in a sentence” or “a patch in the top-left corner of an image.”
In [4]:
def create_embedding(input_dim, output_dim):
    ''' Creates an embedding layer for input tokens. '''
    return nn.Embedding(input_dim, output_dim)

input_emb = create_embedding(V, d) # Input embedding
pos_emb   = create_embedding(T, d) # Positional embedding

# example usage of positiona embeddings
P = pos_emb(torch.arange(T)) # Example usage to get positional embeddings for positions 0 to T-1
plt.figure(figsize=(8,4))
plt.imshow(P.detach().cpu().numpy(), aspect='auto')
plt.colorbar()
plt.title('Positional Embeddings (not trained yet)')
plt.xlabel('Embedding Dimension')
plt.ylabel('Position')
plt.show()
No description has been provided for this image

2.3. Transformer blocks¶

Each Transformer block has two main components:

  • Multi-Head Self-Attention (MHSA), which allows every token to attend to other tokens in the sequence, gathering relevant context.
  • Feed-Forward Network (FFN), a small MLP applied to each token representation independently.

Both parts are wrapped with residual connections and layer normalization for stable training.

  • Layer normalization works by normalizing the activations of each token representation across its hidden dimensions, ensuring that values have consistent scale and distribution. This helps gradients flow more reliably, reduces sensitivity to initialization, and allows the model to train deeper networks without diverging. In practice, it is usually applied together with residual connections, so that the model learns new information while preserving the original signal.
In [5]:
class Block(nn.Module):
    def __init__(self, mhsa, d, d_ff, dropout=0.1):
        ''' A single Transformer block with Multi-Head Self-Attention and Feed-Forward Network.
        Args:
            mhsa: Multi-Head Self-Attention module.
            d: Embedding dimension.
            d_ff: Dimension of the feed-forward network.
            dropout: Dropout rate.
        '''
        super(Block, self).__init__()
        self.mhsa = mhsa
        self.ffn  = nn.Sequential(
            nn.Linear(d, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d)
        )
        self.norm1 = nn.LayerNorm(d)
        self.norm2 = nn.LayerNorm(d)
        self.dropout = nn.Dropout(dropout)

    def forward(self, X):
        '''
        Forward pass through the Transformer block.
        Args:
            X: Input tensor of shape (B, T, d).
        Returns:
            Output tensor of shape (B, T, d).
        '''
        X = X + self.dropout(self.mhsa(self.norm1(X)))
        X = X + self.dropout(self.ffn(self.norm2(X)))
        return X

2.4. Self-attention implementation¶

Within attention, the terms query, key, and value have special meaning:

  • A query vector represents the current token asking “what information do I need?”
  • Key vectors represent all tokens as potential “addresses” to look up.
  • Value vectors hold the actual information to be passed along.

Attention works by comparing queries against keys (using dot products), producing weights that determine how strongly each token’s value contributes to the output. This mechanism lets the model learn what to pay attention to at every step.

For an input $X = x_{1:n}$, where each token $x \in \mathbb{R}^d$, we first obtain queries, keys and values:

\begin{align} Q &= W_q^\top X \in \mathbb{R}^{n\times d_q}\\ K &= W_k^\top X \in \mathbb{R}^{n\times d_q}\\ V &= W_v^\top X \in \mathbb{R}^{n\times d_v}\\ \end{align}

For each query $\{q_i\} \in \mathbb{R}^{d_q}$, we compute its similarity to a key with \begin{align*} \text{sim}(q_i,k_j) \sim \exp(q_i^\top k_j) \end{align*} Then we normalize this across all values, which is implemented as $\texttt{softmax}$ below: \begin{align*} \text{sim}(q_i,k_j) = \frac{ \exp(q_i^\top k_j) } {\sum_m \exp(q_i^\top k_m) } \end{align*} The final output is a weighted average of values: \begin{align*} z_i = \sum_j \text{sim}(q_i,k_j) v_j \end{align*}

No description has been provided for this image
Self-attention calculations (image from https://sebastianraschka.com/blog/2023/self-attention-from-scratch.html)

In [42]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, d, H):
        ''' Multi-Head Self-Attention module.
        Args:
            d: Embedding dimension.
            H: Number of attention heads.
        '''
        super(MultiHeadSelfAttention, self).__init__()
        assert d % H == 0, "Embedding dimension must be divisible by number of heads."
        self.d = d
        self.H = H
        self.d_k = d // H  # Dimension per head

        self.qkv = nn.Linear(d, 3*d)  # Combined query, key, value projection
        self.out = nn.Linear(d, d)  # Output projection
    
    def forward(self, X):
        '''
        Forward pass through the Multi-Head Self-Attention module.
        Args:
            X: Input tensor of shape (B, T, d).
        Returns:
            Output tensor of shape (B, T, d).
        '''
        B, T, d = X.shape
        assert d == self.d, f'Input embedding dimension {d} does not match model dimension {self.d}.'

        # Linear projections
        qkv = self.qkv(X)
        q,k,v = qkv.chunk(3, dim=-1)  # Each is (B, T, d) ---> form queries, keys, values

        # split into heads
        q_reshaped = q.view(B, T, self.H, self.d_k).transpose(1,2)  # (B, H, T, d_k)
        k_reshaped = k.view(B, T, self.H, self.d_k).transpose(1,2)  # (B, H, T, d_k)
        v_reshaped = v.view(B, T, self.H, self.d_k).transpose(1,2)  # (B, H, T, d_k)

        # Scaled dot-product attention
        k_transpose = k_reshaped.transpose(-2, -1)  # (B, H, d_k, T)
        QK = q_reshaped @ k_transpose # (B, H, T, T) ---> QK[b,h,i] stores the similarity of token i to all tokens in the sequence
        attn = QK.softmax(dim=-1) # (B, H, T, T) ---> attn[b,h,i] stores how much token i attends to all tokens in the sequence

        # alternative computation of softmax
        # alt_computation = QK.exp() / QK.exp().sum(dim=-1, keepdim=True)
        # assert torch.allclose(attn, alt_computation, atol=1e-5)

        # output calculation
        out = attn @ v_reshaped  # (B, H, T, d_k)
        out = out.transpose(1,2).contiguous().view(B, T, d) # (B, T, d) ---> put together all the heads
        out = self.out(out)
        return out

3. Training¶

3.1. Preparing the dataset¶

In [7]:
class DeEnPairs(Dataset):
    def __init__(self, dataset):
        self.de = dataset["de"]
        self.en = dataset["en"]

    def __len__(self): 
        return len(self.de)
    
    def __getitem__(self, i):
        return self.de[i], self.en[i]

def make_causal_collate(tok, block_size=256, train_on_source=False):

    PAD_id = tok.pad_token_id
    BOS_id = tok.bos_token_id
    EOS_id = tok.eos_token_id

    def encode_plain(text):
        # Raw subword IDs, no BOS/EOS added automatically
        return tok(text, add_special_tokens=False)["input_ids"]

    def collate(batch):
        input_ids_list, labels_list, attn_list, end_of_DE_list = [], [], [], []

        for de_txt, en_txt in batch:
            de_ids = encode_plain(de_txt)
            en_ids = encode_plain(en_txt)

            # Build: [de, BOS, en, EOS]
            seq = de_ids + [BOS_id] + en_ids + [EOS_id]

            x = torch.tensor(seq, dtype=torch.long)

            # mask source part (German + the BOS boundary) unless you also want LM on DE
            # find BOS position we inserted as English boundary
            bos_pos = (x == BOS_id).nonzero(as_tuple=True)[0][0].item()

            # Labels = next-token targets. we shift inputs by one implicitly via CE on logits[:, :-1] vs labels[:, 1:]
            y = x.clone()[1:]

            if not train_on_source:
                y[:bos_pos] = -100  # ignore German and the BOS token

            # last token has no next-token label; safe to keep (HF shifts internally),
            # but many people prefer to mask last position explicitly:
            # y[-1] = -100

            attn = torch.ones_like(x, dtype=torch.long)  # 1 for real tokens; padding will be 0

            input_ids_list.append(x)
            labels_list.append(y)
            attn_list.append(attn)
            end_of_DE_list.append(bos_pos+1)

        # Pad to batch max length
        input_ids = pad_sequence(input_ids_list, batch_first=True, padding_value=PAD_id)
        labels    = pad_sequence(labels_list,    batch_first=True, padding_value=-100)
        attention = pad_sequence(attn_list,      batch_first=True, padding_value=0)
        end_of_DE_list = torch.tensor(end_of_DE_list, dtype=torch.long)

        # If sequence still exceeds block_size after padding logic (rare), hard cut at right edge:
        if input_ids.size(1) > block_size:
            input_ids = input_ids[:, -block_size:]
            labels    = labels[:,    -block_size:]
            attention = attention[:, -block_size:]

        return {"input_ids": input_ids, "attention_mask": attention, "labels": labels, "end_of_DE": end_of_DE_list}

    return collate

train_ds = DeEnPairs(ds_train)
valid_ds = DeEnPairs(ds_valid)

collate = make_causal_collate(tok, block_size=256, train_on_source=True)
train_dl = DataLoader(train_ds, batch_size=64, shuffle=True, collate_fn=collate)
batch = next(iter(train_dl))

# 1 if there is a real token, 0 if padding
print('Attention masks:\n',batch['attention_mask'][0],'\n')

# indices of tokens in the vocabulary. remember 58101 is <bos>, 0 is <eos>, 58100 is <pad>
print('Input ids:\n',batch['input_ids'][0],'\n')

# the beginning idx of the English part (after German + <bos>)
print('End of German sentences:\n',batch['end_of_DE'][0],'\n')

# the idx of the next token to predict, -100 if ignored (German + <bos> + padding)
print('Labels:\n',batch['labels'][0],'\n')

# print the actual text
for i in range(batch['attention_mask'][0].sum()):
    print(batch['input_ids'][0][i].item(), tok.decode(batch['input_ids'][0][i].item()), sep='\t')
Attention masks:
 tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0]) 

Input ids:
 tensor([  525, 17593,   222, 41377,    15, 19943, 16778,    37,    94, 16167,
            2,   145,  2229, 12317,  4691, 22799,     3, 58101,    93,    17,
         7565,    18,     7,    39,    17,  6409,    15,  1730,   175,    17,
            6,  9401,    32,    14,    17, 12662,     5,    17,  9691,     7,
           88,  1032,    17, 19376,    17, 27499,    17, 16391,     6,     3,
            0, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100,
        58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100,
        58100, 58100, 58100, 58100, 58100]) 

End of German sentences:
 tensor(18) 

Labels:
 tensor([17593,   222, 41377,    15, 19943, 16778,    37,    94, 16167,     2,
          145,  2229, 12317,  4691, 22799,     3, 58101,    93,    17,  7565,
           18,     7,    39,    17,  6409,    15,  1730,   175,    17,     6,
         9401,    32,    14,    17, 12662,     5,    17,  9691,     7,    88,
         1032,    17, 19376,    17, 27499,    17, 16391,     6,     3,     0,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100]) 

525	Eine
17593	Figur
222	eines
41377	orientalisch
15	en
19943	Mannes
16778	sitzt
37	auf
94	einer
16167	Mauer
2	,
145	vor
2229	einigen
12317	weißen
4691	Wohn
22799	häusern
3	.
58101	<bos>
93	A
17	
7565	figur
18	e
7	of
39	an
17	
6409	ori
15	en
1730	tal
175	man
17	
6	s
9401	its
32	on
14	a
17	
12662	wall
5	in
17	
9691	front
7	of
88	so
1032	me
17	
19376	white
17	
27499	apartment
17	
16391	building
6	s
3	.
0	</s>

3.2. Our training pipeline¶

In practice, we train these models efficiently by predicting many next tokens in parallel. Suppose our target sequence has length D. Instead of creating D separate training examples (each with a different context), we process the whole sequence at once. We feed the model the input [x₁, x₂, …, x_{D-1}], and we ask it to predict [x₂, x₃, …, x_D]. In other words, for every position, the model tries to guess “what comes next.” This produces D-1 predictions from a single forward pass. During inference, the model still generates one token at a time, but during training we can supervise all positions simultaneously, which makes learning much faster. The important distinction is: training = parallel next-token prediction across the sequence, generation = sequential prediction, one token at a time.

So now, we implement a causal attention mechanism that predicts token $i$ conditioned to all tokens until $i$.

In [8]:
class MultiHeadCausalSelfAttention(nn.Module):
    def __init__(self, d, H):
        ''' Multi-Head Self-Attention module.
        Args:
            d: Embedding dimension.
            H: Number of attention heads.
        '''
        super(MultiHeadCausalSelfAttention, self).__init__()
        assert d % H == 0, "Embedding dimension must be divisible by number of heads."
        self.d = d
        self.H = H
        self.d_k = d // H  # Dimension per head

        self.qkv = nn.Linear(d, 3*d)  # Combined query, key, value projection
        self.out = nn.Linear(d, d)  # Output projection
    
    def forward(self, X):
        '''
        Forward pass through the Multi-Head Self-Attention module.
        Args:
            X: Input tensor of shape (B, T, d).
            eng_start_idx: Tensor of shape (B,) indicating the start index of the English part in each sequence.
        Returns:
            Output tensor of shape (B, T, d).
        '''
        B, T, d = X.shape
        assert d == self.d, f'Input embedding dimension {d} does not match model dimension {self.d}.'

        # Linear projections
        qkv = self.qkv(X)
        q,k,v = qkv.chunk(3, dim=-1)  # Each is (B, T, d) ---> form queries, keys, values

        # split into heads
        q_reshaped = q.view(B, T, self.H, self.d_k).transpose(1,2)  # (B, H, T, d_k)
        k_reshaped = k.view(B, T, self.H, self.d_k).transpose(1,2)  # (B, H, T, d_k)
        v_reshaped = v.view(B, T, self.H, self.d_k).transpose(1,2)  # (B, H, T, d_k)

        # Scaled dot-product attention
        k_transpose = k_reshaped.transpose(-2, -1)  # (B, H, d_k, T)
        QK = q_reshaped @ k_transpose # (B, H, T, T) ---> QK[b,h,i] stores the similarity of token i to all tokens in the sequence

        # create the causal mask
        causal_mask = torch.tril(torch.ones((T, T), device=X.device)).unsqueeze(0).unsqueeze(0)  # (1, 1, T, T)
        QK = QK.masked_fill(causal_mask == 0, float('-inf'))
        attn = QK.softmax(dim=-1) # (B, H, T, T) ---> attn[b,h,i] stores how much token i attends to all tokens in the sequence

        # output calculation
        out = attn @ v_reshaped  # (B, H, T, d_k)
        out = out.transpose(1,2).contiguous().view(B, T, d) # (B, T, d) ---> put together all the heads
        out = self.out(out)
        return out

3.3. Build the network¶

In [ ]:
T = 256         # Context length (sequence length)
V = len(tok)    # Vocabulary size (for embeddings)
N = 12          # Number of Transformer blocks
d = 256         # Embedding dimension
H = 8           # Number of attention heads

input_emb = create_embedding(V, d) # Input embedding
pos_emb   = create_embedding(T, d) # Positional embedding

mhsa = MultiHeadCausalSelfAttention(d, H)
transf_block = Block(mhsa, d, 4*d, dropout=0.1)
norm_layer = nn.LayerNorm(d)
output_layer = nn.Linear(d, V)  # Output layer to project to vocabulary size 

model = TransformerNetwork(
    T=T, # max seq length
    N=N, # num transformer blocks
    input_emb=input_emb, # Input embedding layer/module.
    pos_emb=pos_emb, # Positional embedding layer/module.
    transf_block=transf_block, # A single Transformer block (to be repeated N times).
    norm_layer=norm_layer, # Normalization layer/module (e.g., LayerNorm).
    output_layer=output_layer # Output layer/module (e.g., to map to data space).
)

print(model)
print('Number of parameters in the model:', sum(p.numel() for p in model.parameters()) )

optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4, weight_decay=1e-3, betas=(0.9, 0.99), eps=1e-8)

model = model.to(device)

# state = torch.load('transformer_deen_T256_N12_d256_H8_can_generate_german.pth', map_location=device)
# model.load_state_dict(state)
TransformerNetwork(
  (input_emb): Embedding(58102, 256)
  (pos_emb): Embedding(256, 256)
  (transf_blocks): ModuleList(
    (0-11): 12 x Block(
      (mhsa): MultiHeadCausalSelfAttention(
        (qkv): Linear(in_features=256, out_features=768, bias=True)
        (out): Linear(in_features=256, out_features=256, bias=True)
      )
      (ffn): Sequential(
        (0): Linear(in_features=256, out_features=1024, bias=True)
        (1): ReLU()
        (2): Linear(in_features=1024, out_features=256, bias=True)
      )
      (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (norm_layer): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  (output_layer): Linear(in_features=256, out_features=58102, bias=True)
)
Number of parameters in the model: 30662134
/tmp/ipykernel_616401/352975203.py:33: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  state = torch.load('transformer_deen_T256_N12_d256_H8_can_generate_german.pth', map_location=device)
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[9], line 35
     33 state = torch.load('transformer_deen_T256_N12_d256_H8_can_generate_german.pth', map_location=device)
     34 model.load_state_dict(state)
---> 35 print(aaa)

NameError: name 'aaa' is not defined

3.4. Train the network¶

In [ ]:
# train the network
for epoch in range(20):
    model.train()
    total_loss = 0.0
    for i,batch in enumerate(train_dl):
        input_ids = batch['input_ids'].to(device)
        labels = batch['labels'].to(device)
        eng_start_idx = batch['end_of_DE'].to(device)

        optimizer.zero_grad()
        outputs = model(input_ids)  # (B, T, V)
        logits = outputs[:, :-1, :].contiguous()  # (B, T-1, V)

        loss = F.cross_entropy(logits.view(-1, V), labels.view(-1), ignore_index=-100)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        if i % 100 == 0:
            print(f'Processing batch {i}/{len(train_dl)}. Mean loss so far: {total_loss/(i+1):.4f}')

    avg_loss = total_loss / len(train_dl)
    print(f'Epoch {epoch+1}, Training Loss: {avg_loss:.4f}')

torch.save(model.state_dict(), 'transformer_deen.pth')
Processing batch 0/454. Mean loss so far: 11.1316
Processing batch 100/454. Mean loss so far: 6.0472
Processing batch 200/454. Mean loss so far: 5.1690
Processing batch 300/454. Mean loss so far: 4.7178
Processing batch 400/454. Mean loss so far: 4.4251
Epoch 1, Training Loss: 4.3016
Processing batch 0/454. Mean loss so far: 3.5583
Processing batch 100/454. Mean loss so far: 3.2004
Processing batch 200/454. Mean loss so far: 3.1463
Processing batch 300/454. Mean loss so far: 3.1020
Processing batch 400/454. Mean loss so far: 3.0596
Epoch 2, Training Loss: 3.0323
Processing batch 0/454. Mean loss so far: 2.8784
Processing batch 100/454. Mean loss so far: 2.7206
Processing batch 200/454. Mean loss so far: 2.6985
Processing batch 300/454. Mean loss so far: 2.6770
Processing batch 400/454. Mean loss so far: 2.6592
Epoch 3, Training Loss: 2.6491
Processing batch 0/454. Mean loss so far: 2.4186
Processing batch 100/454. Mean loss so far: 2.4207
Processing batch 200/454. Mean loss so far: 2.4191
Processing batch 300/454. Mean loss so far: 2.4177
Processing batch 400/454. Mean loss so far: 2.4111
Epoch 4, Training Loss: 2.4055
Processing batch 0/454. Mean loss so far: 2.1825
Processing batch 100/454. Mean loss so far: 2.2375
Processing batch 200/454. Mean loss so far: 2.2403
Processing batch 300/454. Mean loss so far: 2.2430
Processing batch 400/454. Mean loss so far: 2.2384
Epoch 5, Training Loss: 2.2337
Processing batch 0/454. Mean loss so far: 2.2638
Processing batch 100/454. Mean loss so far: 2.0781
Processing batch 200/454. Mean loss so far: 2.0903
Processing batch 300/454. Mean loss so far: 2.0961
Processing batch 400/454. Mean loss so far: 2.1039
Epoch 6, Training Loss: 2.1036
Processing batch 0/454. Mean loss so far: 2.0014
Processing batch 100/454. Mean loss so far: 1.9786
Processing batch 200/454. Mean loss so far: 1.9838
Processing batch 300/454. Mean loss so far: 1.9911
Processing batch 400/454. Mean loss so far: 1.9965
Epoch 7, Training Loss: 1.9970
Processing batch 0/454. Mean loss so far: 1.9444
Processing batch 100/454. Mean loss so far: 1.8606
Processing batch 200/454. Mean loss so far: 1.8752
Processing batch 300/454. Mean loss so far: 1.8913
Processing batch 400/454. Mean loss so far: 1.8993
Epoch 8, Training Loss: 1.9066
Processing batch 0/454. Mean loss so far: 1.6859
Processing batch 100/454. Mean loss so far: 1.7814
Processing batch 200/454. Mean loss so far: 1.7992
Processing batch 300/454. Mean loss so far: 1.8133
Processing batch 400/454. Mean loss so far: 1.8245
Epoch 9, Training Loss: 1.8286
Processing batch 0/454. Mean loss so far: 1.7428
Processing batch 100/454. Mean loss so far: 1.7061
Processing batch 200/454. Mean loss so far: 1.7296
Processing batch 300/454. Mean loss so far: 1.7408
Processing batch 400/454. Mean loss so far: 1.7550
Epoch 10, Training Loss: 1.7606
Processing batch 0/454. Mean loss so far: 1.6196
Processing batch 100/454. Mean loss so far: 1.6392
Processing batch 200/454. Mean loss so far: 1.6609
Processing batch 300/454. Mean loss so far: 1.6790
Processing batch 400/454. Mean loss so far: 1.6942
Epoch 11, Training Loss: 1.7006
Processing batch 0/454. Mean loss so far: 1.5412
Processing batch 100/454. Mean loss so far: 1.5813
Processing batch 200/454. Mean loss so far: 1.6037
Processing batch 300/454. Mean loss so far: 1.6211
Processing batch 400/454. Mean loss so far: 1.6373
Epoch 12, Training Loss: 1.6448
Processing batch 0/454. Mean loss so far: 1.5117
Processing batch 100/454. Mean loss so far: 1.5385
Processing batch 200/454. Mean loss so far: 1.5619
Processing batch 300/454. Mean loss so far: 1.5775
Processing batch 400/454. Mean loss so far: 1.5889
Epoch 13, Training Loss: 1.5935
Processing batch 0/454. Mean loss so far: 1.3743
Processing batch 100/454. Mean loss so far: 1.4866
Processing batch 200/454. Mean loss so far: 1.5096
Processing batch 300/454. Mean loss so far: 1.5286
Processing batch 400/454. Mean loss so far: 1.5441
Epoch 14, Training Loss: 1.5500
Processing batch 0/454. Mean loss so far: 1.4222
Processing batch 100/454. Mean loss so far: 1.4365
Processing batch 200/454. Mean loss so far: 1.4626
Processing batch 300/454. Mean loss so far: 1.4873
Processing batch 400/454. Mean loss so far: 1.5038
Epoch 15, Training Loss: 1.5086
Processing batch 0/454. Mean loss so far: 1.3181
Processing batch 100/454. Mean loss so far: 1.4103
Processing batch 200/454. Mean loss so far: 1.4276
Processing batch 300/454. Mean loss so far: 1.4486
Processing batch 400/454. Mean loss so far: 1.4640
Epoch 16, Training Loss: 1.4719
Processing batch 0/454. Mean loss so far: 1.3642
Processing batch 100/454. Mean loss so far: 1.3767
Processing batch 200/454. Mean loss so far: 1.3968
Processing batch 300/454. Mean loss so far: 1.4167
Processing batch 400/454. Mean loss so far: 1.4323
Epoch 17, Training Loss: 1.4374
Processing batch 0/454. Mean loss so far: 1.3088
Processing batch 100/454. Mean loss so far: 1.3422
Processing batch 200/454. Mean loss so far: 1.3636
Processing batch 300/454. Mean loss so far: 1.3824
Processing batch 400/454. Mean loss so far: 1.3981
Epoch 18, Training Loss: 1.4056
Processing batch 0/454. Mean loss so far: 1.2703
Processing batch 100/454. Mean loss so far: 1.3162
Processing batch 200/454. Mean loss so far: 1.3311
Processing batch 300/454. Mean loss so far: 1.3505
Processing batch 400/454. Mean loss so far: 1.3692
Epoch 19, Training Loss: 1.3765
Processing batch 0/454. Mean loss so far: 1.2532
Processing batch 100/454. Mean loss so far: 1.2835
Processing batch 200/454. Mean loss so far: 1.3037
Processing batch 300/454. Mean loss so far: 1.3248
Processing batch 400/454. Mean loss so far: 1.3416
Epoch 20, Training Loss: 1.3480

3.5. Example model outputs¶

Finally, we take a test input and let model translate it. We give the German sentences as input and let model predict the next tokens.

In [ ]:
@torch.no_grad()
def generate(model, idx, max_new_tokens, temperature=1.0, do_sample=False, top_k=None):
    """
        Parameters:
        -----------
        model: The trained Transformer model.
        idx: Tensor of shape (B, T) containing the initial sequence of token indices.
        max_new_tokens: Number of new tokens to generate.
        temperature: Controls the randomness of predictions by scaling the logits.
        do_sample: If True, sample from the probability distribution; if False, take the most likely token.
        top_k: If specified, restricts sampling to the top_k most probable tokens
    """
    for _ in range(max_new_tokens):
        # forward the model to get the logits for the index in the sequence
        logits = model(idx)
        # pluck the logits at the final step and scale by desired temperature
        logits = logits[:, -1, :] / temperature
        # optionally crop the logits to only the top k options
        if top_k is not None:
            v, _ = torch.topk(logits, top_k)
            logits[logits < v[:, [-1]]] = -float('Inf')
        # apply softmax to convert logits to (normalized) probabilities
        probs = F.softmax(logits, dim=-1)
        # either sample from the distribution or take the most likely element
        if do_sample:
            idx_next = torch.multinomial(probs, num_samples=1)
        else:
            _, idx_next = torch.topk(probs, k=1, dim=-1)
        # append sampled index to the running sequence and continue
        idx = torch.cat((idx, idx_next), dim=1)

    return idx

# generate starting from a test sentence
valid_dl = DataLoader(valid_ds, batch_size=64, shuffle=True, collate_fn=collate)
batch = next(iter(valid_dl))

i = 0
input_ids, attention_mask, end_of_de = batch['input_ids'][i:i+1], batch['attention_mask'][i:i+1], batch['end_of_DE'][i:i+1]
input_ids, attention_mask, end_of_de = input_ids.to(device), attention_mask.to(device), end_of_de.to(device)

# translate
print('Input:\t\t', tok.decode(input_ids[0,:end_of_de]))
print('Output:\t\t', tok.decode(input_ids[0][end_of_de:]))
model_output = generate(model, input_ids[:,:end_of_de], 50, temperature=1.0, do_sample=False, top_k=None)
print('Generated:\t', tok.decode(model_output[0][end_of_de:]))

# generate German sentence
# print('Input:\t\t', tok.decode(input_ids[0,:end_of_de//2]))
# print('Output:\t\t', tok.decode(input_ids[0][end_of_de//2:]))
# model_output = generate(model, input_ids[:,:end_of_de//2], 50, temperature=1.0, do_sample=False, top_k=None)
# print('Generated:\t', tok.decode(model_output[0][end_of_de//2:]))
Input:		 Zwei Männer sind an einem hellen, sonnigen Tag unterwegs, um ein paar Fische im See zu fangen.<bos>
Output:		 Two men are out on a bright, sunny day attempting to catch some fish on the lake.</s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
Generated:	 Two men are enjoying a bright pond down a sideway toy in lake a lake.</s> e store's lake.</s> e store.</s> es.</s> es

4. Vision transformers¶

Vision Transformers (ViTs) bring the transformer architecture—originally designed for natural language processing—into computer vision. Instead of processing images with convolutional filters, a ViT splits an image into fixed-size patches (for example, 16×16 pixels), flattens them, and linearly projects each patch into an embedding vector. These embeddings form a sequence, much like a sentence in NLP, and are fed into a transformer encoder. A special “class token” is prepended to the sequence and its final hidden state serves as the representation for the whole image, which can then be used for classification or downstream tasks. Positional encodings are also added, since transformers themselves have no inherent sense of order or spatial arrangement.

No description has been provided for this image
A demonstration of vision transformers (image from https://encord.com/blog/vision-transformers/)

What makes ViTs powerful is their ability to model long-range dependencies: any patch can attend to any other patch, allowing the network to capture global structure in an image. Unlike convolutional networks, which focus on local receptive fields and build hierarchical features layer by layer, ViTs start with a global receptive field from the very first layer. This flexibility comes at the cost of needing large amounts of data (or pretraining on large datasets like ImageNet-21k) and careful regularization. Once pretrained, ViTs transfer very well across tasks, and they’ve become state-of-the-art not just in classification, but also in detection, segmentation, and even generative vision models.

4.1. ViT implementation¶

The main difference to previous implementation is the addition of the CLS token.

In [30]:
class VisionTransformer(nn.Module):
    def __init__(self, T, N, input_emb, pos_emb, transf_block, norm_layer, output_layer):
        ''' Transformer network with N blocks, context length T, and given components. 
        Args:
            T: Context length (sequence length).
            N: Number of Transformer blocks.
            input_emb: Input embedding layer/module.
            pos_emb: Positional embedding layer/module.
            transf_block: A single Transformer block (to be repeated N times).
            norm_layer: Normalization layer/module (e.g., LayerNorm).
            output_layer: Output layer/module (e.g., for classification or regression).
        '''
        super(VisionTransformer, self).__init__()
        self.T = T
        self.input_emb = input_emb
        self.pos_emb = pos_emb
        self.transf_blocks = nn.ModuleList([transf_block for _ in range(N)])
        self.norm_layer = norm_layer
        self.output_layer = output_layer
        embed_dim = pos_emb.embedding_dim
        self.cls_token = nn.Parameter(torch.randn(1,1,embed_dim))
    
    @property
    def device(self):
        return next(self.parameters()).device
    
    @property
    def dtype(self):
        return next(self.parameters()).dtype
    
    @property
    def num_layers(self):
        return len(self.transf_blocks)
    
    @property
    def d(self):
        return self.input_emb.embedding_dim

    def forward(self, X):
        '''
        Forward pass through the Transformer network.
        Args:
            X: Input tensor of shape (batch_size, seq_length, input_dim), or (B,T,d).
        Returns:
            Output tensor after passing through the Transformer network - [B,output_dim].
        '''
        device = X.device
        B,T = X.shape[0:2]

        # Input embedding
        X  = self.input_emb(X)  # B,T,d

        # add CLS token
        cls_tokens = self.cls_token.repeat(B,1,1)
        X = torch.cat((cls_tokens, X), dim=1)
        T = T+1

        # positional embeddings
        P  = self.pos_emb(torch.arange(T, device=device)) # T,d
        P  = P.unsqueeze(0).expand(B, T, -1)  # B,T,d 

        # adding positional embeddings
        X = X + P  # B,T,d

        # passing through Transformer blocks
        for block in self.transf_blocks:
            X = block(X)  # B,T,d

        class_token_final = X[:,0,:]
        
        # final normalization
        class_token_final = self.norm_layer(class_token_final)  # B,d

        # output layer
        out = self.output_layer(class_token_final)  # B,output_dim (depends on output_layer)

        return out

4.2. Build the network¶

This time,

  • the embedding layers are simple MLPs.
  • input length depends on the image and patch size.
In [47]:
patch_size = 8
T = int((32/patch_size)**2)     # Context length (sequence length)
N = 12      # Number of Transformer blocks
d = 256     # Embedding dimension
H = 16       # Number of attention heads
num_classes = 10 # number of output classes (for CIFAR-10)

input_emb = nn.Linear(3*(patch_size**2), d) # Input embedding
pos_emb   = nn.Embedding(T+1, d) # create_embedding(T+1, d) # Positional embedding

mhsa = MultiHeadSelfAttention(d, H)
transf_block = Block(mhsa, d, 4*d, dropout=0.1)
norm_layer = nn.LayerNorm(d)
output_layer = nn.Linear(d, num_classes)  # Output layer to project to vocabulary size 

model = VisionTransformer(
    T=T, # max seq length
    N=N, # num transformer blocks
    input_emb=input_emb, # Input embedding layer/module.
    pos_emb=pos_emb, # Positional embedding layer/module.
    transf_block=transf_block, # A single Transformer block (to be repeated N times).
    norm_layer=norm_layer, # Normalization layer/module (e.g., LayerNorm).
    output_layer=output_layer # Output layer/module (e.g., to map to data space).
)

print(model)
print('Number of parameters in the model:', sum(p.numel() for p in model.parameters()) )

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-3, betas=(0.9, 0.99), eps=1e-8)
model = model.to(device)
VisionTransformer(
  (input_emb): Linear(in_features=192, out_features=256, bias=True)
  (pos_emb): Embedding(17, 256)
  (transf_blocks): ModuleList(
    (0-11): 12 x Block(
      (mhsa): MultiHeadSelfAttention(
        (qkv): Linear(in_features=256, out_features=768, bias=True)
        (out): Linear(in_features=256, out_features=256, bias=True)
      )
      (ffn): Sequential(
        (0): Linear(in_features=256, out_features=1024, bias=True)
        (1): ReLU()
        (2): Linear(in_features=1024, out_features=256, bias=True)
      )
      (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (norm_layer): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  (output_layer): Linear(in_features=256, out_features=10, bias=True)
)
Number of parameters in the model: 846858

4.3. Training functions¶

We train the model on CIFAR10 dataset.

In [ ]:
test_transform = transforms.Compose([transforms.ToTensor(),
                                     transforms.Normalize([0.49139968, 0.48215841, 0.44653091], [0.24703223, 0.24348513, 0.26158784])
                                     ])
# For training, we add some augmentation. Networks are too powerful and would overfit.
train_transform = transforms.Compose([transforms.RandomHorizontalFlip(),
                                      transforms.RandomResizedCrop((32,32),scale=(0.8,1.0),ratio=(0.9,1.1)),
                                      transforms.ToTensor(),
                                      transforms.Normalize([0.49139968, 0.48215841, 0.44653091], [0.24703223, 0.24348513, 0.26158784])
                                     ])

# Loading the training dataset. We need to split it into a training and validation part
# We need to do a little trick because the validation set should not use the augmentation.
train_dataset = CIFAR10(root='data/', train=True, transform=train_transform, download=True)
val_dataset   = CIFAR10(root='data/', train=True, transform=test_transform,  download=True)

train_dl = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2, pin_memory=True)
val_dl   = DataLoader(val_dataset, batch_size=256, shuffle=False, num_workers=2, pin_memory=True)

def img_to_patch(x, patch_size):
    """
    Inputs:
        x - torch.Tensor representing the image of shape [B, C, H, W]
        patch_size - Number of pixels per dimension of the patches (integer)
    """
    B, C, H, W = x.shape
    x = x.reshape(B, C, H//patch_size, patch_size, W//patch_size, patch_size)
    x = x.permute(0, 2, 4, 1, 3, 5) # [B, H', W', C, p_H, p_W]
    x = x.flatten(1,2)              # [B, H'*W', C, p_H, p_W]
    x = x.flatten(2,4)          # [B, H'*W', C*p_H*p_W]
    return x

# X = img_to_patch(X, self.patch_size)  # B,T,d
Files already downloaded and verified
Files already downloaded and verified

4.4. Training loop¶

In [ ]:
# train the network
for epoch in range(20):
    model.train()
    total_accuracy = 0.0
    for i,batch in enumerate(train_dl):
        X, y = batch
        X, y = X.to(device), y.to(device)

        X = img_to_patch(X, patch_size)
        optimizer.zero_grad()
        logits = model(X)  # (B, num_classes)
        loss = F.cross_entropy(logits, y)
        accuracy = (logits.argmax(dim=-1) == y).float().mean()
        loss.backward()
        optimizer.step()
        total_accuracy += accuracy.item()
        if i % 100 == 0:
            print(f'Processing batch {i}/{len(train_dl)}. Mean accuracy so far: {total_accuracy/(i+1):.4f}')
    
    # now validate
    model.eval()
    total_val_accuracy = 0.0
    with torch.no_grad():
        for i,batch in enumerate(val_dl):
            X, y = batch
            X, y = X.to(device), y.to(device)

            X = img_to_patch(X, patch_size)
            logits = model(X)  # (B, num_classes)
            accuracy = (logits.argmax(dim=-1) == y).float().mean()
            total_val_accuracy += accuracy.item()
        print(f'Epoch {epoch+1}, Training Accuracy: {total_accuracy/len(train_dl):.4f}, Validation Accuracy: {total_val_accuracy/len(val_dl):.4f}')