Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
325 changes: 288 additions & 37 deletions src/kernels.cu
Original file line number Diff line number Diff line change
@@ -1,55 +1,306 @@
#include <vector>
#include <cuda_fp16.h>

#include <cfloat>
#include <cmath>
#include <algorithm>
#include <stdexcept>
#include <string>
#include <iostream>
#include <cstdlib>
#include "../tester/utils.h"

/**
* @brief Computes the trace of a matrix.
*
* The trace of a matrix is defined as the sum of its diagonal elements.
* This function expects a flattened row-major matrix stored in a
* std::vector. If the matrix is not square, the trace will sum up
* elements along the main diagonal up to the smaller of rows or cols.
*
* @tparam T The numeric type of matrix elements (e.g., float, int).
* @param h_input A flattened matrix of size rows * cols.
* @param rows Number of rows in the matrix.
* @param cols Number of columns in the matrix.
* @return The trace (sum of diagonal values) of the matrix.
*/

// Error checking macro
#define CUDA_CHECK(call) \
{ \
cudaError_t err = call; \
if (err != cudaSuccess) \
{ \
std::cerr << "CUDA error at " << __FILE__ << ":" << __LINE__ \
<< " - " << cudaGetErrorString(err) << "\n"; \
exit(1); \
} \
}

constexpr int WARP_SIZE = 32;

template <typename T, const int kWarpSize = WARP_SIZE>
__device__ __forceinline__ T warp_reduce_sum(T val) {
#pragma unroll
for (int mask = kWarpSize >> 1; mask >= 1; mask >>= 1) {
val += __shfl_xor_sync(0xffffffff, val, mask);
}
return val;
}

template <typename T>
__global__ void trace_kernel(const T* d_input, int cols, int n_diag, T* d_sum) {
constexpr int NUM_THREADS = 256;
constexpr int NUM_WARPS = (NUM_THREADS + WARP_SIZE - 1) / WARP_SIZE;
__shared__ T reduce_smem[NUM_WARPS];

int tid = threadIdx.x;
int idx = blockIdx.x * NUM_THREADS + tid;
int warp = tid / WARP_SIZE;
int lane = tid % WARP_SIZE;

T sum = (idx < n_diag) ? d_input[idx * cols + idx] : T(0);
sum = warp_reduce_sum<T, WARP_SIZE>(sum);

if (lane == 0) {
reduce_smem[warp] = sum;
}
__syncthreads();

sum = (lane < NUM_WARPS) ? reduce_smem[lane] : T(0);
if (warp == 0) {
sum = warp_reduce_sum<T, NUM_WARPS>(sum);
}

if (tid == 0) {
atomicAdd(d_sum, sum);
}
}

template <typename T>
T trace(const std::vector<T>& h_input, size_t rows, size_t cols) {
// TODO: Implement the trace function
return T(-1);
if (h_input.empty() || rows == 0 || cols == 0) {
return T(0);
}

const int n_diag = static_cast<int>(std::min(rows, cols));
const int block_size = 256;
const int grid_size = (n_diag + block_size - 1) / block_size;

T* d_input = nullptr;
T* d_sum = nullptr;

CUDA_CHECK(cudaMalloc(&d_input, h_input.size() * sizeof(T)));
CUDA_CHECK(cudaMalloc(&d_sum, sizeof(T)));

CUDA_CHECK(cudaMemcpy(d_input, h_input.data(), h_input.size() * sizeof(T), cudaMemcpyHostToDevice));
CUDA_CHECK(cudaMemset(d_sum, 0, sizeof(T)));

trace_kernel<T><<<grid_size, block_size>>>(
d_input, static_cast<int>(cols), n_diag, d_sum
);
CUDA_CHECK(cudaGetLastError());
CUDA_CHECK(cudaDeviceSynchronize());

T h_sum = T(0);
CUDA_CHECK(cudaMemcpy(&h_sum, d_sum, sizeof(T), cudaMemcpyDeviceToHost));

CUDA_CHECK(cudaFree(d_input));
CUDA_CHECK(cudaFree(d_sum));

return h_sum;
}


constexpr int FLASH_BLOCK_SIZE = 32;

// Flash Attention v1 Kernel
template <typename T>
__global__ void flashAttentionKernel(
const T* Q, const T* K, const T* V, T* O,
int batch_size, int target_seq_len, int src_seq_len,
int query_heads, int kv_heads, int head_dim, bool is_causal) {

(void)batch_size;
int batch_idx = blockIdx.z;
int head_idx = blockIdx.y;
int row_idx = blockIdx.x;
int tid = threadIdx.x;

if (row_idx >= target_seq_len) {
return;
}

// 一个 block 对应一个 (batch_idx, row_idx, head_idx) 输出向量,
// 仅使用 thread 0 串行计算该向量,其他线程直接返回。
if (tid != 0) {
return;
}

// GQA 头映射:
// 1) 可整除时:每组 q_head 共享一个 kv_head
// 2) 不可整除时:退化为 q_head % kv_heads 的循环映射
int kv_head_idx = 0;
if (query_heads % kv_heads == 0) {
int q_per_kv = query_heads / kv_heads;
kv_head_idx = head_idx / q_per_kv;
} else {
kv_head_idx = head_idx % kv_heads;
}
// causal=True 时只允许看见 [0, row_idx],否则看见全部。
int effective_src = is_causal ? min(src_seq_len, row_idx + 1) : src_seq_len;
int out_base = ((batch_idx * target_seq_len + row_idx) * query_heads + head_idx) * head_dim;

extern __shared__ float smem[];
float* q_vec = smem; // [head_dim],缓存当前 query 向量
float* out_accum = q_vec + head_dim; // [head_dim],累计 softmax 后的输出

if (effective_src <= 0) {
for (int d = 0; d < head_dim; ++d) {
O[out_base + d] = static_cast<T>(0);
}
return;
}

// 缩放因子:scores = (QK^T) / sqrt(d)
const float inv_sqrt_d = 1.0f / sqrtf(static_cast<float>(head_dim));
for (int d = 0; d < head_dim; ++d) {
int q_idx = ((batch_idx * target_seq_len + row_idx) * query_heads + head_idx) * head_dim + d;
q_vec[d] = static_cast<float>(Q[q_idx]);
out_accum[d] = 0.0f;
}

// Pass 1:求 m = max_j(score_j),用于数值稳定 softmax。
float m = -FLT_MAX;
for (int s = 0; s < effective_src; ++s) {
float dot = 0.0f;
int k_base = ((batch_idx * src_seq_len + s) * kv_heads + kv_head_idx) * head_dim;
for (int d = 0; d < head_dim; ++d) {
dot = fmaf(q_vec[d], static_cast<float>(K[k_base + d]), dot);
}
m = fmaxf(m, dot * inv_sqrt_d);
}

// Pass 2:计算
// denom = sum_j exp(score_j - m)
// out = sum_j exp(score_j - m) * V_j
float denom = 0.0f;
for (int s = 0; s < effective_src; ++s) {
float dot = 0.0f;
int k_base = ((batch_idx * src_seq_len + s) * kv_heads + kv_head_idx) * head_dim;
for (int d = 0; d < head_dim; ++d) {
dot = fmaf(q_vec[d], static_cast<float>(K[k_base + d]), dot);
}

float w = expf(dot * inv_sqrt_d - m);
denom += w;

int v_base = ((batch_idx * src_seq_len + s) * kv_heads + kv_head_idx) * head_dim;
for (int d = 0; d < head_dim; ++d) {
out_accum[d] += w * static_cast<float>(V[v_base + d]);
}
}

// 最终归一化:out = out / denom
float inv_denom = (denom > 0.0f) ? (1.0f / denom) : 0.0f;
for (int d = 0; d < head_dim; ++d) {
O[out_base + d] = static_cast<T>(out_accum[d] * inv_denom);
}
}

/**
* @brief Computes flash attention for given query, key, and value tensors.
*
* @tparam T Data type (float) for input/output tensors
* @param[in] h_q Query tensor of shape [batch_size, tgt_seq_len, query_heads, head_dim]
* @param[in] h_k Key tensor of shape [batch_size, src_seq_len, kv_heads, head_dim]
* @param[in] h_v Value tensor of shape [batch_size, src_seq_len, kv_heads, head_dim]
* @param[out] h_o Output attention tensor of shape [batch_size, tgt_seq_len, query_heads, head_dim]
* @param[in] batch_size Batch dimension size
* @param[in] target_seq_len Target sequence length
* @param[in] src_seq_len Source sequence length
* @param[in] query_heads Number of query attention heads
* @param[in] kv_heads Number of key/value heads (supports grouped query attention)
* @param[in] head_dim Dimension size of each attention head
* @param[in] is_causal Whether to apply causal masking
*/
template <typename T>
void flashAttention(const std::vector<T>& h_q, const std::vector<T>& h_k,
const std::vector<T>& h_v, std::vector<T>& h_o,
int batch_size, int target_seq_len, int src_seq_len,
int query_heads, int kv_heads, int head_dim, bool is_causal) {
// TODO: Implement the flash attention function
int query_heads, int kv_heads, int head_dim, bool is_causal) {
// 输出形状:[B, Tq, Hq, D]
const size_t o_elems =
static_cast<size_t>(batch_size > 0 ? batch_size : 0) *
static_cast<size_t>(target_seq_len > 0 ? target_seq_len : 0) *
static_cast<size_t>(query_heads > 0 ? query_heads : 0) *
static_cast<size_t>(head_dim > 0 ? head_dim : 0);
if (h_o.size() != o_elems) {
h_o.resize(o_elems);
}

if (o_elems == 0) {
return;
}
if (batch_size <= 0 || target_seq_len <= 0 || src_seq_len <= 0 ||
query_heads <= 0 || kv_heads <= 0 || head_dim <= 0) {
// 输入维度非法时,按全 0 输出处理,避免非法 launch/malloc。
std::fill(h_o.begin(), h_o.end(), T(0));
return;
}

size_t elem_size = sizeof(T);
// 输入布局:
// Q: [B, Tq, Hq, D]
// K: [B, Tk, Hkv, D]
// V: [B, Tk, Hkv, D]
size_t q_elems = (size_t)batch_size * target_seq_len * query_heads * head_dim;
size_t k_elems = (size_t)batch_size * src_seq_len * kv_heads * head_dim;
size_t v_elems = (size_t)batch_size * src_seq_len * kv_heads * head_dim;
size_t q_size = q_elems * elem_size;
size_t k_size = k_elems * elem_size;
size_t v_size = v_elems * elem_size;
size_t o_size = static_cast<size_t>(batch_size) * target_seq_len * query_heads * head_dim * elem_size;

if (h_q.size() != q_elems || h_k.size() != k_elems || h_v.size() != v_elems) {
throw std::invalid_argument("flashAttention: input tensor sizes do not match provided dimensions.");
}

T* d_q = nullptr;
T* d_k = nullptr;
T* d_v = nullptr;
T* d_o = nullptr;

try {
CUDA_CHECK(cudaMalloc(&d_q, q_size));
CUDA_CHECK(cudaMalloc(&d_k, k_size));
CUDA_CHECK(cudaMalloc(&d_v, v_size));
CUDA_CHECK(cudaMalloc(&d_o, o_size));

CUDA_CHECK(cudaMemcpy(d_q, h_q.data(), q_size, cudaMemcpyHostToDevice));
CUDA_CHECK(cudaMemcpy(d_k, h_k.data(), k_size, cudaMemcpyHostToDevice));
CUDA_CHECK(cudaMemcpy(d_v, h_v.data(), v_size, cudaMemcpyHostToDevice));

// 网格布局:
// grid.x -> query 位置 t
// grid.y -> query head h
// grid.z -> batch b
dim3 grid_dim(
target_seq_len,
query_heads,
batch_size
);

int block_size = FLASH_BLOCK_SIZE;

// 动态共享内存:q_vec + out_accum(各 head_dim 个 float)
size_t smem_size = 0;
smem_size += (2 * static_cast<size_t>(head_dim)) * sizeof(float);

int device = 0;
CUDA_CHECK(cudaGetDevice(&device));

int max_smem = 0;
CUDA_CHECK(cudaDeviceGetAttribute(&max_smem, cudaDevAttrMaxSharedMemoryPerBlock, device));
if (smem_size > static_cast<size_t>(max_smem)) {
throw std::invalid_argument("flashAttention: shared memory requirement exceeds device limit.");
}

flashAttentionKernel<T><<<grid_dim, block_size, smem_size>>>(
d_q, d_k, d_v, d_o,
batch_size, target_seq_len, src_seq_len,
query_heads, kv_heads, head_dim, is_causal
);

CUDA_CHECK(cudaGetLastError());
CUDA_CHECK(cudaDeviceSynchronize());

CUDA_CHECK(cudaMemcpy(h_o.data(), d_o, o_size, cudaMemcpyDeviceToHost));
} catch (...) {
if (d_q != nullptr) cudaFree(d_q);
if (d_k != nullptr) cudaFree(d_k);
if (d_v != nullptr) cudaFree(d_v);
if (d_o != nullptr) cudaFree(d_o);
throw;
}

CUDA_CHECK(cudaFree(d_q));
CUDA_CHECK(cudaFree(d_k));
CUDA_CHECK(cudaFree(d_v));
CUDA_CHECK(cudaFree(d_o));
}

// *********************************************************************
// Explicit Template Instantiations (REQUIRED FOR LINKING WITH TESTER.O)
// DO NOT MODIFY THIS SECTION
// Explicit Template Instantiations
// *********************************************************************
template int trace<int>(const std::vector<int>&, size_t, size_t);
template float trace<float>(const std::vector<float>&, size_t, size_t);
Expand Down
Loading