Triton Puzzles 记录 #
Triton Puzzles 是一份用来学习 Triton 的练习题.
配置 #
我用的是 Triton-Puzzles-Lite.
对应的 triton 版本为 triton==3.1.0.
安装时需要注意指定 numpy<2.0. 否则运行会有问题.
题目 #
Puzzle 1: Constant Add #
Add a constant to a vector. Uses one program id axis.
Block size B0 is always the same as vector x with length N0.
直接 load 然后加回去就可以.
@triton.jit
def add_kernel(x_ptr, z_ptr, N0, B0: tl.constexpr):
# We name the offsets of the pointers as "off_"
off_x = tl.arange(0, B0)
x = tl.load(x_ptr + off_x)
output = x + 10
tl.store(z_ptr + off_x, output)
return
Puzzle 2: Constant Add Block #
Add a constant to a vector. Uses one program block axis (no for loops yet).
Block size B0 is now smaller than the shape vector x which is N0.
直接 load 然后加回去就可以.
@triton.jit
def add_mask2_kernel(x_ptr, z_ptr, N0, B0: tl.constexpr):
pid = tl.program_id(0)
offset = tl.arange(0, B0) + pid * B0
x = tl.load(x_ptr + offset, offset < N0)
output = x + 10
tl.store(z_ptr + offset, output, offset < N0)
return
Puzzle 3: Outer Vector Add #
Add two vectors.
Uses one program block axis. Block size B0 is always the same as vector x length N0.
Block size B1 is always the same as vector y length N1.
在 triton 中, 两个不同形状的向量进行操作会触发广播, 从最后一维开始对齐. 对每一维:
- 如果两个维度相等: OK
- 如果其中一个是 1: 可以广播成另一个
- 否则: shape 不兼容(编译期/报错)
这里就用了广播机制.
@triton.jit
def add_vec_kernel(x_ptr, y_ptr, z_ptr, N0, N1, B0: tl.constexpr, B1: tl.constexpr):
off_x = tl.arange(0, B0)[:, None]
off_y = tl.arange(0, B1)[None, :]
x = tl.load(x_ptr + off_x)
y = tl.load(y_ptr + off_y)
z = x + y
off_z = off_y * B0 + off_x
tl.store(z_ptr + off_z, z)
return
Puzzle 4: Outer Vector Add Block #
Add a row vector to a column vector.
Uses two program block axes. Block size B0 is always less than the vector x length N0.
Block size B1 is always less than vector y length N1.
简单的 2D tile.
@triton.jit
def add_vec_block_kernel(
x_ptr, y_ptr, z_ptr, N0, N1, B0: tl.constexpr, B1: tl.constexpr
):
block_id_x = tl.program_id(0)
block_id_y = tl.program_id(1)
off_x = tl.arange(0, B0)[:, None] + B0 * block_id_x
off_y = tl.arange(0, B1)[None, :] + B1 * block_id_y
mask_x = off_x < N0
mask_y = off_y < N1
mask_z = mask_x & mask_y
x = tl.load(x_ptr + off_x, mask_x)
y = tl.load(y_ptr + off_y, mask_y)
z = x + y
off_z = off_y * N0 + off_x
tl.store(z_ptr + off_z, z, mask_z)
return
Puzzle 5: Fused Outer Multiplication #
Multiply a row vector to a column vector and take a relu.
Uses two program block axes. Block size B0 is always less than the vector x length N0.
Block size B1 is always less than vector y length N1.
仍然是 2D tile. 填充时注意 mask.
@triton.jit
def mul_relu_block_kernel(
x_ptr, y_ptr, z_ptr, N0, N1, B0: tl.constexpr, B1: tl.constexpr
):
block_id_x = tl.program_id(0)
block_id_y = tl.program_id(1)
off_x = tl.arange(0, B0)[None, :] + B0 * block_id_x
off_y = tl.arange(0, B1)[:, None] + B1 * block_id_y
mask_x = off_x < N0
mask_y = off_y < N1
mask_z = mask_x & mask_y
x = tl.load(x_ptr + off_x, mask_x)
y = tl.load(y_ptr + off_y, mask_y)
z = x * y
off_z = off_y * N0 + off_x
tl.store(z_ptr + off_z, 0, mask_z)
tl.store(z_ptr + off_z, z, mask_z & (z > 0))
return
Puzzle 6: Fused Outer Multiplication - Backwards #
Backwards of a function that multiplies a matrix with a row vector and take a relu.
Uses two program blocks. Block size B0 is always less than the vector x length N0.
Block size B1 is always less than vector y length N1. Chain rule backward dz
is of shape N1 by N0
反向传播的梯度是
$$dx_{j, i} = \begin{cases} dz_{j, i} \times y_j , &x_{j, i} \times y_j > 0,\\ 0, &\text{else} \end{cases} $$@triton.jit
def mul_relu_block_back_kernel(
x_ptr, y_ptr, dz_ptr, dx_ptr, N0, N1, B0: tl.constexpr, B1: tl.constexpr
):
block_id_i = tl.program_id(0)
block_id_j = tl.program_id(1)
off_i = tl.arange(0, B0)[None, :] + B0 * block_id_i
off_j = tl.arange(0, B1)[:, None] + B1 * block_id_j
off_ij = off_j * N0 + off_i
mask_i = off_i < N0
mask_j = off_j < N1
mask_ij = mask_i & mask_j
x = tl.load(x_ptr + off_ij, mask_ij)
y = tl.load(y_ptr + off_j, mask_j)
z = x * y
dz = tl.load(dz_ptr + off_ij, mask_ij)
dx = dz * y
tl.store(dx_ptr + off_ij, 0, mask_ij)
tl.store(dx_ptr + off_ij, dx, mask_ij & (z > 0))
return
Puzzle 7: Long Sum #
Sum of a batch of numbers.
Uses one program blocks. Block size B0 represents a range of batches of x of length N0.
Each element is of length T. Process it B1 < T elements at a time.
Hint: You will need a for loop for this problem. These work and look the same as in Python.
由于 GPU 寄存器有限, 因此对于长向量只能切分成小段依次处理.
@triton.jit
def sum_kernel(x_ptr, z_ptr, N0, N1, T, B0: tl.constexpr, B1: tl.constexpr):
pid_0 = tl.program_id(0)
tl.store(z_ptr + pid_0, 0)
nblocks = tl.cdiv(T, B1)
for block in range(nblocks):
offset = tl.arange(0, B1) + block * B1
mask = offset < T
offset = offset + pid_0 * T
x = tl.load(x_ptr + offset, mask)
sum_x = tl.sum(x, axis=0)
z = tl.load(z_ptr + pid_0)
tl.store(z_ptr + pid_0, z + sum_x)
return
Puzzle 8: Long Softmax #
Softmax of a batch of logits.
Uses one program block axis. Block size B0 represents the batch of x of length N0.
Block logit length T. Process it B1 < T elements at a time.
Note softmax needs to be computed in numerically stable form as in Python. In addition in Triton
they recommend not using exp but instead using exp2. You need the identity
Advanced: there one way to do this with 3 loops. You can also do it with 2 loops if you are clever. Hint: you will find this identity useful:
$$\exp(x_i - m) = \exp(x_i - m/2 - m/2) = \exp(x_i - m/ 2) / \exp(m/2)$$这里用了一个 trick, 用一个循环计算完 $\sum \exp(x_i)$. 具体来说是
$$ \begin{aligned} \exp(x_i - m_\text{new}) &= \exp((x_i - m_\text{old}) - (m_\text{new}-m_\text{old})) \\ &= \exp(x_i - m_\text{old}) \times \exp(m_\text{old}-m_\text{new}) \end{aligned} $$计算时保存一个旧的最值 $m_\text{old}$.
@triton.jit
def softmax_kernel(x_ptr, z_ptr, N0, N1, T, B0: tl.constexpr, B1: tl.constexpr):
"""2 loops ver."""
pid_0 = tl.program_id(0)
log2_e = 1.44269504
nblocks = tl.cdiv(T, B1)
d = tl.full((), 0, tl.float32)
x_max = tl.full((), 0, tl.float32)
for block in range(nblocks):
offset = tl.arange(0, B1) + block * B1
mask = offset < T
offset = offset + pid_0 * T
x = tl.load(x_ptr + offset, mask)
old_max = x_max
x_max = max(tl.max(x, axis=0), x_max)
exp_diff = tl.exp2((old_max - x_max) * log2_e)
d = d * exp_diff
d = d + tl.sum(tl.exp2((x - x_max) * log2_e), axis=0)
for block in range(nblocks):
offset = tl.arange(0, B1) + block * B1
mask = offset < T
offset = offset + pid_0 * T
x = tl.load(x_ptr + offset, mask)
x = tl.exp2((x - x_max) * log2_e)
softmax = x / d
tl.store(z_ptr + offset, softmax, mask)
return
Puzzle 9: Simple FlashAttention #
A scalar version of FlashAttention.
Uses zero programs. Block size B0 represent the batches of q to process out of N0. Sequence length is T. Process it B1 < T elements (k, v) at a time for some B1.
This can be done in 1 loop using a similar trick from the last puzzle.
Hint: Use tl.where to mask q dot k to -inf to avoid overflow (NaN).
同样的处理.
@triton.jit
def flashatt_kernel(
q_ptr, k_ptr, v_ptr, z_ptr, N0, T, B0: tl.constexpr, B1: tl.constexpr
):
block_id_i = tl.program_id(0)
log2_e = 1.44269504
myexp = lambda x: tl.exp2(log2_e * x)
off_qz = tl.arange(0, B0)[:, None] + block_id_i * B0
mask_qz = off_qz < T
q = tl.load(q_ptr + off_qz, mask_qz)
nblk_kv = tl.cdiv(T, B1)
d = tl.zeros((B0,), dtype=tl.float32)[:, None]
x_max = tl.zeros((B0,), dtype=tl.float32)[:, None]
for blk in range(nblk_kv):
off_kv = tl.arange(0, B1)[None, :] + blk * B1
mask_kv = off_kv < T
k = tl.load(k_ptr + off_kv, mask_kv)
x = q * k
old_x_max = x_max
x_max = max(x_max, tl.max(x, axis=1)[:, None])
exp_diff = myexp(old_x_max - x_max)
d = d * exp_diff
d = d + tl.sum(myexp(x - x_max), axis=1)[:, None]
z = tl.zeros((B0,), dtype=tl.float32)[:, None]
for blk in range(nblk_kv):
off_kv = tl.arange(0, B1)[None, :] + blk * B1
mask_kv = off_kv < T
k = tl.load(k_ptr + off_kv, mask_kv)
v = tl.load(v_ptr + off_kv, mask_kv)
x = q * k
z = z + tl.sum(myexp(x - x_max) / d * v, axis=1)[:, None]
tl.store(z_ptr + off_qz, z, mask_qz)
return
Puzzle 10: Two Dimensional Convolution #
A batched 2D convolution.
Uses one program id axis. Block size B0 represent the batches to process out of N0.
Image x is size is H by W with only 1 channel, and kernel k is size KH by KW.
卷积, 每个线程处理 B0 张 image.
@triton.jit
def conv2d_kernel(
x_ptr, k_ptr, z_ptr, N0, H, W, KH: tl.constexpr, KW: tl.constexpr, B0: tl.constexpr
):
block_id_i = tl.program_id(0)
off_blk = block_id_i * B0 * H * W
for img in range(B0):
off_img = img * H * W
for i in range(H):
for j in range(W):
off_pixel = i * W + j
off_total = off_blk + off_img + off_pixel
off_x_i = tl.arange(i, i + KH)[:, None]
off_x_j = tl.arange(j, j + KW)[None, :]
off_x = off_blk + off_img + off_x_i * W + off_x_j
mask_x = (off_x_i < H) & (off_x_j < W)
x = tl.load(x_ptr + off_x, mask_x, 0)
off_k = tl.arange(0, KH)[:, None] * KW + tl.arange(0, KW)[None, :]
k = tl.load(k_ptr + off_k)
tl.store(
z_ptr + off_total,
tl.sum(x * k),
off_total < N0 * H * W,
)
return
Puzzle 11: Matrix Multiplication #
A blocked matrix multiplication.
Uses three program id axes. Block size B2 represent the batches to process out of N2.
Block size B0 represent the rows of x to process out of N0. Block size B1 represent the cols
of y to process out of N1. The middle shape is MID.
You are allowed to use tl.dot which computes a smaller mat mul.
Hint: the main trick is that you can split a matmul into smaller parts.
$$z_{i, j, k} = \sum_{l=1}^{L/2} x_{i,j, l} \times y_{i, l, k} + \sum_{l=L/2}^{L} x_{i,j, l} \times y_{i, l, k}$$矩阵相乘, 使用 2D tile.
@triton.jit
def dot_kernel(
x_ptr,
y_ptr,
z_ptr,
N0,
N1,
N2,
MID,
B0: tl.constexpr,
B1: tl.constexpr,
B2: tl.constexpr,
B_MID: tl.constexpr,
):
block_id_j = tl.program_id(0) # Row of X
block_id_k = tl.program_id(1) # Col of Y
block_id_i = tl.program_id(2) # Batch
off_i = block_id_i * B2
off_j = tl.arange(0, B0)[:, None] + block_id_j * B0
off_k = tl.arange(0, B1)[None, :] + block_id_k * B1
off_z = off_i * N0 * N1 + off_j * N1 + off_k
mask_z = (off_j < N0) & (off_k < N1)
z = tl.zeros((B0, B1), dtype=tl.float32)
blk_size = 16
nblk_l = tl.cdiv(MID, blk_size)
for blk in range(nblk_l):
off_blk = tl.arange(0, blk_size) + blk * blk_size
off_x = off_i * N0 * MID + off_j * MID + off_blk[None, :]
mask_x = (off_j < N0) & (off_blk[None, :] < MID)
x = tl.load(x_ptr + off_x, mask_x)
off_y = off_i * MID * N1 + off_blk[:, None] * N1 + off_k
mask_y = (off_blk[:, None] < MID) & (off_k < N1)
y = tl.load(y_ptr + off_y, mask_y)
z = z + tl.dot(x, y)
tl.store(z_ptr + off_z, z, mask_z)
return
Puzzle 12: Quantized Matrix Mult #
When doing matrix multiplication with quantized neural networks a common strategy is to store the weight matrix in lower precision, with a shift and scale term.
For this problem our weight will be stored in 4 bits. We can store FPINT of these in a 32 bit integer. In addition for every group weights in order we will store 1 scale float value and 1 shift 4 bit value. We store these for the column of weight. The activations are stored separately in standard floats.
Mathematically it looks like.
$$z_{j, k} = \sum_{l} sc_{j, \frac{l}{g}} (w_{j, l} - sh_{j, \frac{l}{g}}) \times y_{l, k} \text{ for } j = 1\ldots N_0, k = 1\ldots N_1$$Where g is the number of groups (GROUP).
However, it is a bit more complex since we need to also extract the 4-bit values into floats to begin.
Note:
- We don’t consider batch size, i.e.
i, in this puzzle. - Remember to unpack the
FPINTvalues into separate 4-bit values. This contains some shape manipulation.
最复杂的一个. 题干讲的有点绕, 我的理解如下: 原始权重可以表示为: $x=s\times(w-o)$, 其中:
- $x$ 是原始的权重.
- 大小为
N0 * MID.
- 大小为
- $s$ 是缩放因子. 格式为
float32.- 输入格式是
float32. - 每
GROUP个原始权重共用一个. - 大小为
N0 * MID / GROUP
- 输入格式是
- $w$ 是量化后的权重. 格式为
int4.- 输入格式是
int32. - 一个
int32中可以保存FPINT个权重. - 大小为
N0 * MID / FPINT
- 输入格式是
- $o$ 是偏移. 格式为
int4.- 输入格式是
int32. - 一个
int32中可以保存FPINT个权重. - 每
GROUP个原始权重共用一个. - 大小为
N0 * MID / FPINT / GROUP
- 输入格式是
@triton.jit
def quant_dot_kernel(
scale_ptr,
offset_ptr,
weight_ptr,
activation_ptr,
z_ptr,
N0,
N1,
MID,
B0: tl.constexpr,
B1: tl.constexpr,
B_MID: tl.constexpr,
):
block_id_j = tl.program_id(0)
block_id_k = tl.program_id(1)
def extract(x):
over = tl.arange(0, 8) * 4
mask = tl.full([8], 0xF, tl.int32)
return (x >> over) & mask
off_j = tl.arange(0, B0)[:, None] + block_id_j * B0
off_k = tl.arange(0, B1)[None, :] + block_id_k * B1
off_z = off_j * N1 + off_k
mask_z = (off_j < N0) & (off_k < N1)
z = tl.zeros((B0, B1), dtype=tl.float32)
nblk_l = tl.cdiv(MID, B_MID)
MID_S = MID // GROUP
MID_W = MID // FPINT
MID_O = MID // FPINT // GROUP
B_MID_S = B_MID // GROUP
B_MID_W = B_MID // FPINT
B_MID_O = B_MID // FPINT // GROUP
for blk in range(nblk_l):
off_s_l = tl.arange(0, B_MID_S) + blk * B_MID_S
off_s = off_j * MID_S + off_s_l[None, :]
mask_s = (off_j < N0) & (off_s_l[None, :] < MID_S)
s = tl.load(scale_ptr + off_s, mask_s)
off_w_l = tl.arange(0, B_MID_W) + blk * B_MID_W
off_w = off_j * MID_W + off_w_l[None, :]
mask_w = (off_j < N0) & (off_w_l[None, :] < MID_W)
w = tl.load(weight_ptr + off_w, mask_w)
unpacked_w = extract(w[:, :, None])
off_o_l = tl.arange(0, B_MID_O) + blk * B_MID_O
off_o = off_j * MID_O + off_o_l[None, :]
mask_o = (off_j < N0) & (off_o_l[None, :] < MID_O)
o = tl.load(offset_ptr + off_o, mask_o)
unpacked_o = extract(o)
dequant_w = s[:, :, None] * (unpacked_w - unpacked_o[:, :, None])
dequant_w = tl.reshape(dequant_w, (dequant_w.shape[0], B_MID))
off_a_l = tl.arange(0, B_MID) + blk * B_MID
off_a = off_a_l[:, None] * N1 + off_k
mask_a = (off_a_l[:, None] < MID) & (off_k < N1)
a = tl.load(activation_ptr + off_a, mask_a)
z = z + tl.dot(dequant_w, a)
tl.store(z_ptr + off_z, z, mask_z)
return