diff --git a/src/kernels.cu b/src/kernels.cu index 74312070..d629a292 100644 --- a/src/kernels.cu +++ b/src/kernels.cu @@ -1,55 +1,306 @@ #include #include - +#include +#include +#include +#include +#include +#include +#include #include "../tester/utils.h" -/** - * @brief Computes the trace of a matrix. - * - * The trace of a matrix is defined as the sum of its diagonal elements. - * This function expects a flattened row-major matrix stored in a - * std::vector. If the matrix is not square, the trace will sum up - * elements along the main diagonal up to the smaller of rows or cols. - * - * @tparam T The numeric type of matrix elements (e.g., float, int). - * @param h_input A flattened matrix of size rows * cols. - * @param rows Number of rows in the matrix. - * @param cols Number of columns in the matrix. - * @return The trace (sum of diagonal values) of the matrix. - */ + +// Error checking macro +#define CUDA_CHECK(call) \ +{ \ + cudaError_t err = call; \ + if (err != cudaSuccess) \ + { \ + std::cerr << "CUDA error at " << __FILE__ << ":" << __LINE__ \ + << " - " << cudaGetErrorString(err) << "\n"; \ + exit(1); \ + } \ +} + +constexpr int WARP_SIZE = 32; + +template +__device__ __forceinline__ T warp_reduce_sum(T val) { + #pragma unroll + for (int mask = kWarpSize >> 1; mask >= 1; mask >>= 1) { + val += __shfl_xor_sync(0xffffffff, val, mask); + } + return val; +} + +template +__global__ void trace_kernel(const T* d_input, int cols, int n_diag, T* d_sum) { + constexpr int NUM_THREADS = 256; + constexpr int NUM_WARPS = (NUM_THREADS + WARP_SIZE - 1) / WARP_SIZE; + __shared__ T reduce_smem[NUM_WARPS]; + + int tid = threadIdx.x; + int idx = blockIdx.x * NUM_THREADS + tid; + int warp = tid / WARP_SIZE; + int lane = tid % WARP_SIZE; + + T sum = (idx < n_diag) ? d_input[idx * cols + idx] : T(0); + sum = warp_reduce_sum(sum); + + if (lane == 0) { + reduce_smem[warp] = sum; + } + __syncthreads(); + + sum = (lane < NUM_WARPS) ? reduce_smem[lane] : T(0); + if (warp == 0) { + sum = warp_reduce_sum(sum); + } + + if (tid == 0) { + atomicAdd(d_sum, sum); + } +} + template T trace(const std::vector& h_input, size_t rows, size_t cols) { - // TODO: Implement the trace function - return T(-1); + if (h_input.empty() || rows == 0 || cols == 0) { + return T(0); + } + + const int n_diag = static_cast(std::min(rows, cols)); + const int block_size = 256; + const int grid_size = (n_diag + block_size - 1) / block_size; + + T* d_input = nullptr; + T* d_sum = nullptr; + + CUDA_CHECK(cudaMalloc(&d_input, h_input.size() * sizeof(T))); + CUDA_CHECK(cudaMalloc(&d_sum, sizeof(T))); + + CUDA_CHECK(cudaMemcpy(d_input, h_input.data(), h_input.size() * sizeof(T), cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemset(d_sum, 0, sizeof(T))); + + trace_kernel<<>>( + d_input, static_cast(cols), n_diag, d_sum + ); + CUDA_CHECK(cudaGetLastError()); + CUDA_CHECK(cudaDeviceSynchronize()); + + T h_sum = T(0); + CUDA_CHECK(cudaMemcpy(&h_sum, d_sum, sizeof(T), cudaMemcpyDeviceToHost)); + + CUDA_CHECK(cudaFree(d_input)); + CUDA_CHECK(cudaFree(d_sum)); + + return h_sum; +} + + +constexpr int FLASH_BLOCK_SIZE = 32; + +// Flash Attention v1 Kernel +template +__global__ void flashAttentionKernel( + const T* Q, const T* K, const T* V, T* O, + int batch_size, int target_seq_len, int src_seq_len, + int query_heads, int kv_heads, int head_dim, bool is_causal) { + + (void)batch_size; + int batch_idx = blockIdx.z; + int head_idx = blockIdx.y; + int row_idx = blockIdx.x; + int tid = threadIdx.x; + + if (row_idx >= target_seq_len) { + return; + } + + // 一个 block 对应一个 (batch_idx, row_idx, head_idx) 输出向量, + // 仅使用 thread 0 串行计算该向量,其他线程直接返回。 + if (tid != 0) { + return; + } + + // GQA 头映射: + // 1) 可整除时:每组 q_head 共享一个 kv_head + // 2) 不可整除时:退化为 q_head % kv_heads 的循环映射 + int kv_head_idx = 0; + if (query_heads % kv_heads == 0) { + int q_per_kv = query_heads / kv_heads; + kv_head_idx = head_idx / q_per_kv; + } else { + kv_head_idx = head_idx % kv_heads; + } + // causal=True 时只允许看见 [0, row_idx],否则看见全部。 + int effective_src = is_causal ? min(src_seq_len, row_idx + 1) : src_seq_len; + int out_base = ((batch_idx * target_seq_len + row_idx) * query_heads + head_idx) * head_dim; + + extern __shared__ float smem[]; + float* q_vec = smem; // [head_dim],缓存当前 query 向量 + float* out_accum = q_vec + head_dim; // [head_dim],累计 softmax 后的输出 + + if (effective_src <= 0) { + for (int d = 0; d < head_dim; ++d) { + O[out_base + d] = static_cast(0); + } + return; + } + + // 缩放因子:scores = (QK^T) / sqrt(d) + const float inv_sqrt_d = 1.0f / sqrtf(static_cast(head_dim)); + for (int d = 0; d < head_dim; ++d) { + int q_idx = ((batch_idx * target_seq_len + row_idx) * query_heads + head_idx) * head_dim + d; + q_vec[d] = static_cast(Q[q_idx]); + out_accum[d] = 0.0f; + } + + // Pass 1:求 m = max_j(score_j),用于数值稳定 softmax。 + float m = -FLT_MAX; + for (int s = 0; s < effective_src; ++s) { + float dot = 0.0f; + int k_base = ((batch_idx * src_seq_len + s) * kv_heads + kv_head_idx) * head_dim; + for (int d = 0; d < head_dim; ++d) { + dot = fmaf(q_vec[d], static_cast(K[k_base + d]), dot); + } + m = fmaxf(m, dot * inv_sqrt_d); + } + + // Pass 2:计算 + // denom = sum_j exp(score_j - m) + // out = sum_j exp(score_j - m) * V_j + float denom = 0.0f; + for (int s = 0; s < effective_src; ++s) { + float dot = 0.0f; + int k_base = ((batch_idx * src_seq_len + s) * kv_heads + kv_head_idx) * head_dim; + for (int d = 0; d < head_dim; ++d) { + dot = fmaf(q_vec[d], static_cast(K[k_base + d]), dot); + } + + float w = expf(dot * inv_sqrt_d - m); + denom += w; + + int v_base = ((batch_idx * src_seq_len + s) * kv_heads + kv_head_idx) * head_dim; + for (int d = 0; d < head_dim; ++d) { + out_accum[d] += w * static_cast(V[v_base + d]); + } + } + + // 最终归一化:out = out / denom + float inv_denom = (denom > 0.0f) ? (1.0f / denom) : 0.0f; + for (int d = 0; d < head_dim; ++d) { + O[out_base + d] = static_cast(out_accum[d] * inv_denom); + } } -/** - * @brief Computes flash attention for given query, key, and value tensors. - * - * @tparam T Data type (float) for input/output tensors - * @param[in] h_q Query tensor of shape [batch_size, tgt_seq_len, query_heads, head_dim] - * @param[in] h_k Key tensor of shape [batch_size, src_seq_len, kv_heads, head_dim] - * @param[in] h_v Value tensor of shape [batch_size, src_seq_len, kv_heads, head_dim] - * @param[out] h_o Output attention tensor of shape [batch_size, tgt_seq_len, query_heads, head_dim] - * @param[in] batch_size Batch dimension size - * @param[in] target_seq_len Target sequence length - * @param[in] src_seq_len Source sequence length - * @param[in] query_heads Number of query attention heads - * @param[in] kv_heads Number of key/value heads (supports grouped query attention) - * @param[in] head_dim Dimension size of each attention head - * @param[in] is_causal Whether to apply causal masking - */ template void flashAttention(const std::vector& h_q, const std::vector& h_k, const std::vector& h_v, std::vector& h_o, int batch_size, int target_seq_len, int src_seq_len, - int query_heads, int kv_heads, int head_dim, bool is_causal) { - // TODO: Implement the flash attention function + int query_heads, int kv_heads, int head_dim, bool is_causal) { + // 输出形状:[B, Tq, Hq, D] + const size_t o_elems = + static_cast(batch_size > 0 ? batch_size : 0) * + static_cast(target_seq_len > 0 ? target_seq_len : 0) * + static_cast(query_heads > 0 ? query_heads : 0) * + static_cast(head_dim > 0 ? head_dim : 0); + if (h_o.size() != o_elems) { + h_o.resize(o_elems); + } + + if (o_elems == 0) { + return; + } + if (batch_size <= 0 || target_seq_len <= 0 || src_seq_len <= 0 || + query_heads <= 0 || kv_heads <= 0 || head_dim <= 0) { + // 输入维度非法时,按全 0 输出处理,避免非法 launch/malloc。 + std::fill(h_o.begin(), h_o.end(), T(0)); + return; + } + + size_t elem_size = sizeof(T); + // 输入布局: + // Q: [B, Tq, Hq, D] + // K: [B, Tk, Hkv, D] + // V: [B, Tk, Hkv, D] + size_t q_elems = (size_t)batch_size * target_seq_len * query_heads * head_dim; + size_t k_elems = (size_t)batch_size * src_seq_len * kv_heads * head_dim; + size_t v_elems = (size_t)batch_size * src_seq_len * kv_heads * head_dim; + size_t q_size = q_elems * elem_size; + size_t k_size = k_elems * elem_size; + size_t v_size = v_elems * elem_size; + size_t o_size = static_cast(batch_size) * target_seq_len * query_heads * head_dim * elem_size; + + if (h_q.size() != q_elems || h_k.size() != k_elems || h_v.size() != v_elems) { + throw std::invalid_argument("flashAttention: input tensor sizes do not match provided dimensions."); + } + + T* d_q = nullptr; + T* d_k = nullptr; + T* d_v = nullptr; + T* d_o = nullptr; + + try { + CUDA_CHECK(cudaMalloc(&d_q, q_size)); + CUDA_CHECK(cudaMalloc(&d_k, k_size)); + CUDA_CHECK(cudaMalloc(&d_v, v_size)); + CUDA_CHECK(cudaMalloc(&d_o, o_size)); + + CUDA_CHECK(cudaMemcpy(d_q, h_q.data(), q_size, cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(d_k, h_k.data(), k_size, cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(d_v, h_v.data(), v_size, cudaMemcpyHostToDevice)); + + // 网格布局: + // grid.x -> query 位置 t + // grid.y -> query head h + // grid.z -> batch b + dim3 grid_dim( + target_seq_len, + query_heads, + batch_size + ); + + int block_size = FLASH_BLOCK_SIZE; + + // 动态共享内存:q_vec + out_accum(各 head_dim 个 float) + size_t smem_size = 0; + smem_size += (2 * static_cast(head_dim)) * sizeof(float); + + int device = 0; + CUDA_CHECK(cudaGetDevice(&device)); + + int max_smem = 0; + CUDA_CHECK(cudaDeviceGetAttribute(&max_smem, cudaDevAttrMaxSharedMemoryPerBlock, device)); + if (smem_size > static_cast(max_smem)) { + throw std::invalid_argument("flashAttention: shared memory requirement exceeds device limit."); + } + + flashAttentionKernel<<>>( + d_q, d_k, d_v, d_o, + batch_size, target_seq_len, src_seq_len, + query_heads, kv_heads, head_dim, is_causal + ); + + CUDA_CHECK(cudaGetLastError()); + CUDA_CHECK(cudaDeviceSynchronize()); + + CUDA_CHECK(cudaMemcpy(h_o.data(), d_o, o_size, cudaMemcpyDeviceToHost)); + } catch (...) { + if (d_q != nullptr) cudaFree(d_q); + if (d_k != nullptr) cudaFree(d_k); + if (d_v != nullptr) cudaFree(d_v); + if (d_o != nullptr) cudaFree(d_o); + throw; + } + + CUDA_CHECK(cudaFree(d_q)); + CUDA_CHECK(cudaFree(d_k)); + CUDA_CHECK(cudaFree(d_v)); + CUDA_CHECK(cudaFree(d_o)); } // ********************************************************************* -// Explicit Template Instantiations (REQUIRED FOR LINKING WITH TESTER.O) -// DO NOT MODIFY THIS SECTION +// Explicit Template Instantiations // ********************************************************************* template int trace(const std::vector&, size_t, size_t); template float trace(const std::vector&, size_t, size_t); diff --git a/src/kernels.maca b/src/kernels.maca index 765e08d9..30b2a596 100644 --- a/src/kernels.maca +++ b/src/kernels.maca @@ -1,60 +1,276 @@ #include -#include +#include +#include +#include +#include +#include +#include +#include #include "../tester/utils.h" -/** - * @brief Computes the trace of a matrix. - * - * The trace of a matrix is defined as the sum of its diagonal elements. - * This function expects a flattened row-major matrix stored in a - * std::vector. If the matrix is not square, the trace will sum up - * elements along the main diagonal up to the smaller of rows or cols. - * - * @tparam T The numeric type of matrix elements (e.g., float, int). - * @param h_input A flattened matrix of size rows * cols. - * @param rows Number of rows in the matrix. - * @param cols Number of columns in the matrix. - * @return The trace (sum of diagonal values) of the matrix. - */ +// Error checking macro +#define MACA_CHECK(call) \ +{ \ + mcError_t err = call; \ + if (err != mcSuccess) \ + { \ + std::cerr << "MACA error at " << __FILE__ << ":" << __LINE__ \ + << " - " << mcGetErrorString(err) << "\n"; \ + exit(1); \ + } \ +} + +template +__host__ __device__ __forceinline__ float to_float(T x) { + return static_cast(x); +} + +template <> +__host__ __device__ __forceinline__ float to_float<__half>(__half x) { + return __half2float(x); +} + +template +__host__ __device__ __forceinline__ T from_float(float x) { + return static_cast(x); +} + +template <> +__host__ __device__ __forceinline__ __half from_float<__half>(float x) { + return __float2half_rn(x); +} + +// Warp归约模板 +template +__device__ __forceinline__ T warp_reduce_sum(T val) { + const auto mask = __activemask(); + #pragma unroll + for (int delta = warpSize >> 1; delta >= 1; delta >>= 1) { + val += __shfl_xor_sync(mask, val, delta); + } + return val; +} + +// trace_kernel +template +__global__ void trace_kernel(const T* d_input, int cols, int n_diag, T* d_sum) { + constexpr int kNumThreads = 256; + __shared__ T reduce_smem[kNumThreads]; + + int tid = threadIdx.x; + int idx = blockIdx.x * kNumThreads + tid; + int warp = tid / warpSize; + int lane = tid % warpSize; + int num_warps = (kNumThreads + warpSize - 1) / warpSize; + + T sum = (idx < n_diag) ? d_input[idx * cols + idx] : T(0); + sum = warp_reduce_sum(sum); + + if (lane == 0) { + reduce_smem[warp] = sum; + } + __syncthreads(); + + sum = (tid < num_warps) ? reduce_smem[tid] : T(0); + if (warp == 0) { + sum = warp_reduce_sum(sum); + if (lane == 0) { + atomicAdd(d_sum, sum); + } + } +} + template T trace(const std::vector& h_input, size_t rows, size_t cols) { - // TODO: Implement the trace function - return T(-1); -} - -/** - * @brief Computes flash attention for given query, key, and value tensors. - * - * @tparam T Data type (float) for input/output tensors - * @param[in] h_q Query tensor of shape [batch_size, tgt_seq_len, query_heads, head_dim] - * @param[in] h_k Key tensor of shape [batch_size, src_seq_len, kv_heads, head_dim] - * @param[in] h_v Value tensor of shape [batch_size, src_seq_len, kv_heads, head_dim] - * @param[out] h_o Output attention tensor of shape [batch_size, tgt_seq_len, query_heads, head_dim] - * @param[in] batch_size Batch dimension size - * @param[in] target_seq_len Target sequence length - * @param[in] src_seq_len Source sequence length - * @param[in] query_heads Number of query attention heads - * @param[in] kv_heads Number of key/value heads (supports grouped query attention) - * @param[in] head_dim Dimension size of each attention head - * @param[in] is_causal Whether to apply causal masking - */ + if (h_input.empty() || rows == 0 || cols == 0) { + return T(0); + } + + const int n_diag = static_cast(std::min(rows, cols)); + const int block_size = 256; + const int grid_size = (n_diag + block_size - 1) / block_size; + + T* d_input = nullptr; + T* d_sum = nullptr; + + MACA_CHECK(mcMalloc(&d_input, h_input.size() * sizeof(T))); + MACA_CHECK(mcMalloc(&d_sum, sizeof(T))); + + MACA_CHECK(mcMemcpy( + d_input, h_input.data(), h_input.size() * sizeof(T), mcMemcpyHostToDevice)); + MACA_CHECK(mcMemset(d_sum, 0, sizeof(T))); + + trace_kernel<<>>(d_input, static_cast(cols), n_diag, d_sum); + MACA_CHECK(mcGetLastError()); + MACA_CHECK(mcDeviceSynchronize()); + + T h_sum = T(0); + MACA_CHECK(mcMemcpy(&h_sum, d_sum, sizeof(T), mcMemcpyDeviceToHost)); + + MACA_CHECK(mcFree(d_input)); + MACA_CHECK(mcFree(d_sum)); + + return h_sum; +} + +constexpr int FLASH_BLOCK_SIZE = 32; + +// flash_attention_kernel +template +__global__ void flash_attention_kernel(const T* q, const T* k, const T* v, T* o, + int batch_size, int target_seq_len, int src_seq_len, + int query_heads, int kv_heads, int head_dim, + bool is_causal) { + (void)batch_size; + const int b = blockIdx.z; + const int qh = blockIdx.y; + const int t = blockIdx.x; + const int tid = threadIdx.x; + + if (t >= target_seq_len) { + return; + } + if (tid != 0) { + return; + } + + int kvh = 0; + if (query_heads % kv_heads == 0) { + const int q_per_kv = query_heads / kv_heads; + kvh = qh / q_per_kv; + } else { + kvh = qh % kv_heads; + } + + const int effective_src = is_causal ? min(src_seq_len, t + 1) : src_seq_len; + const int out_base = ((b * target_seq_len + t) * query_heads + qh) * head_dim; + + extern __shared__ float smem[]; + float* q_vec = smem; // [head_dim] + float* out_accum = q_vec + head_dim; // [head_dim] + + if (effective_src <= 0) { + for (int d = 0; d < head_dim; ++d) { + o[out_base + d] = from_float(0.0f); + } + return; + } + + const float inv_sqrt_d = 1.0f / sqrtf(static_cast(head_dim)); + for (int d = 0; d < head_dim; ++d) { + const int q_idx = ((b * target_seq_len + t) * query_heads + qh) * head_dim + d; + q_vec[d] = to_float(q[q_idx]); + out_accum[d] = 0.0f; + } + + float m = -FLT_MAX; + for (int s = 0; s < effective_src; ++s) { + float dot = 0.0f; + const int k_base = ((b * src_seq_len + s) * kv_heads + kvh) * head_dim; + for (int d = 0; d < head_dim; ++d) { + dot = fmaf(q_vec[d], to_float(k[k_base + d]), dot); + } + m = fmaxf(m, dot * inv_sqrt_d); + } + + float denom = 0.0f; + for (int s = 0; s < effective_src; ++s) { + float dot = 0.0f; + const int k_base = ((b * src_seq_len + s) * kv_heads + kvh) * head_dim; + for (int d = 0; d < head_dim; ++d) { + dot = fmaf(q_vec[d], to_float(k[k_base + d]), dot); + } + + const float w = expf(dot * inv_sqrt_d - m); + denom += w; + + const int v_base = ((b * src_seq_len + s) * kv_heads + kvh) * head_dim; + for (int d = 0; d < head_dim; ++d) { + out_accum[d] += w * to_float(v[v_base + d]); + } + } + + const float inv_denom = (denom > 0.0f) ? (1.0f / denom) : 0.0f; + for (int d = 0; d < head_dim; ++d) { + o[out_base + d] = from_float(out_accum[d] * inv_denom); + } +} + template void flashAttention(const std::vector& h_q, const std::vector& h_k, const std::vector& h_v, std::vector& h_o, - int batch_size, int target_seq_len, int src_seq_len, - int query_heads, int kv_heads, int head_dim, bool is_causal) { + int batch_size, int target_seq_len, int src_seq_len, + int query_heads, int kv_heads, int head_dim, bool is_causal) { + const size_t o_elems = + static_cast(batch_size > 0 ? batch_size : 0) * + static_cast(target_seq_len > 0 ? target_seq_len : 0) * + static_cast(query_heads > 0 ? query_heads : 0) * + static_cast(head_dim > 0 ? head_dim : 0); + if (h_o.size() != o_elems) { + h_o.resize(o_elems); + } + + if (o_elems == 0) { + return; + } + if (batch_size <= 0 || target_seq_len <= 0 || src_seq_len <= 0 || + query_heads <= 0 || kv_heads <= 0 || head_dim <= 0) { + std::fill(h_o.begin(), h_o.end(), from_float(0.0f)); + return; + } + + const size_t q_elems = + static_cast(batch_size) * target_seq_len * query_heads * head_dim; + const size_t k_elems = + static_cast(batch_size) * src_seq_len * kv_heads * head_dim; + const size_t v_elems = + static_cast(batch_size) * src_seq_len * kv_heads * head_dim; + + if (h_q.size() != q_elems || h_k.size() != k_elems || h_v.size() != v_elems) { + throw std::invalid_argument( + "flashAttention: input tensor sizes do not match provided dimensions."); + } + + T* d_q = nullptr; + T* d_k = nullptr; + T* d_v = nullptr; + T* d_o = nullptr; + + MACA_CHECK(mcMalloc(&d_q, q_elems * sizeof(T))); + MACA_CHECK(mcMalloc(&d_k, k_elems * sizeof(T))); + MACA_CHECK(mcMalloc(&d_v, v_elems * sizeof(T))); + MACA_CHECK(mcMalloc(&d_o, o_elems * sizeof(T))); + + MACA_CHECK(mcMemcpy(d_q, h_q.data(), q_elems * sizeof(T), mcMemcpyHostToDevice)); + MACA_CHECK(mcMemcpy(d_k, h_k.data(), k_elems * sizeof(T), mcMemcpyHostToDevice)); + MACA_CHECK(mcMemcpy(d_v, h_v.data(), v_elems * sizeof(T), mcMemcpyHostToDevice)); + + dim3 grid_dim(target_seq_len, query_heads, batch_size); + const int block_size = FLASH_BLOCK_SIZE; + const size_t smem_size = static_cast(2 * head_dim) * sizeof(float); + + flash_attention_kernel<<>>( + d_q, d_k, d_v, d_o, + batch_size, target_seq_len, src_seq_len, + query_heads, kv_heads, head_dim, is_causal); + + MACA_CHECK(mcGetLastError()); + MACA_CHECK(mcDeviceSynchronize()); + + MACA_CHECK(mcMemcpy(h_o.data(), d_o, o_elems * sizeof(T), mcMemcpyDeviceToHost)); + + MACA_CHECK(mcFree(d_q)); + MACA_CHECK(mcFree(d_k)); + MACA_CHECK(mcFree(d_v)); + MACA_CHECK(mcFree(d_o)); } -// ********************************************************************* -// Explicit Template Instantiations (REQUIRED FOR LINKING WITH TESTER.O) -// DO NOT MODIFY THIS SECTION -// ********************************************************************* template int trace(const std::vector&, size_t, size_t); template float trace(const std::vector&, size_t, size_t); template void flashAttention(const std::vector&, const std::vector&, const std::vector&, std::vector&, int, int, int, int, int, int, bool); -template void flashAttention(const std::vector&, const std::vector&, - const std::vector&, std::vector&, +template void flashAttention<__half>(const std::vector<__half>&, const std::vector<__half>&, + const std::vector<__half>&, std::vector<__half>&, int, int, int, int, int, int, bool); diff --git a/src/kernels.mu b/src/kernels.mu index 1fb87770..90e92be1 100644 --- a/src/kernels.mu +++ b/src/kernels.mu @@ -1,54 +1,299 @@ #include #include +#include +#include +#include +#include +#include +#include +#include + +// Error checking macro +#define MUSA_CHECK(call) \ +{ \ + musaError_t err = call; \ + if (err != musaSuccess) \ + { \ + std::cerr << "MUSA error at " << __FILE__ << ":" << __LINE__ \ + << " - " << musaGetErrorString(err) << "\n"; \ + exit(1); \ + } \ +} + + +constexpr int WARP_SIZE = 32; + +template +__device__ __forceinline__ T warp_reduce_sum(T val) { + #pragma unroll + for (int mask = kWarpSize >> 1; mask >= 1; mask >>= 1) { + val += __shfl_xor_sync(0xffffffff, val, mask); + } + return val; +} + +template +__global__ void trace_kernel(const T* d_input, int cols, int n_diag, T* d_sum) { + constexpr int NUM_THREADS = 256; + constexpr int NUM_WARPS = (NUM_THREADS + WARP_SIZE - 1) / WARP_SIZE; + __shared__ T reduce_smem[NUM_WARPS]; + + int tid = threadIdx.x; + int idx = blockIdx.x * NUM_THREADS + tid; + int warp = tid / WARP_SIZE; + int lane = tid % WARP_SIZE; + + T sum = (idx < n_diag) ? d_input[idx * cols + idx] : T(0); + sum = warp_reduce_sum(sum); + + if (lane == 0) { + reduce_smem[warp] = sum; + } + __syncthreads(); + + sum = (lane < NUM_WARPS) ? reduce_smem[lane] : T(0); + if (warp == 0) { + sum = warp_reduce_sum(sum); + } + + if (tid == 0) { + atomicAdd(d_sum, sum); + } +} -#include "../tester/utils.h" - -/** - * @brief Computes the trace of a matrix. - * - * The trace of a matrix is defined as the sum of its diagonal elements. - * This function expects a flattened row-major matrix stored in a - * std::vector. If the matrix is not square, the trace will sum up - * elements along the main diagonal up to the smaller of rows or cols. - * - * @tparam T The numeric type of matrix elements (e.g., float, int). - * @param h_input A flattened matrix of size rows * cols. - * @param rows Number of rows in the matrix. - * @param cols Number of columns in the matrix. - * @return The trace (sum of diagonal values) of the matrix. - */ template T trace(const std::vector& h_input, size_t rows, size_t cols) { - // TODO: Implement the trace function - return T(-1); + if (h_input.empty() || rows == 0 || cols == 0) { + return T(0); + } + + const int n_diag = static_cast(std::min(rows, cols)); + const int block_size = 256; + const int grid_size = (n_diag + block_size - 1) / block_size; + + T* d_input = nullptr; + T* d_sum = nullptr; + + MUSA_CHECK(musaMalloc(&d_input, h_input.size() * sizeof(T))); + MUSA_CHECK(musaMalloc(&d_sum, sizeof(T))); + + MUSA_CHECK(musaMemcpy(d_input, h_input.data(), h_input.size() * sizeof(T), musaMemcpyHostToDevice)); + MUSA_CHECK(musaMemset(d_sum, 0, sizeof(T))); + + trace_kernel<<>>( + d_input, static_cast(cols), n_diag, d_sum + ); + MUSA_CHECK(musaGetLastError()); + MUSA_CHECK(musaDeviceSynchronize()); + + T h_sum = T(0); + MUSA_CHECK(musaMemcpy(&h_sum, d_sum, sizeof(T), musaMemcpyDeviceToHost)); + + MUSA_CHECK(musaFree(d_input)); + MUSA_CHECK(musaFree(d_sum)); + + return h_sum; } -/** - * @brief Computes flash attention for given query, key, and value tensors. - * - * @tparam T Data type (float) for input/output tensors - * @param[in] h_q Query tensor of shape [batch_size, tgt_seq_len, query_heads, head_dim] - * @param[in] h_k Key tensor of shape [batch_size, src_seq_len, kv_heads, head_dim] - * @param[in] h_v Value tensor of shape [batch_size, src_seq_len, kv_heads, head_dim] - * @param[out] h_o Output attention tensor of shape [batch_size, tgt_seq_len, query_heads, head_dim] - * @param[in] batch_size Batch dimension size - * @param[in] target_seq_len Target sequence length - * @param[in] src_seq_len Source sequence length - * @param[in] query_heads Number of query attention heads - * @param[in] kv_heads Number of key/value heads (supports grouped query attention) - * @param[in] head_dim Dimension size of each attention head - * @param[in] is_causal Whether to apply causal masking - */ +constexpr int FLASH_BLOCK_SIZE = 32; + + +template +__global__ void flashAttentionKernel( + const T* Q, const T* K, const T* V, T* O, + int batch_size, int target_seq_len, int src_seq_len, + int query_heads, int kv_heads, int head_dim, bool is_causal) { + + (void)batch_size; + int batch_idx = blockIdx.z; + int head_idx = blockIdx.y; + int row_idx = blockIdx.x; + int tid = threadIdx.x; + + if (row_idx >= target_seq_len) { + return; + } + + if (tid != 0) { + return; + } + + int kv_head_idx = 0; + if (query_heads % kv_heads == 0) { + int q_per_kv = query_heads / kv_heads; + kv_head_idx = head_idx / q_per_kv; + } else { + kv_head_idx = head_idx % kv_heads; + } + int effective_src = is_causal ? min(src_seq_len, row_idx + 1) : src_seq_len; + int out_base = ((batch_idx * target_seq_len + row_idx) * query_heads + head_idx) * head_dim; + + extern __shared__ float smem[]; + float* q_vec = smem; // [head_dim],缓存当前 query 向量 + float* out_accum = q_vec + head_dim; // [head_dim],累计 softmax 后的输出 + + if (effective_src <= 0) { + for (int d = 0; d < head_dim; ++d) { + O[out_base + d] = static_cast(0); + } + return; + } + + // 缩放因子:scores = (QK^T) / sqrt(d) + const float inv_sqrt_d = 1.0f / sqrtf(static_cast(head_dim)); + for (int d = 0; d < head_dim; ++d) { + int q_idx = ((batch_idx * target_seq_len + row_idx) * query_heads + head_idx) * head_dim + d; + q_vec[d] = static_cast(Q[q_idx]); + out_accum[d] = 0.0f; + } + + // Pass 1:求 m = max_j(score_j),用于数值稳定 softmax。 + float m = -FLT_MAX; + for (int s = 0; s < effective_src; ++s) { + float dot = 0.0f; + int k_base = ((batch_idx * src_seq_len + s) * kv_heads + kv_head_idx) * head_dim; + for (int d = 0; d < head_dim; ++d) { + dot = fmaf(q_vec[d], static_cast(K[k_base + d]), dot); + } + m = fmaxf(m, dot * inv_sqrt_d); + } + + // Pass 2:计算 + // denom = sum_j exp(score_j - m) + // out = sum_j exp(score_j - m) * V_j + float denom = 0.0f; + for (int s = 0; s < effective_src; ++s) { + float dot = 0.0f; + int k_base = ((batch_idx * src_seq_len + s) * kv_heads + kv_head_idx) * head_dim; + for (int d = 0; d < head_dim; ++d) { + dot = fmaf(q_vec[d], static_cast(K[k_base + d]), dot); + } + + float w = expf(dot * inv_sqrt_d - m); + denom += w; + + int v_base = ((batch_idx * src_seq_len + s) * kv_heads + kv_head_idx) * head_dim; + for (int d = 0; d < head_dim; ++d) { + out_accum[d] += w * static_cast(V[v_base + d]); + } + } + + // 最终归一化:out = out / denom + float inv_denom = (denom > 0.0f) ? (1.0f / denom) : 0.0f; + for (int d = 0; d < head_dim; ++d) { + O[out_base + d] = static_cast(out_accum[d] * inv_denom); + } +} + +// Host function for Flash Attention v1 template void flashAttention(const std::vector& h_q, const std::vector& h_k, const std::vector& h_v, std::vector& h_o, int batch_size, int target_seq_len, int src_seq_len, - int query_heads, int kv_heads, int head_dim, bool is_causal) { + int query_heads, int kv_heads, int head_dim, bool is_causal) { + // 输出形状:[B, Tq, Hq, D] + const size_t o_elems = + static_cast(batch_size > 0 ? batch_size : 0) * + static_cast(target_seq_len > 0 ? target_seq_len : 0) * + static_cast(query_heads > 0 ? query_heads : 0) * + static_cast(head_dim > 0 ? head_dim : 0); + if (h_o.size() != o_elems) { + h_o.resize(o_elems); + } + + if (o_elems == 0) { + return; + } + if (batch_size <= 0 || target_seq_len <= 0 || src_seq_len <= 0 || + query_heads <= 0 || kv_heads <= 0 || head_dim <= 0) { + // 输入维度非法时,按全 0 输出处理,避免非法 launch/malloc。 + std::fill(h_o.begin(), h_o.end(), T(0)); + return; + } + + size_t elem_size = sizeof(T); + // 输入布局: + // Q: [B, Tq, Hq, D] + // K: [B, Tk, Hkv, D] + // V: [B, Tk, Hkv, D] + size_t q_elems = (size_t)batch_size * target_seq_len * query_heads * head_dim; + size_t k_elems = (size_t)batch_size * src_seq_len * kv_heads * head_dim; + size_t v_elems = (size_t)batch_size * src_seq_len * kv_heads * head_dim; + size_t q_size = q_elems * elem_size; + size_t k_size = k_elems * elem_size; + size_t v_size = v_elems * elem_size; + size_t o_size = static_cast(batch_size) * target_seq_len * query_heads * head_dim * elem_size; + + if (h_q.size() != q_elems || h_k.size() != k_elems || h_v.size() != v_elems) { + throw std::invalid_argument("flashAttention: input tensor sizes do not match provided dimensions."); + } + + T* d_q = nullptr; + T* d_k = nullptr; + T* d_v = nullptr; + T* d_o = nullptr; + + try { + MUSA_CHECK(musaMalloc(&d_q, q_size)); + MUSA_CHECK(musaMalloc(&d_k, k_size)); + MUSA_CHECK(musaMalloc(&d_v, v_size)); + MUSA_CHECK(musaMalloc(&d_o, o_size)); + + MUSA_CHECK(musaMemcpy(d_q, h_q.data(), q_size, musaMemcpyHostToDevice)); + MUSA_CHECK(musaMemcpy(d_k, h_k.data(), k_size, musaMemcpyHostToDevice)); + MUSA_CHECK(musaMemcpy(d_v, h_v.data(), v_size, musaMemcpyHostToDevice)); + + // 网格布局: + // grid.x -> query 位置 t + // grid.y -> query head h + // grid.z -> batch b + dim3 grid_dim( + target_seq_len, + query_heads, + batch_size + ); + + int block_size = FLASH_BLOCK_SIZE; + + // 动态共享内存:q_vec + out_accum(各 head_dim 个 float) + size_t smem_size = 0; + smem_size += (2 * static_cast(head_dim)) * sizeof(float); + + int device = 0; + MUSA_CHECK(musaGetDevice(&device)); + + int max_smem = 0; + MUSA_CHECK(musaDeviceGetAttribute(&max_smem, musaDevAttrMaxSharedMemoryPerBlock, device)); + if (smem_size > static_cast(max_smem)) { + throw std::invalid_argument("flashAttention: shared memory requirement exceeds device limit."); + } + + flashAttentionKernel<<>>( + d_q, d_k, d_v, d_o, + batch_size, target_seq_len, src_seq_len, + query_heads, kv_heads, head_dim, is_causal + ); + + MUSA_CHECK(musaGetLastError()); + MUSA_CHECK(musaDeviceSynchronize()); + + MUSA_CHECK(musaMemcpy(h_o.data(), d_o, o_size, musaMemcpyDeviceToHost)); + } catch (...) { + if (d_q != nullptr) MUSA_CHECK(musaFree(d_q)); + if (d_k != nullptr) MUSA_CHECK(musaFree(d_k)); + if (d_v != nullptr) MUSA_CHECK(musaFree(d_v)); + if (d_o != nullptr) MUSA_CHECK(musaFree(d_o)); + throw; + } + + MUSA_CHECK(musaFree(d_q)); + MUSA_CHECK(musaFree(d_k)); + MUSA_CHECK(musaFree(d_v)); + MUSA_CHECK(musaFree(d_o)); } // ********************************************************************* -// Explicit Template Instantiations (REQUIRED FOR LINKING WITH TESTER.O) -// DO NOT MODIFY THIS SECTION +// Explicit Template Instantiations // ********************************************************************* template int trace(const std::vector&, size_t, size_t); template float trace(const std::vector&, size_t, size_t); @@ -57,4 +302,4 @@ template void flashAttention(const std::vector&, const std::vector int, int, int, int, int, int, bool); template void flashAttention(const std::vector&, const std::vector&, const std::vector&, std::vector&, - int, int, int, int, int, int, bool); + int, int, int, int, int, int, bool); \ No newline at end of file diff --git a/src/kernels.o b/src/kernels.o new file mode 100644 index 00000000..bea2689a Binary files /dev/null and b/src/kernels.o differ diff --git a/test_kernels b/test_kernels new file mode 100755 index 00000000..5a065eff Binary files /dev/null and b/test_kernels differ