← Back to blog

MegaQwen: Hitting the Architectural Ceiling of Cooperative CUDA Megakernels

A deep dive into building and exhaustively optimizing a fused transformer megakernel. We achieved 530 tok/s on Qwen3-0.6B - then discovered why that's the limit.

MegaQwen - Fused transformer block megakernel

Results

BackendShort ContextLong Contextvs HuggingFace
Megakernel530 tok/s158 tok/s3.9x
TensorRT-LLM355 tok/s355 tok/s2.6x
vLLM-107 tok/s1.8x
HuggingFace136 tok/s59 tok/s1.0x

The megakernel beats TensorRT-LLM at short contexts (530 vs 355 tok/s) with zero compilation overhead. But more interesting than the result is why we couldn't go faster.

The Architecture

A megakernel fuses an entire transformer block into one CUDA kernel launch. All 28 layers of Qwen3-0.6B run as one kernel invocation, with intermediate activations staying in registers and L2 cache instead of global memory.

The challenge: transformer layers have dependencies. Attention can't start until QKV projection finishes. We solve this with cooperative groups - CUDA's mechanism for grid-wide synchronization:

#include <cooperative_groups.h>
namespace cg = cooperative_groups;

__global__ void megakernel(...) {
    cg::grid_group grid = cg::this_grid();

    for (int layer = 0; layer < 28; layer++) {
        // All 82 blocks compute QKV projection
        compute_qkv(block_id, ...);
        grid.sync();  // Wait for all blocks

        // 16 blocks compute attention (one per Q head)
        // Other 66 blocks prefetch MLP weights
        if (block_id < 16) {
            compute_attention(block_id, ...);
        } else {
            prefetch_mlp_weights(block_id - 16, ...);
        }
        grid.sync();

        // All blocks compute MLP
        compute_mlp(block_id, ...);
        grid.sync();
    }
}

Each decode step requires ~225 grid.sync() calls (8 per layer x 28 layers + extras). This becomes important later.

The Optimization Journey

After the initial implementation hit 170 tok/s, we spent weeks trying every optimization we could think of.

What Worked

OptimizationSpeedupDescription
Block Divergence + L2 Prefetch+2xDuring attention (16 blocks working), 66 idle blocks prefetch MLP weights into L2 cache
Redundant RMSNorm+42% (short ctx)All blocks compute RMSNorm redundantly, eliminating 56 syncs
128-bit Vectorized Loads+3.5%uint4 loads instead of uint2 for better coalescing

The block divergence optimization was the breakthrough. During attention, only 16 of 82 blocks actually compute (one per Q head). The other 66 blocks were sitting idle at grid.sync(). By having them prefetch MLP weights using __ldg(), those weights are already in L2 cache when MLP starts.

What Didn't Work

OptimizationResultWhy It Failed
Warp Producer/Consumer Split0%Reducing compute warps hurt more than prefetching helped
Shared Memory Caching0%L1/L2 cache already effective; extra __syncthreads() overhead
cp.async Double-Buffering+1%Can't overlap enough compute with memory loads
Atomic Counter SyncInconclusiveCompilation too slow to benchmark

The warp specialization experiment was particularly instructive. We hypothesized that dedicating some warps to prefetching while others compute would help:

Producer:Consumer RatioAvg tok/s
0:8 (all compute)509.9
1:7510.8
2:6489.9
4:4 (half and half)478.6

All compute is optimal. This was the first hint that we weren't bandwidth-bound.

The Root Cause

After exhaustive optimization, we finally profiled properly:

MetricValue
Effective memory bandwidth~47 GB/s
Peak memory bandwidth936 GB/s
Bandwidth utilization5%
grid.sync() calls per token140+
Sync latency each~0.7 us

We're using only 5% of available memory bandwidth.

The kernel isn't memory-bound. It's latency-bound by grid synchronization. With 140+ grid.sync() calls at ~0.7us each, we spend ~100us per token just waiting at barriers.

This explains why warp-level prefetching didn't help: we're not waiting on memory, we're waiting on synchronization. Adding more prefetching when the bottleneck is sync overhead is pointless.

Cooperative Groups vs CUDA Graphs

A natural question: why not split the kernel at grid.sync() points and use CUDA graphs instead?

ApproachTimePer-op Cost
Cooperative + 225 grid.sync()167.3 us0.73 us/sync
CUDA graph (225 kernels)186.9 us0.83 us/kernel
Regular kernel launches347.5 us1.54 us/launch

Cooperative groups actually wins by 19.7 us over CUDA graphs for pure sync overhead. But the bigger issue: splitting would require writing all intermediate buffers to global memory between kernels. That's ~2.7 MB of extra memory traffic per token - 6x more costly than the sync overhead difference.

The Architectural Ceiling

~530 tok/s is the architectural ceiling for batch=1 bf16 cooperative megakernels on RTX 3090.

To exceed this limit, you need to fundamentally change the approach:

ApproachExpected GainDifficulty
INT4 Quantization~4xMedium
Non-cooperative architectureUnknownHigh (major rewrite)
Speculative decoding~2-4xMedium

Fair Comparison (Devil's Advocate)

Credit where it's due: TensorRT-LLM, vLLM, SGLang, and other frameworks are excellently optimized for production workloads with dynamic shapes, variable batch sizes, and long contexts. This megakernel exploits several advantages they intentionally don't:

  1. Static shapes: All dimensions (hidden size, head count, MLP width) are compile-time constants. Production frameworks must handle arbitrary model architectures at runtime.
  2. Short context bias: The benchmarks favor position 1-100 where KV cache overhead is minimal. At longer contexts, TensorRT-LLM's consistent 355 tok/s beats the megakernel's degradation to 158 tok/s.
  3. Single model, single GPU: No tensor parallelism, no continuous batching, no dynamic memory allocation. Real serving systems need all of these.
  4. Learning exercise: This project was built to understand GPU optimization, not to replace production inference engines.

The speedup is real, but it comes from exploiting a narrow regime (batch=1, short context, static shapes) where the texture cache (__ldg()) provides massive benefits by keeping weights in the read-only cache path while L1/L2 handles activations. Production frameworks can't make these assumptions.

TL;DR: Use TensorRT-LLM or vLLM for production. Use this to learn how GPUs actually work.

Lessons Learned

  1. Profile before optimizing. We assumed memory bandwidth was the issue. It wasn't. The 5% bandwidth utilization revealed we're latency-bound.
  2. Cooperative kernels have inherent limits. grid.sync() overhead dominates at high sync counts. For 140+ syncs per token, this is unavoidable.
  3. Block-level parallelism matters. The +2x from L2 prefetching came from utilizing idle blocks during attention. This was the only substantial win.
  4. GPU caches are better than you think. L1/L2 handle repeated reads effectively. Explicit shared memory caching often adds overhead without benefit.
  5. SASS analysis is essential. Looking at actual assembly revealed whether optimizations (like 128-bit loads) were actually being applied.

Try It

git clone https://github.com/Infatoshi/megakernels.git
cd megakernels
uv venv && source .venv/bin/activate
uv pip install torch --index-url https://download.pytorch.org/whl/cu121
uv pip install transformers triton

# Interactive chat
python chat.py

# Run benchmarks
python experiments/framework_bench/benchmark_suite.py

# Verify correctness
python verify_correctness.py

February 2026