Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions custom_ops/gpu_ops/cpp_extensions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -941,9 +941,7 @@ void SpeculateScheduleCache(const paddle::Tensor& draft_tokens,
const int block_size,
const int max_draft_tokens);

void NgramMatch(const paddle::Tensor& input_ids,
const paddle::Tensor& input_ids_len,
const paddle::Tensor& token_ids_all,
void NgramMatch(const paddle::Tensor& token_ids_all,
const paddle::Tensor& prompt_lens,
const paddle::Tensor& step_idx,
const paddle::Tensor& draft_token_num,
Expand Down
49 changes: 15 additions & 34 deletions custom_ops/gpu_ops/speculate_decoding/ngram_match.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,7 @@
// the tentative new seq_lens_this_time to a copy buffer.
// Phase 2 will decide which ones to keep (threshold logic).
// ============================================================
__global__ void ngram_match_search_kernel(const int64_t *input_ids,
const int64_t *input_ids_len,
const int64_t *token_ids_all,
__global__ void ngram_match_search_kernel(const int64_t *token_ids_all,
const int64_t *prompt_lens,
const int64_t *step_idx,
const int *draft_token_num,
Expand All @@ -38,7 +36,6 @@ __global__ void ngram_match_search_kernel(const int64_t *input_ids,
const int64_t *max_dec_len,
int64_t *draft_tokens_copy,
int32_t *seq_lens_this_time_copy,
int64_t input_ids_stride,
int64_t max_model_len,
int64_t draft_tokens_stride,
int64_t max_batch_size,
Expand All @@ -63,9 +60,9 @@ __global__ void ngram_match_search_kernel(const int64_t *input_ids,
// Active decoder item: at least the base token.
if (threadIdx.x == 0) seq_lens_this_time_copy[batch_idx] = 1;

const int64_t *cur_input_ids = input_ids + batch_idx * input_ids_stride;
const int64_t cur_input_ids_len = input_ids_len[batch_idx];
const int64_t prompt_len = prompt_lens[batch_idx];
const int64_t *cur_input_ids = token_ids_all + batch_idx * max_model_len;
const int64_t cur_input_ids_len = prompt_len;
const int64_t *cur_pre_ids =
token_ids_all + batch_idx * max_model_len + prompt_len;
const int64_t cur_step_idx = step_idx[batch_idx];
Expand All @@ -79,7 +76,7 @@ __global__ void ngram_match_search_kernel(const int64_t *input_ids,
for (int ngram_size = max_ngram_size; ngram_size >= 1; --ngram_size) {
if (cur_step_idx < ngram_size) continue;

const int64_t *ngram = cur_pre_ids + (cur_step_idx + 1 - ngram_size);
const int64_t *ngram = cur_pre_ids + (cur_step_idx - ngram_size);

int64_t pos = parallel_ngram_search(
cur_input_ids, cur_input_ids_len, ngram, ngram_size, &s_min_pos);
Expand Down Expand Up @@ -235,9 +232,7 @@ static int sum_cpu(const int *value, int num) {
return sum_value;
}

static void find_candidate_pred_tokens(const int64_t *input_ids,
const int64_t *input_ids_len,
const int64_t *token_ids_all,
static void find_candidate_pred_tokens(const int64_t *token_ids_all,
const int64_t *prompt_lens,
const int64_t *step_idx,
const int *draft_token_num,
Expand All @@ -246,7 +241,6 @@ static void find_candidate_pred_tokens(const int64_t *input_ids,
int32_t *seq_lens_encoder,
int32_t *seq_lens_decoder,
int64_t *max_dec_len,
int64_t input_ids_stride,
int64_t max_model_len,
int64_t draft_tokens_stride,
int64_t max_batch_size,
Expand Down Expand Up @@ -274,12 +268,12 @@ static void find_candidate_pred_tokens(const int64_t *input_ids,
continue;
}

const int64_t *cur_input_ids = input_ids + batch_idx * input_ids_stride;
const int64_t *cur_input_ids = token_ids_all + batch_idx * max_model_len;
const int64_t cur_input_ids_len = prompt_lens[batch_idx];
int64_t *cur_draft_tokens = draft_tokens + batch_idx * draft_tokens_stride;
const int64_t *cur_pre_ids =
token_ids_all + batch_idx * max_model_len + prompt_lens[batch_idx];
token_ids_all + batch_idx * max_model_len + cur_input_ids_len;
const int64_t cur_step_idx = step_idx[batch_idx];
const int64_t cur_input_ids_len = input_ids_len[batch_idx];
seq_lens_this_time[batch_idx] = 1;
unprocessed_batch_size--;

Expand All @@ -301,7 +295,7 @@ static void find_candidate_pred_tokens(const int64_t *input_ids,
if (cur_step_idx < ngram_size) {
continue;
}
const int64_t *ngram = cur_pre_ids + (cur_step_idx + 1 - ngram_size);
const int64_t *ngram = cur_pre_ids + (cur_step_idx - ngram_size);

bool match_input = false;
for (int64_t i = 0; i <= cur_input_ids_len - ngram_size; ++i) {
Expand Down Expand Up @@ -370,9 +364,7 @@ static void find_candidate_pred_tokens(const int64_t *input_ids,
// bsz × NGRAM_BLOCK_THREADS threads. Phase 2 is O(bsz) with scans.
// ============================================================

void NgramMatch(const paddle::Tensor &input_ids,
const paddle::Tensor &input_ids_len,
const paddle::Tensor &token_ids_all,
void NgramMatch(const paddle::Tensor &token_ids_all,
const paddle::Tensor &prompt_lens,
const paddle::Tensor &step_idx,
const paddle::Tensor &draft_token_num,
Expand All @@ -383,9 +375,6 @@ void NgramMatch(const paddle::Tensor &input_ids,
const paddle::Tensor &max_dec_len,
const int max_ngram_size,
const int max_draft_tokens) {
auto input_ids_shape = input_ids.shape();
const int64_t input_ids_stride = input_ids_shape[1];

const int64_t max_model_len = token_ids_all.shape()[1];

auto draft_tokens_shape = draft_tokens.shape();
Expand All @@ -399,8 +388,8 @@ void NgramMatch(const paddle::Tensor &input_ids,
threshold = std::stoi(env_var);
}

if (input_ids.is_gpu()) {
auto stream = input_ids.stream();
if (token_ids_all.is_gpu()) {
auto stream = token_ids_all.stream();

// Persistent scratch buffers for Phase 1 → Phase 2 communication.
// Cached across calls to avoid per-invocation allocation overhead.
Expand All @@ -416,9 +405,9 @@ void NgramMatch(const paddle::Tensor &input_ids,
draft_tokens_stride > s_scratch_stride) {
s_draft_copy = paddle::empty({max_batch_size, draft_tokens_stride},
paddle::DataType::INT64,
input_ids.place());
token_ids_all.place());
s_seqlens_copy = paddle::empty(
{max_batch_size}, paddle::DataType::INT32, input_ids.place());
{max_batch_size}, paddle::DataType::INT32, token_ids_all.place());
s_scratch_batch = max_batch_size;
s_scratch_stride = draft_tokens_stride;
}
Expand All @@ -435,8 +424,6 @@ void NgramMatch(const paddle::Tensor &input_ids,
NGRAM_BLOCK_THREADS,
0,
stream>>>(
input_ids.data<int64_t>(),
input_ids_len.data<int64_t>(),
token_ids_all.data<int64_t>(),
prompt_lens.data<int64_t>(),
step_idx.data<int64_t>(),
Expand All @@ -446,7 +433,6 @@ void NgramMatch(const paddle::Tensor &input_ids,
max_dec_len.data<int64_t>(),
draft_tokens_copy.data<int64_t>(),
seq_lens_this_time_copy.data<int32_t>(),
input_ids_stride,
max_model_len,
draft_tokens_stride,
max_batch_size,
Expand All @@ -465,8 +451,6 @@ void NgramMatch(const paddle::Tensor &input_ids,
threshold);
} else {
find_candidate_pred_tokens(
input_ids.data<int64_t>(),
input_ids_len.data<int64_t>(),
token_ids_all.data<int64_t>(),
prompt_lens.data<int64_t>(),
step_idx.data<int64_t>(),
Expand All @@ -476,7 +460,6 @@ void NgramMatch(const paddle::Tensor &input_ids,
const_cast<int32_t *>(seq_lens_encoder.data<int32_t>()),
const_cast<int32_t *>(seq_lens_decoder.data<int32_t>()),
const_cast<int64_t *>(max_dec_len.data<int64_t>()),
input_ids_stride,
max_model_len,
draft_tokens_stride,
max_batch_size,
Expand All @@ -486,9 +469,7 @@ void NgramMatch(const paddle::Tensor &input_ids,
}

PD_BUILD_STATIC_OP(ngram_match)
.Inputs({"input_ids",
"input_ids_len",
"token_ids_all",
.Inputs({"token_ids_all",
"prompt_lens",
"step_idx",
"draft_token_num",
Expand Down
1 change: 1 addition & 0 deletions fastdeploy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1964,6 +1964,7 @@ def __init__(
in [
SpecMethod.MTP,
SpecMethod.SUFFIX,
SpecMethod.NGRAM,
]
)
else 0
Expand Down
13 changes: 0 additions & 13 deletions fastdeploy/spec_decode/ngram.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@

from typing import TYPE_CHECKING

import paddle

from fastdeploy.model_executor.ops.gpu import ngram_match

from .base import Proposer
Expand All @@ -36,23 +34,12 @@ class NgramProposer(Proposer):
def __init__(self, fd_config: "FDConfig"):
super().__init__(fd_config)
self.max_ngram_size = self.speculative_config.max_ngram_size
self.input_ids_len = paddle.zeros(shape=[self.max_num_seqs, 1], dtype="int64").cpu()
self.input_ids_len_gpu = paddle.zeros(shape=[self.max_num_seqs, 1], dtype="int64").cuda()

def update(self, bid: int, seq_len: int):
"""
update
"""
self.input_ids_len[bid] = seq_len
self.input_ids_len_gpu[bid] = seq_len

def _run_impl(self, share_inputs):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❓ 疑问 update() 方法在此 PR 中已删除,请确认 fastdeploy/worker/gpu_model_runner.py 中(如 _postprocess 等位置)已无 proposer.update(bid, seq_len) 残余调用,否则在 NGRAM 模式下会引发 AttributeError

"""
run
"""
ngram_match(
share_inputs["input_ids_cpu"].cuda(),
self.input_ids_len_gpu,
share_inputs["token_ids_all"],
share_inputs["prompt_lens"],
share_inputs["step_idx"],
Expand Down
6 changes: 5 additions & 1 deletion fastdeploy/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2094,7 +2094,11 @@ def capture_model(self) -> None:
logger.info(
f"Warm up the model with the num_tokens:{num_tokens}, expected_decode_len:{expected_decode_len}"
)
elif self.speculative_decoding and self.spec_method in [SpecMethod.MTP, SpecMethod.SUFFIX]:
elif self.speculative_decoding and self.spec_method in [
SpecMethod.MTP,
SpecMethod.SUFFIX,
SpecMethod.NGRAM,

This comment was marked as outdated.

]:
for capture_size in sorted(capture_sizes, reverse=True):
expected_decode_len = (self.speculative_config.num_speculative_tokens + 1) * 2
self._dummy_run(
Expand Down
36 changes: 13 additions & 23 deletions tests/operators/test_ngram_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,24 +26,19 @@ def setUp(self):

def test_basic_match(self):
"""
Case 1: input_ids overlaps with token_ids_all, and can extract draft tokens.
Case 1: prompt overlaps with pre_ids, and can extract draft tokens.
"""
batch_size = 1
seq_len = 6

# Input IDs
input_ids = paddle.to_tensor([[10, 20, 30, 40, 50, 60]], dtype="int64")
# Length of input IDs
input_ids_len = paddle.to_tensor([6], dtype="int64")
# Previous IDs
token_ids_all = paddle.to_tensor([[10, 20, 30, 40, 0, 0]], dtype="int64")
prompt_lens = paddle.zeros([4, 1], dtype="int64")
# Current step index
step_idx = paddle.to_tensor([3], dtype="int64")

# Combined prompt and generated IDs: prompt=[10,20,30,40,50,60], generated=[10,20,30,40,0,0]
token_ids_all = paddle.to_tensor([[10, 20, 30, 40, 50, 60, 10, 20, 30, 40, 0, 0]], dtype="int64")
prompt_lens = paddle.to_tensor([[6]], dtype="int64")
# Current step index: 4 tokens generated (positions 0-3 valid)
step_idx = paddle.to_tensor([4], dtype="int64")
# Number of draft tokens
draft_token_num = paddle.to_tensor([3], dtype="int32")
# Placeholder for draft tokens
draft_tokens = paddle.zeros([batch_size, seq_len], dtype="int64")
draft_tokens = paddle.zeros([batch_size, 6], dtype="int64")

# Sequence lengths for this time step
seq_lens_this_time = paddle.zeros([batch_size], dtype="int32")
Expand All @@ -55,8 +50,6 @@ def test_basic_match(self):
max_dec_len = paddle.to_tensor([10], dtype="int64")

ngram_match(
input_ids,
input_ids_len,
token_ids_all,
prompt_lens,
step_idx,
Expand All @@ -80,14 +73,13 @@ def test_basic_match(self):

def test_no_match(self):
"""
Case 2: token_ids_all does not match input_ids, should only keep the current token.
Case 2: pre_ids does not match prompt, should only keep the current token.
"""
batch_size = 1
input_ids = paddle.to_tensor([[100, 200, 300, 400]], dtype="int64")
input_ids_len = paddle.to_tensor([4], dtype="int64")
token_ids_all = paddle.to_tensor([[1, 2, 3, 4]], dtype="int64")
prompt_lens = paddle.zeros([4, 1], dtype="int64")
step_idx = paddle.to_tensor([3], dtype="int64")
# Combined prompt and generated IDs: prompt=[100,200,300,400], generated=[1,2,3,4]
token_ids_all = paddle.to_tensor([[100, 200, 300, 400, 1, 2, 3, 4]], dtype="int64")
prompt_lens = paddle.to_tensor([4], dtype="int64")
step_idx = paddle.to_tensor([4], dtype="int64")
draft_token_num = paddle.to_tensor([2], dtype="int32")
draft_tokens = paddle.zeros([batch_size, 4], dtype="int64")

Expand All @@ -97,8 +89,6 @@ def test_no_match(self):
max_dec_len = paddle.to_tensor([6], dtype="int64")

ngram_match(
input_ids,
input_ids_len,
token_ids_all,
prompt_lens,
step_idx,
Expand Down
Loading
Loading