In modern large language models, one of the most frequent operations is Top-K sampling. After each forward pass, the model outputs a logits tensor with dimensions like [batch_size, vocab_size]. That vocab_size can be 50,000 or more. To sample the next token, we need to find the $k$ (typically 2,048) most probable tokens from that massive vector. This happens for every single token we generate, billions of times per day in production systems.
The performance of this one operation can make or break your inference latency. PyTorch’s torch.topk takes about 1.15ms on a modern GPU (tested on an H100) for a typical workload. Today, we’ll build a kernel that does it in 0.112ms—a 10x improvement. Along the way, we may discover some interesting quirks of the NVIDIA CUDA compiler (nvcc) and how even here, we have the issue of unoptimized assembly (PTX) emitted by the compiler, even at -O3.
Naive Sorting…
Let’s start with the obvious approach that every CPU programmer would reach for: sort the array and take the first k elements.
// DON'T DO THIS!!
thrust::sort(logits, logits + vocab_size, thrust::greater<float>());
// Take first 2048 elements
On a GPU, this is catastrophically slow. Our benchmarks (using Thrust for the naive sort) show this takes 23ms: over 200x slower than our optimized kernel. Why? You’re sorting 50,000 elements to keep 2,048 and throw away 47,952. That’s 96% wasted work. Worse, sorting requires extensive inter-thread communication and synchronization, which GPUs handle poorly.
Even PyTorch’s torch.topk, which uses a more sophisticated heap-based approach, takes 1.15ms. We can do much better.
Filter, Don’t Sort
We don’t really need a full sort. We just need to find the threshold value, the 2,048th largest element, and collect everything above it. Here’s our strategy:
- Build a histogram with 512 bins in shared memory (one pass over the data)
- Run a prefix sum to convert counts to offsets
- Find the threshold bin that contains the 2,048th element
- Collect elements above the threshold directly, and those in the threshold bin separately
- Sort only the threshold bin (typically just a few thousand elements)
This transforms an O(N log N) problem into two O(N) passes plus one tiny O(k’ log k’) sort, where k’ is the size of the threshold bin.

FP16 Binning
One of the most important parts of this kernel is how it converts floating-point values into histogram bins while preserving order. This function deserves careful study:
static inline __device__ uint16_t extractBinIdx(float x) {
union {
__half h;
uint16_t u16;
} tmp;
tmp.h = __float2half_rn(x); // Convert to FP16
tmp.u16 = (x < 0.f) ? (~tmp.u16 & 0xffff) : (tmp.u16 | 0x8000);
return 511 - (tmp.u16 >> 7);
}
This exploits a property of IEEE 754: positive floating-point numbers, when interpreted as integers, sort in the same order as their float values. The trick is handling negative numbers; they sort backwards as integers, so we flip their bits.
Let’s examine the PTX emitted here:
{ cvt.rn.f16.f32 %rs7, %f6;} // Convert float32 to float16
setp.lt.f32 %p5, %f6, 0f00000000; // Check if negative
cvt.u32.u16 %r181, %rs7; // Convert to integer
not.b32 %r182, %r181; // Bitwise NOT (invert bits) if negative
or.b32 %r183, %r181, 32768; // Set sign bit if positive
selp.b32 %r184, %r182, %r183, %p5; // Select based on sign
The crucial instruction is selp.b32 (select with predicate). A naive compiler might generate a branch here, which would be devastating. Divergent branches within a warp force serialization. Instead, selp keeps all threads synchronized, executing in a single cycle.
Note: For FP16 logits (common in optimized inference), you can adapt this by skipping the float-to-half conversion and adjusting bit manipulations accordingly.
Why Do Bit Tricks Work?
The IEEE 754 standard. It’s the most widely used specification for representing and manipulating floating-point numbers in computers, but I won’t bore you with the details. Check out it’s wikipedia page.
At its core, IEEE 754 uses a binary scientific notation-like format to store numbers. A floating-point number is expressed as:
$ (-1)^{\text{sign}} \times \text{mantissa} \times 2^{\text{exponent}} $
This allows for a wide dynamic range while approximating continuous values. The standard defines several formats, but the most common are:
- Half-Precision (FP16 or Binary16): 16 bits total. Used in machine learning and graphics for reduced memory and computation, but with limited range and precision.
- Single-Precision (FP32 or Binary32): 32 bits total. The default for many applications, offering about 7 decimal digits of precision.
- Double-Precision (FP64 or Binary64): 64 bits total. Provides higher precision (about 15 decimal digits) for scientific computing.
- Other formats like Quad-Precision (128 bits) exist but are less common.
Each format divides the bits into three fields: sign bit, exponent, and mantissa (also called significand or fraction).
Bit Layout
- Sign Bit: Always the most significant bit (MSB). 0 for positive, 1 for negative. This allows simple sign flipping via bit manipulation.
- Exponent: A biased integer to represent both positive and negative powers of 2. The bias shifts the range to avoid signed exponents (e.g., bias = 127 for FP32, so stored exponent 127 represents actual exponent 0).
- Mantissa: The fractional part, normalized to start with an implied leading 1 (for most values), maximizing precision.
Specific layouts:
- FP16 (Half-Precision): 1 sign bit + 5 exponent bits (bias 15) + 10 mantissa bits. Range: ≈ ±6.1 × 10⁻⁵ to ±65504. Precision: ~3-4 decimal digits.
- FP32 (Single-Precision): 1 sign bit + 8 exponent bits (bias 127) + 23 mantissa bits. Range: ≈ ±1.4 × 10⁻⁴⁵ to ±3.4 × 10³⁸. Precision: ~7 decimal digits.
- FP64 (Double-Precision): 1 sign bit + 11 exponent bits (bias 1023) + 52 mantissa bits. Range: ≈ ±4.9 × 10⁻³²⁴ to ±1.8 × 10³⁰⁸. Precision: ~15 decimal digits.
For example, in FP32, the number 1.0 is represented as:
- Sign: 0
- Exponent: 01111111 (127 biased, actual 0)
- Mantissa: 00000000000000000000000 (implied 1.0)
- Binary: 00111111100000000000000000000000 (hex: 0x3F800000)
Normalization and Denormalized Numbers
In normalized form, the mantissa is in [1, 2), with an implied leading 1 not stored (saving a bit). This ensures efficient use of bits. For instance, 85.125 (binary 1010101.001) normalizes to 1.010101001 × 2⁶.
Denormalized (subnormal) numbers occur when the exponent is at its minimum (all zeros), losing the implied 1. They represent tiny values close to zero, like 1.4 × 10⁻⁴⁵ in FP32, but with reduced precision (gradual underflow).
Special Values
IEEE 754 handles edge cases gracefully:
- Zero: Exponent and mantissa all zeros. +0 (sign 0) and -0 (sign 1) are distinct but compare equal in most operations.
- Infinity (±∞): Exponent all ones, mantissa all zeros. Results from overflow (e.g., 1.0 / 0.0 = +∞).
- Not a Number (NaN): Exponent all ones, mantissa non-zero. Indicates invalid operations (e.g., 0/0 or √(-1)). Quiet NaNs (qNaN) propagate silently; signaling NaNs (sNaN) raise exceptions.
- Other specials like subnormals bridge zero and the smallest normalized number.
Why Bit Patterns Allow Ordering for Positive Floats
A key property exploited in our binning function is that for positive numbers (sign bit = 0), the IEEE 754 bit pattern, when interpreted as an unsigned integer, matches the numerical order. This happens because the exponent is the most significant field after the sign. Larger exponents mean larger numbers, and since exponents are biased positively, higher bit values in the exponent field correspond to higher powers of 2. For the same exponent, the mantissa acts as a fractional offset within the exponent’s “window” (e.g., [2^e, 2^{e+1})). Larger mantissa bits mean a larger offset, so higher integer values. Also, the bias ensures no negative exponents disrupt ordering—everything is shifted to positive integers.
Example:
- 2.0 (FP32): 0 10000000 00000000000000000000000 (integer: 1073741824)
- 4.0 (FP32): 0 10000001 00000000000000000000000 (integer: 1082130432)
- Since 1082130432 > 1073741824, 4.0 > 2.0 holds when comparing bits as ints.
This enables efficient sorting and comparisons without decoding—perfect for GPU kernels where bit ops are cheap.
Handling Signs and Negatives
Negatives flip the sign bit to 1, but their bit patterns (as two’s complement-like) reverse the order when interpreted as signed integers. For example:
- +1.0: 0x3F800000
- -1.0: 0xBF800000 (sign bit set)
As unsigned ints, -1.0’s pattern is larger, but numerically smaller. To sort negatives correctly, our binning code inverts their bits (~ operator) and masks, effectively mapping negatives to a range where lower values (more negative) get smaller bin indices.
In FP16, the same principles apply but with fewer bits: sign 1, exponent 5 (bias 15), mantissa 10. The ordering property holds for positives, but range is limited (max ~65504), and precision drops (e.g., can’t distinguish below ~0.00006). For LLM logits (often positive probabilities after softmax), negatives are rare, but handling them ensures robustness.
Practical Implications
Repeating decimals like 0.1 (binary 0.000110011…) can’t be exact, leading to errors (e.g., 0.1 + 0.2 ≠ 0.3). IEEE 754 defines modes like round-to-nearest (default), affecting conversions like float-to-half. In most AI workloads, FP16 is popular for inference, but watch for overflow. Alternatives like BF16 (1 sign + 8 exp + 7 mantissa) preserve FP32 range with less precision, useful for deep learning.
Sorry for boring you with the detail, but it’s important for later discussions on extractBinIdx, where we use ordering to bin values efficiently without full comparisons.
The Histogram Pass
Here’s where things get interesting. The histogram building loop is simple:
for (int rowIt = rowStart + threadIdx.x; rowIt < rowEnd;
rowIt += kNumThreadsPerBlock) {
float logit = logits[rowIdx * stride0 + rowIt * stride1];
uint16_t idx = extractBinIdx(logit);
atomicAdd(&smemHistogram[idx], 1);
}
The compiler, by default, unrolls this loop 4x. When we examine the PTX, we see:
$L__BB0_9:
ld.global.f32 %f7, [%rd59]; // Load 1
// ... binning logic ...
atom.shared.add.u32 %r203, [%r202], 1;
add.s64 %rd45, %rd59, %rd8;
ld.global.f32 %f8, [%rd45]; // Load 2
// ... three more iterations ...
add.s32 %r741, %r741, 2048; // Increment by 4 * 512
setp.lt.s32 %p12, %r741, %r2;
@%p12 bra $L__BB0_9;
The compiler chose 4x unrolling, presumably to hide memory latency. But here’s the surprise from our benchmarks:
Histogram (1x Unrolled): 0.161 ms
Histogram (4x Compiler): 0.373 ms (2.3x slower!)
Histogram (8x Forced): 0.112 ms (fastest)
The compiler’s choice is suboptimal! Let’s try to understand why.
Why 8x Unrolling Wins
The key, I believe, is memory latency hiding. A global memory load takes 200-400 cycles on modern GPUs. With 1x unrolling, we have a tight dependency chain: load → compute bin → atomic add → next iteration. The GPU stalls waiting for each load.
With 8x unrolling, we can have 8 loads in flight simultaneously. While the memory subsystem fetches the first element, the SM can issue loads 2 through 8. By the time we’re computing the bin for element 1, elements 2-8 are already arriving in registers. This is called Memory-Level Parallelism (MLP).
But why is 4x slower than 1x? The compiler also performs an optimization when generating the 4x unrolled version. It fuses the array indexing calculation into the binning math:
// The compiler transforms (511 - (u16 >> 7)) * 4 into:
shr.u32 %r185, %r184, 5; // Shift by 5, not 7!
not.b32 %r186, %r185;
and.b32 %r187, %r186, 2044; // 2044 = 511 * 4
This algebraic transformation saves instructions but might introduce some hard-to-notice performance penalties. The different shift amount could affect instruction scheduling or create unexpected dependencies. Additionally, 4x unrolling might not provide enough in-flight operations to fully hide memory latency, while not being tight enough to benefit from cache locality like 1x.
To force 8x unrolling, we use:
#pragma unroll 8
for (int rowIt = rowStart + threadIdx.x; rowIt < rowEnd;
rowIt += kNumThreadsPerBlock) {
// ... same loop body ...
}
This achieves the optimal balance: enough in-flight loads to hide latency without overwhelming the register file or instruction cache. Pro tip: Use Nsight Compute to profile stalls and instruction throughput to confirm these hypotheses.
Do Warps Beat Blocks?
After building the histogram, we need to convert counts to offsets using a prefix sum. We use CUB’s BlockScan:
using Scan = cub::BlockScan<int, kNumThreadsPerBlock>;
__shared__ typename Scan::TempStorage smemScan;
int prefixSum{0}, totalSum{0};
Scan(smemScan).ExclusiveSum(binCount, prefixSum, totalSum);
The PTX shows us CUB’s sophisticated implementation:
// Warp-level reduction using shuffle
{ .reg .s32 r0; .reg .pred p;
shfl.sync.up.b32 r0|p, %r242, %r240, %r265, %r272;
@p add.s32 r0, r0, %r242;
mov.s32 %r238, r0;}
// Repeats with offsets: 1, 2, 4, 8, 16
Pay attention to the shfl.sync.up.b32 instruction. It allows threads within a warp to exchange data through registers, not shared memory. This warp-level scan completes in just 5 instructions (log₂32), versus dozens of shared memory operations.
For the block-level scan, CUB uses an interesting workspace size:
mad.lo.s32 %r274, %r5, 68, %r235; // 68 bytes per thread!
Why 68 bytes? I’m not entirely sure, but if I had to guess, this unusual number (not a power of 2) is likely chosen to avoid bank conflicts in shared memory while maximizing throughput. With 32 banks of 4 bytes each, accessing stride-68 means threads access different banks, so it may prevent serialization.
Parallel Beats Sequential
We need to find which bin contains the 2,048th element. The CPU programmer’s instinct is binary search:
// DON'T DO THIS - Sequential binary search
int low = 0, high = kNumBins - 1;
while (low <= high) {
int mid = (low + high) / 2;
// ... binary search logic ...
}
// Requires log₂(512) = 9 synchronizations!
Our benchmarks show binary search performs nearly identically to our parallel approach at 0.377ms—close enough that for small bin counts like 512, the log(N) syncs don’t add much overhead. The GPU way is much simpler:
// DO THIS - Parallel O(1) search
if (threadIdx.x < kNumBins) {
int nextPrefixSum = (threadIdx.x == kNumBins - 1) ?
totalSum : smemHistogram[threadIdx.x + 1];
if (prefixSum < kTopK && nextPrefixSum >= kTopK) {
smemThresholdBinIdx[0] = threadIdx.x;
}
}
All 512 threads check their bin simultaneously. One thread finds the threshold and writes it. The PTX shows perfect predication:
setp.lt.s32 %p19, %r744, 2048;
setp.gt.s32 %p20, %r21, 2047;
or.pred %p21, %p20, %p19;
@%p21 bra $L__BB0_22; // Skip if not threshold
st.shared.u32 [smemThresholdBinIdx], %r5;
No divergence, one pass, done. GPUs are cool.
Collection Phase
Now we make a second pass to collect elements:
for (int rowIt = rowStart + threadIdx.x; rowIt < rowEnd;
rowIt += kNumThreadsPerBlock) {
float logit = logits[...];
uint16_t idx = extractBinIdx(logit);
if (idx < thresholdBinIdx) {
// Guaranteed top-k, collect directly
int dstIdx = atomicAdd(&smemHistogram[idx], 1);
smemIndices[dstIdx] = rowIt;
} else if (idx == thresholdBinIdx) {
// Threshold bin, needs sorting
int dstIdx = atomicAdd(&smemFinalDstIdx[0], 1);
if (dstIdx < kNumFinalItems) { // CRITICAL!
smemFinal.items.logits[dstIdx] = logit;
smemFinal.items.indices[dstIdx] = rowIt;
}
}
}
The overflow check if (dstIdx < kNumFinalItems) is non-negotiable in production. What if the model outputs 5,000 identical logits that all hash to the same bin? Without this check, we’d corrupt memory and crash.
We allocate space for 3,072 elements (50% margin over 2,048), making overflow extremely unlikely while keeping memory usage reasonable.
Note: In our naive sort implementation (for benchmarking), there’s a potential race if multiple threads target the same minIdx—mitigated by per-thread min scans, but consider locks or reductions for robustness in similar code.
Radix Sort
For the threshold bin, we use CUB’s BlockRadixSort (I also didn’t want to spend too much time implementing that from scratch). The PTX shows some pretty advanced optimizations:
bfe.u32 %r456, %r457, %r773, %r497; // Bit field extract
The bfe instruction extracts a bit field in one cycle, perfect for radix sort’s digit extraction. This replaces the traditional shift-and-mask pattern.
CUB also uses “blocked-to-striped” layout transformation for better memory coalescing. The PTX even includes mysterious pragmas:
.pragma "used_bytes_mask 61695"; // 0xF0FF
This tells the hardware which bytes in a 128-byte cache line are actually used, allowing the memory subsystem to optimize transfers. This level of optimization is invisible from CUDA.
Register Pressure
The PTX header reveals something alarming:
.reg .b32 %r<785>; // 785 32-bit registers requested!
Hopper GPUs have 256 registers per thread maximum. The kernel is massively spilling to L1 cache. But this is intentional (I hope!), the compiler and CUB have decided that spilling to L1 (20-30 cycles) is worth it for maximum instruction-level parallelism. This kernel prioritizes throughput over occupancy. Test with --maxrregcount to tune if needed.
Some Benchmarks
Here’s how our implementations compare on a typical workload (batch_size=1024, vocab_size=50000, k=2048). Note: Performance scales well to larger vocabs (e.g., 100k) or smaller k, but test your specific setup.
| Implementation | Time (ms) | Relative to Fastest |
|---|---|---|
| Naive Sort | 23.117 | 206x slower |
| PyTorch torch.topk | 1.150 | 10.3x slower |
| Histogram (4x compiler default) | 0.373 | 3.3x slower |
| Histogram (Binary search) | 0.377 | 3.4x slower |
| Histogram (1x no unroll) | 0.161 | 1.4x slower |
| Histogram (8x unroll) | 0.112 | Fastest |

Interesting findings!
- Our best kernel beats PyTorch by 10x
- The compiler’s default 4x unrolling is suboptimal
- Binary search provides minimal benefit over parallel search for small bin counts
- The naive sort is unusably slow
Lessons Learned
I probably don’t need to say this outright, but you shouldn’t always trust the compiler at face-value. C/C++ programmers have grappled with this extensively for decades, but unfortunately inline PTX and just PTX-driven CUDA development is very rare in the ML space. The compiler isn’t omniscient, you should always benchmark your code.
Think parallel-first. Some algorithms may seem clever to your CPU-biased brain but they usually end up having no benefit on a GPU. When you have parallel threads, use them! Read the assembly! Profile everything!
Explore libraries like FlashInfer (which inspired parts of this) or Triton for even more optimized kernels.
Appendix A: Memory Bandwidth Analysis
Let’s calculate the theoretical efficiency. For 50,000 elements:
- Input: 50,000 × 4 bytes = 200 KB read twice = 400 KB
- Output: 2,048 × 4 bytes = 8 KB
- Total: 408 KB
At 0.112ms, that’s 3.64 GB/s. On an H100 with 3.35 TB/s bandwidth, we’re using only 0.1% of theoretical peak. This seems low, but remember:
- We’re latency-bound, not bandwidth-bound
- Atomic operations serialize some accesses
- The kernel is optimized for small batch sizes
For larger batches or streaming workloads, bandwidth utilization would improve dramatically. To get a fuller picture, factor in shared memory and atomic traffic in your effective bandwidth calc—tools like Nsight can help.
Want to experiment with this kernel? The complete code is available here. Try varying the histogram bins (256, 512, 1024) or forcing different unrolling factors. The PTX will tell you if your changes actually improved anything, just look for those instruction patterns we’ve discussed.
Building production GPU kernels? Check out CUB for good parallel primitives and Nsight Compute for profiling. And remember: when in doubt, read the PTX.