Implementing the No Propagation Paper (Part 2)
Whew…that was a very loooong break. I think more than one month has past since my last articles. I would say that I’m very busy. Has to juggle between work, study and my passion to build and run as many AI experiments as I can.
But enough about me. You don’t open Medium or Substack to read about my personal grievances. You want to learn and deep dive into the world of artificial intelligence. And this time, we will pick up where we left off. We will continue with the No Propagation paper.
A bit of recap…
So, last time, we explore the overall concept of No Propagation. The ideas was that, they believe a neural network can be train without forward and backward propagation. Traditionally backpropagation bring the entire error signal back through the network. Each layers, nodes will have different contributions to the error, and will be updated based on that contribution.
So what did No Prop do? No Prop takes a different approach. Instead of sending the error signal backwards through all the layers, No Prop updates each layer’s weights separately. It does this by making a small, random change to the layer’s weights and checking how much the loss changes. If the change improves the loss, it adjusts the weights in that direction.
Think of No Propagation as an office full of workers. Each workers represents each layer in the neural network. Each workers will have a different tasks during the neural network training. As opposed to a production line which will have to wait for the person before them to complete their task before they can put their hand into it. But No Propagation cuts into this sequential process. Each workers has been given specific part of the training to complete, and they will have to complete it via local denoising steps. Kind of like diffusion in neural network.
We have gone extensively last time on the mathematics of No Propagation. That’s only for our foundational understanding. This time, we will go straight to implementation, using the examples that is provided from GitHub. You may check on it yourself here, as I already put it in my repo:
https://github.com/maercaestro/megat-noprop
The Three Main Recipe
There are three main variations of No Prop training as mentioned in the paper. It can be detailed out as below:
No Prop Discrete Time (DT) — Performing the denoising steps in discrete time.
No Prop Continous Time (CT) — A bit more complex. Performing the denoising steps in continous time.
No Prop Flow Matching (FM) — Alternative approach where it try to learn flow from noise to target label
But the basic building blocks of all these three variations is really the denoising block. As mention before, No Propagation begins with the proses of adding noise to our target label. Each layers of neural network will have a specific representation of gaussian diffusion noise added to the target label. And then, during training, all these layers will try to denoiste the target label until they reach as low error as possible.
So, in order to do that we need to build our denoising block. Since we are using MNIST handwritten digit dataset for our training, we will be using CNN to process the image. This will help in extracting the features that will guide the training process later. The denoising process will be connected via simple fully connected layers. Both of these will be combined later to form a tensor that will have both the clean image features and noisy image features.
Let’s build the simple CNN first. It will be typical 2D convolution, since we’re only handling black and white. It has 3 layers, followed by max pooling and lastly by average pooling before we flatten it as output. Let’s build this
import torch
import torch.nn as nn
import torch.nn.functional as F
class DenoiseBlock(nn.Module):
def __init__(self, embedding_dim: int, num_classes: int):
super().__init__()
# --- Image Processing Path ---
self.conv_block = nn.Sequential(
nn.Conv2d(1, 32, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Dropout(0.2),
nn.Conv2d(32, 64, 3, padding=1),
nn.ReLU(),
nn.Dropout(0.2),
nn.Conv2d(64, 128, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Dropout(0.2),
nn.AdaptiveAvgPool2d((1, 1)),
nn.Flatten(),
nn.Linear(128, 256),
nn.BatchNorm1d(256)
)
Let’s follow this by creating our simple Fully Connected layers. This will be for our noisy label features. We use Linear and Batch Normalization for each of the layers
# --- Noisy Label Processing Path ---
self.fc_z1 = nn.Linear(embedding_dim, 256)
self.bn_z1 = nn.BatchNorm1d(256)
self.fc_z2 = nn.Linear(256, 256)
self.bn_z2 = nn.BatchNorm1d(256)
self.fc_z3 = nn.Linear(256, 256)
self.bn_z3 = nn.BatchNorm1d(256)
Both of these Convolution block and FC layers will be combined
# --- Fusion and Output Path ---
self.fc_f1 = nn.Linear(256 + 256, 256)
self.bn_f1 = nn.BatchNorm1d(256)
self.fc_f2 = nn.Linear(256, 128)
self.bn_f2 = nn.BatchNorm1d(128)
self.fc_out = nn.Linear(128, num_classes)
Alright, let’s create our Forward method to pass all this data in the correct flow. First we get the noisy features and the image features. We combine both of them to be used to predict the clean labels. This will be used in our training script later.
def forward(self, x: torch.Tensor, z_prev: torch.Tensor, W_embed: torch.Tensor) -> tuple:
# 1. Get image features
x_feat = self.conv_block(x)
# 2. Get noisy label features
h1 = F.relu(self.bn_z1(self.fc_z1(z_prev)))
h2 = F.relu(self.bn_z2(self.fc_z2(h1)))
h3 = self.bn_z3(self.fc_z3(h2))
z_feat = h3 + h1 # Residual connection
# 3. Combine features
h_f = torch.cat([x_feat, z_feat], dim=1)
# 4. Predict clean label
h_f = F.relu(self.bn_f1(self.fc_f1(h_f)))
h_f = F.relu(self.bn_f2(self.fc_f2(h_f)))
logits = self.fc_out(h_f)
p = F.softmax(logits, dim=1)
z_next = p @ W_embed # Weighted average of clean embeddings
return z_next, logits
Alright. That’s the step by step in building our foundations of denoising blocks. We can save this as blocks.py or model.py and we will call that later in our training script. Now, let’s look at the training script.
Training Run
As mentioned above, there are three ways to perform No Propagation training. The key is on how we handle the time steps. In dt mode, we denoise each individual blocks/layers at each discrete time steps. We can put it in code as below:
def train_nopropdt(model, train_loader, test_loader, epochs, lr, weight_decay):
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
history = {'train_acc': [], 'val_acc': []}
for epoch in range(1, epochs + 1):
model.train()
for t in tqdm(range(model.T)):
for x, y in train_loader:
x, y = x.to(device), y.to(device)
uy = model.W_embed[y]
alpha_bar_t = model.alpha_bar[t]
noise = torch.randn_like(uy)
z_t = torch.sqrt(alpha_bar_t) * uy + torch.sqrt(1 - alpha_bar_t) * noise
# Here's our DenoiseBlock in action!
z_pred, _ = model.blocks[t](x, z_t, model.W_embed)
loss_l2 = F.mse_loss(z_pred, uy)
loss = 0.5 * model.eta * model.snr_delta[t] * loss_l2
if t == model.T - 1:
logits = model.classifier(z_pred)
loss_ce = F.cross_entropy(logits, y)
loss_kl = 0.5 * uy.pow(2).sum(dim=1).mean()
loss = loss + loss_ce + loss_kl
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
As you can see, the z_t is our our noisy targets. We input that to our blocks to made z_pred. Then the z_pred will be compared with our uy to get our l2 loss. The full loss will be our l2_loss, coupled with SNR (signal-to-noise ratio) and eta, which is the weight for our denoising loss.
Both SNR and ETA is the hyperparameter that can be configured and search to improve our training run.
The difference with ct is that instead of discrete time steps that we set in the loop (as you can see in for t in tqdm line), we sample a random time steps for each of the blocks to denoise first. And the denoising across the blocks/layers happen at different time steps for each blocks. The noise is regulated via the same hyperparameter that we have with dt. It’s just that the snr will play a major role of carrying the time steps information. The snr used in the ct loss is the derivative of the snr. We call it the snr_prime.
def train_nopropct(model, train_loader, test_loader, epochs, lr, weight_decay, inference_steps=1000):
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
history = {'train_acc': [], 'val_acc': []}
for epoch in range(1, epochs + 1):
model.train()
for x, y in tqdm(train_loader):
x, y = x.to(device), y.to(device)
B = x.size(0)
uy = model.W_embed[y]
# Sample t ∈ [0, 1]
t = torch.rand(B, 1, device=device)
# Compute ᾱ(t)
alpha_bar_t = model.alpha_bar(t)
# Sample z_t ~ N(√ᾱ(t) * uy, (1 - ᾱ(t)) * I)
noise = torch.randn_like(uy)
z_t = torch.sqrt(alpha_bar_t) * uy + torch.sqrt(1 - alpha_bar_t) * noise
# Predict and compute loss using the shared block
z_pred = model.forward_denoise(x, z_t, t)
snr_prime_t = model.snr_prime(t)
loss_l2 = F.mse_loss(z_pred, uy)
loss = 0.5 * model.eta * snr_prime_t.mean() * loss_l2
# Final classifier loss (optional)
logits = model.classifier(z_pred)
loss_ce = F.cross_entropy(logits, y)
loss_kl = 0.5 * uy.pow(2).sum(dim=1).mean()
loss += loss_ce + loss_kl
optimizer.zero_grad()
loss.backward()
optimizer.step()
In terms of flow matching, the method is totally different. Instead of denoising, it try to find the latent trajectory between our noise and our targets.
Think of it like a network, where our noise is our input and the target is our output. Flow matching method will determine the path that we should go to get as near to our target as possible.
So, instead of denoising loss, our loss is now the vector field. Specifically the difference between current vector field and the predicted vector field. We can get the predicted vector field using the same denoising blocks that we have made, and with that get our l2 loss.
def train_nopropfm(model, train_loader, test_loader, epochs, lr, weight_decay, inference_steps=1000):
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
history = {'train_acc': [], 'val_acc': []}
for epoch in range(1, epochs + 1):
model.train()
for x, y in tqdm(train_loader):
x, y = x.to(device), y.to(device)
B = x.size(0)
z1 = model.W_embed[y] # class embeddings
z0 = torch.randn_like(z1)
t = torch.rand(B, 1, device=device)
z_t = t * z1 + (1 - t) * z0
v_target = z1 - z0
# Vector field prediction
v_pred = model.forward_vector_field(x, z_t, t)
loss_l2 = F.mse_loss(v_pred, v_target)
# Extrapolate z̃₁ and compute classification loss
z1_hat = model.extrapolate_z1(z_t, v_pred, t)
logits = model.classifier(z1_hat)
loss_ce = F.cross_entropy(logits, y)
loss = loss_l2 + loss_ce
optimizer.zero_grad()
loss.backward()
optimizer.step()
Please also take note that we also have to get the extrapolated position in the vector field, which is the z1_hat. That z1_hat will allows us to get our cross entropy loss. The difference between our current prediction and the target.
So. there it is. Our full method of implementation for No Propagation. As you can see there’s an ingenuity here. Especially in providing an alternative way to train backpropagation.
Alright, you guys can try this yourself. My repo is as above. Anyway, good luck. And let us see each other again next time.