MegaQwen: Hitting the Architectural Ceiling of Cooperative CUDA Megakernels
A deep dive into building and exhaustively optimizing a fused transformer megakernel. We achieved 530 tok/s on Qwen3-0.6B - then discovered why that's the limit.

Results
| Backend | Short Context | Long Context | vs HuggingFace |
|---|---|---|---|
| Megakernel | 530 tok/s | 158 tok/s | 3.9x |
| TensorRT-LLM | 355 tok/s | 355 tok/s | 2.6x |
| vLLM | - | 107 tok/s | 1.8x |
| HuggingFace | 136 tok/s | 59 tok/s | 1.0x |
The megakernel beats TensorRT-LLM at short contexts (530 vs 355 tok/s) with zero compilation overhead. But more interesting than the result is why we couldn't go faster.
The Architecture
A megakernel fuses an entire transformer block into one CUDA kernel launch. All 28 layers of Qwen3-0.6B run as one kernel invocation, with intermediate activations staying in registers and L2 cache instead of global memory.
The challenge: transformer layers have dependencies. Attention can't start until QKV projection finishes. We solve this with cooperative groups - CUDA's mechanism for grid-wide synchronization:
#include <cooperative_groups.h>
namespace cg = cooperative_groups;
__global__ void megakernel(...) {
cg::grid_group grid = cg::this_grid();
for (int layer = 0; layer < 28; layer++) {
// All 82 blocks compute QKV projection
compute_qkv(block_id, ...);
grid.sync(); // Wait for all blocks
// 16 blocks compute attention (one per Q head)
// Other 66 blocks prefetch MLP weights
if (block_id < 16) {
compute_attention(block_id, ...);
} else {
prefetch_mlp_weights(block_id - 16, ...);
}
grid.sync();
// All blocks compute MLP
compute_mlp(block_id, ...);
grid.sync();
}
}Each decode step requires ~225 grid.sync() calls (8 per layer x 28 layers + extras). This becomes important later.
The Optimization Journey
After the initial implementation hit 170 tok/s, we spent weeks trying every optimization we could think of.
What Worked
| Optimization | Speedup | Description |
|---|---|---|
| Block Divergence + L2 Prefetch | +2x | During attention (16 blocks working), 66 idle blocks prefetch MLP weights into L2 cache |
| Redundant RMSNorm | +42% (short ctx) | All blocks compute RMSNorm redundantly, eliminating 56 syncs |
| 128-bit Vectorized Loads | +3.5% | uint4 loads instead of uint2 for better coalescing |
The block divergence optimization was the breakthrough. During attention, only 16 of 82 blocks actually compute (one per Q head). The other 66 blocks were sitting idle at grid.sync(). By having them prefetch MLP weights using __ldg(), those weights are already in L2 cache when MLP starts.
What Didn't Work
| Optimization | Result | Why It Failed |
|---|---|---|
| Warp Producer/Consumer Split | 0% | Reducing compute warps hurt more than prefetching helped |
| Shared Memory Caching | 0% | L1/L2 cache already effective; extra __syncthreads() overhead |
| cp.async Double-Buffering | +1% | Can't overlap enough compute with memory loads |
| Atomic Counter Sync | Inconclusive | Compilation too slow to benchmark |
The warp specialization experiment was particularly instructive. We hypothesized that dedicating some warps to prefetching while others compute would help:
| Producer:Consumer Ratio | Avg tok/s |
|---|---|
| 0:8 (all compute) | 509.9 |
| 1:7 | 510.8 |
| 2:6 | 489.9 |
| 4:4 (half and half) | 478.6 |
All compute is optimal. This was the first hint that we weren't bandwidth-bound.
The Root Cause
After exhaustive optimization, we finally profiled properly:
| Metric | Value |
|---|---|
| Effective memory bandwidth | ~47 GB/s |
| Peak memory bandwidth | 936 GB/s |
| Bandwidth utilization | 5% |
| grid.sync() calls per token | 140+ |
| Sync latency each | ~0.7 us |
We're using only 5% of available memory bandwidth.
The kernel isn't memory-bound. It's latency-bound by grid synchronization. With 140+ grid.sync() calls at ~0.7us each, we spend ~100us per token just waiting at barriers.
This explains why warp-level prefetching didn't help: we're not waiting on memory, we're waiting on synchronization. Adding more prefetching when the bottleneck is sync overhead is pointless.
Cooperative Groups vs CUDA Graphs
A natural question: why not split the kernel at grid.sync() points and use CUDA graphs instead?
| Approach | Time | Per-op Cost |
|---|---|---|
| Cooperative + 225 grid.sync() | 167.3 us | 0.73 us/sync |
| CUDA graph (225 kernels) | 186.9 us | 0.83 us/kernel |
| Regular kernel launches | 347.5 us | 1.54 us/launch |
Cooperative groups actually wins by 19.7 us over CUDA graphs for pure sync overhead. But the bigger issue: splitting would require writing all intermediate buffers to global memory between kernels. That's ~2.7 MB of extra memory traffic per token - 6x more costly than the sync overhead difference.
The Architectural Ceiling
~530 tok/s is the architectural ceiling for batch=1 bf16 cooperative megakernels on RTX 3090.
To exceed this limit, you need to fundamentally change the approach:
| Approach | Expected Gain | Difficulty |
|---|---|---|
| INT4 Quantization | ~4x | Medium |
| Non-cooperative architecture | Unknown | High (major rewrite) |
| Speculative decoding | ~2-4x | Medium |
Fair Comparison (Devil's Advocate)
Credit where it's due: TensorRT-LLM, vLLM, SGLang, and other frameworks are excellently optimized for production workloads with dynamic shapes, variable batch sizes, and long contexts. This megakernel exploits several advantages they intentionally don't:
- Static shapes: All dimensions (hidden size, head count, MLP width) are compile-time constants. Production frameworks must handle arbitrary model architectures at runtime.
- Short context bias: The benchmarks favor position 1-100 where KV cache overhead is minimal. At longer contexts, TensorRT-LLM's consistent 355 tok/s beats the megakernel's degradation to 158 tok/s.
- Single model, single GPU: No tensor parallelism, no continuous batching, no dynamic memory allocation. Real serving systems need all of these.
- Learning exercise: This project was built to understand GPU optimization, not to replace production inference engines.
The speedup is real, but it comes from exploiting a narrow regime (batch=1, short context, static shapes) where the texture cache (__ldg()) provides massive benefits by keeping weights in the read-only cache path while L1/L2 handles activations. Production frameworks can't make these assumptions.
TL;DR: Use TensorRT-LLM or vLLM for production. Use this to learn how GPUs actually work.
Lessons Learned
- Profile before optimizing. We assumed memory bandwidth was the issue. It wasn't. The 5% bandwidth utilization revealed we're latency-bound.
- Cooperative kernels have inherent limits.
grid.sync()overhead dominates at high sync counts. For 140+ syncs per token, this is unavoidable. - Block-level parallelism matters. The +2x from L2 prefetching came from utilizing idle blocks during attention. This was the only substantial win.
- GPU caches are better than you think. L1/L2 handle repeated reads effectively. Explicit shared memory caching often adds overhead without benefit.
- SASS analysis is essential. Looking at actual assembly revealed whether optimizations (like 128-bit loads) were actually being applied.
Try It
git clone https://github.com/Infatoshi/megakernels.git
cd megakernels
uv venv && source .venv/bin/activate
uv pip install torch --index-url https://download.pytorch.org/whl/cu121
uv pip install transformers triton
# Interactive chat
python chat.py
# Run benchmarks
python experiments/framework_bench/benchmark_suite.py
# Verify correctness
python verify_correctness.py- GitHub: github.com/Infatoshi/megakernels
- Full optimization journey: DEVLOG.md
- Benchmark data:
experiments/RESULTS.md - Memory analysis:
docs/MEMORY_ANALYSIS.md
February 2026