diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 204ea33e50b..191047c60da 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -941,9 +941,7 @@ void SpeculateScheduleCache(const paddle::Tensor& draft_tokens, const int block_size, const int max_draft_tokens); -void NgramMatch(const paddle::Tensor& input_ids, - const paddle::Tensor& input_ids_len, - const paddle::Tensor& token_ids_all, +void NgramMatch(const paddle::Tensor& token_ids_all, const paddle::Tensor& prompt_lens, const paddle::Tensor& step_idx, const paddle::Tensor& draft_token_num, diff --git a/custom_ops/gpu_ops/speculate_decoding/ngram_match.cu b/custom_ops/gpu_ops/speculate_decoding/ngram_match.cu index 2f4904ee26c..2c42ad49539 100644 --- a/custom_ops/gpu_ops/speculate_decoding/ngram_match.cu +++ b/custom_ops/gpu_ops/speculate_decoding/ngram_match.cu @@ -27,9 +27,7 @@ // the tentative new seq_lens_this_time to a copy buffer. // Phase 2 will decide which ones to keep (threshold logic). // ============================================================ -__global__ void ngram_match_search_kernel(const int64_t *input_ids, - const int64_t *input_ids_len, - const int64_t *token_ids_all, +__global__ void ngram_match_search_kernel(const int64_t *token_ids_all, const int64_t *prompt_lens, const int64_t *step_idx, const int *draft_token_num, @@ -38,7 +36,6 @@ __global__ void ngram_match_search_kernel(const int64_t *input_ids, const int64_t *max_dec_len, int64_t *draft_tokens_copy, int32_t *seq_lens_this_time_copy, - int64_t input_ids_stride, int64_t max_model_len, int64_t draft_tokens_stride, int64_t max_batch_size, @@ -63,9 +60,9 @@ __global__ void ngram_match_search_kernel(const int64_t *input_ids, // Active decoder item: at least the base token. if (threadIdx.x == 0) seq_lens_this_time_copy[batch_idx] = 1; - const int64_t *cur_input_ids = input_ids + batch_idx * input_ids_stride; - const int64_t cur_input_ids_len = input_ids_len[batch_idx]; const int64_t prompt_len = prompt_lens[batch_idx]; + const int64_t *cur_input_ids = token_ids_all + batch_idx * max_model_len; + const int64_t cur_input_ids_len = prompt_len; const int64_t *cur_pre_ids = token_ids_all + batch_idx * max_model_len + prompt_len; const int64_t cur_step_idx = step_idx[batch_idx]; @@ -79,7 +76,7 @@ __global__ void ngram_match_search_kernel(const int64_t *input_ids, for (int ngram_size = max_ngram_size; ngram_size >= 1; --ngram_size) { if (cur_step_idx < ngram_size) continue; - const int64_t *ngram = cur_pre_ids + (cur_step_idx + 1 - ngram_size); + const int64_t *ngram = cur_pre_ids + (cur_step_idx - ngram_size); int64_t pos = parallel_ngram_search( cur_input_ids, cur_input_ids_len, ngram, ngram_size, &s_min_pos); @@ -235,9 +232,7 @@ static int sum_cpu(const int *value, int num) { return sum_value; } -static void find_candidate_pred_tokens(const int64_t *input_ids, - const int64_t *input_ids_len, - const int64_t *token_ids_all, +static void find_candidate_pred_tokens(const int64_t *token_ids_all, const int64_t *prompt_lens, const int64_t *step_idx, const int *draft_token_num, @@ -246,7 +241,6 @@ static void find_candidate_pred_tokens(const int64_t *input_ids, int32_t *seq_lens_encoder, int32_t *seq_lens_decoder, int64_t *max_dec_len, - int64_t input_ids_stride, int64_t max_model_len, int64_t draft_tokens_stride, int64_t max_batch_size, @@ -274,12 +268,12 @@ static void find_candidate_pred_tokens(const int64_t *input_ids, continue; } - const int64_t *cur_input_ids = input_ids + batch_idx * input_ids_stride; + const int64_t *cur_input_ids = token_ids_all + batch_idx * max_model_len; + const int64_t cur_input_ids_len = prompt_lens[batch_idx]; int64_t *cur_draft_tokens = draft_tokens + batch_idx * draft_tokens_stride; const int64_t *cur_pre_ids = - token_ids_all + batch_idx * max_model_len + prompt_lens[batch_idx]; + token_ids_all + batch_idx * max_model_len + cur_input_ids_len; const int64_t cur_step_idx = step_idx[batch_idx]; - const int64_t cur_input_ids_len = input_ids_len[batch_idx]; seq_lens_this_time[batch_idx] = 1; unprocessed_batch_size--; @@ -301,7 +295,7 @@ static void find_candidate_pred_tokens(const int64_t *input_ids, if (cur_step_idx < ngram_size) { continue; } - const int64_t *ngram = cur_pre_ids + (cur_step_idx + 1 - ngram_size); + const int64_t *ngram = cur_pre_ids + (cur_step_idx - ngram_size); bool match_input = false; for (int64_t i = 0; i <= cur_input_ids_len - ngram_size; ++i) { @@ -370,9 +364,7 @@ static void find_candidate_pred_tokens(const int64_t *input_ids, // bsz × NGRAM_BLOCK_THREADS threads. Phase 2 is O(bsz) with scans. // ============================================================ -void NgramMatch(const paddle::Tensor &input_ids, - const paddle::Tensor &input_ids_len, - const paddle::Tensor &token_ids_all, +void NgramMatch(const paddle::Tensor &token_ids_all, const paddle::Tensor &prompt_lens, const paddle::Tensor &step_idx, const paddle::Tensor &draft_token_num, @@ -383,9 +375,6 @@ void NgramMatch(const paddle::Tensor &input_ids, const paddle::Tensor &max_dec_len, const int max_ngram_size, const int max_draft_tokens) { - auto input_ids_shape = input_ids.shape(); - const int64_t input_ids_stride = input_ids_shape[1]; - const int64_t max_model_len = token_ids_all.shape()[1]; auto draft_tokens_shape = draft_tokens.shape(); @@ -399,8 +388,8 @@ void NgramMatch(const paddle::Tensor &input_ids, threshold = std::stoi(env_var); } - if (input_ids.is_gpu()) { - auto stream = input_ids.stream(); + if (token_ids_all.is_gpu()) { + auto stream = token_ids_all.stream(); // Persistent scratch buffers for Phase 1 → Phase 2 communication. // Cached across calls to avoid per-invocation allocation overhead. @@ -416,9 +405,9 @@ void NgramMatch(const paddle::Tensor &input_ids, draft_tokens_stride > s_scratch_stride) { s_draft_copy = paddle::empty({max_batch_size, draft_tokens_stride}, paddle::DataType::INT64, - input_ids.place()); + token_ids_all.place()); s_seqlens_copy = paddle::empty( - {max_batch_size}, paddle::DataType::INT32, input_ids.place()); + {max_batch_size}, paddle::DataType::INT32, token_ids_all.place()); s_scratch_batch = max_batch_size; s_scratch_stride = draft_tokens_stride; } @@ -435,8 +424,6 @@ void NgramMatch(const paddle::Tensor &input_ids, NGRAM_BLOCK_THREADS, 0, stream>>>( - input_ids.data(), - input_ids_len.data(), token_ids_all.data(), prompt_lens.data(), step_idx.data(), @@ -446,7 +433,6 @@ void NgramMatch(const paddle::Tensor &input_ids, max_dec_len.data(), draft_tokens_copy.data(), seq_lens_this_time_copy.data(), - input_ids_stride, max_model_len, draft_tokens_stride, max_batch_size, @@ -465,8 +451,6 @@ void NgramMatch(const paddle::Tensor &input_ids, threshold); } else { find_candidate_pred_tokens( - input_ids.data(), - input_ids_len.data(), token_ids_all.data(), prompt_lens.data(), step_idx.data(), @@ -476,7 +460,6 @@ void NgramMatch(const paddle::Tensor &input_ids, const_cast(seq_lens_encoder.data()), const_cast(seq_lens_decoder.data()), const_cast(max_dec_len.data()), - input_ids_stride, max_model_len, draft_tokens_stride, max_batch_size, @@ -486,9 +469,7 @@ void NgramMatch(const paddle::Tensor &input_ids, } PD_BUILD_STATIC_OP(ngram_match) - .Inputs({"input_ids", - "input_ids_len", - "token_ids_all", + .Inputs({"token_ids_all", "prompt_lens", "step_idx", "draft_token_num", diff --git a/fastdeploy/config.py b/fastdeploy/config.py index f18a4c6ee0a..6231aeae767 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -1964,6 +1964,7 @@ def __init__( in [ SpecMethod.MTP, SpecMethod.SUFFIX, + SpecMethod.NGRAM, ] ) else 0 diff --git a/fastdeploy/spec_decode/ngram.py b/fastdeploy/spec_decode/ngram.py index 2de823b36da..f55d5fe83a9 100644 --- a/fastdeploy/spec_decode/ngram.py +++ b/fastdeploy/spec_decode/ngram.py @@ -16,8 +16,6 @@ from typing import TYPE_CHECKING -import paddle - from fastdeploy.model_executor.ops.gpu import ngram_match from .base import Proposer @@ -36,23 +34,12 @@ class NgramProposer(Proposer): def __init__(self, fd_config: "FDConfig"): super().__init__(fd_config) self.max_ngram_size = self.speculative_config.max_ngram_size - self.input_ids_len = paddle.zeros(shape=[self.max_num_seqs, 1], dtype="int64").cpu() - self.input_ids_len_gpu = paddle.zeros(shape=[self.max_num_seqs, 1], dtype="int64").cuda() - - def update(self, bid: int, seq_len: int): - """ - update - """ - self.input_ids_len[bid] = seq_len - self.input_ids_len_gpu[bid] = seq_len def _run_impl(self, share_inputs): """ run """ ngram_match( - share_inputs["input_ids_cpu"].cuda(), - self.input_ids_len_gpu, share_inputs["token_ids_all"], share_inputs["prompt_lens"], share_inputs["step_idx"], diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 1127d5c724e..0c00d6a6c43 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -2094,7 +2094,11 @@ def capture_model(self) -> None: logger.info( f"Warm up the model with the num_tokens:{num_tokens}, expected_decode_len:{expected_decode_len}" ) - elif self.speculative_decoding and self.spec_method in [SpecMethod.MTP, SpecMethod.SUFFIX]: + elif self.speculative_decoding and self.spec_method in [ + SpecMethod.MTP, + SpecMethod.SUFFIX, + SpecMethod.NGRAM, + ]: for capture_size in sorted(capture_sizes, reverse=True): expected_decode_len = (self.speculative_config.num_speculative_tokens + 1) * 2 self._dummy_run( diff --git a/tests/operators/test_ngram_match.py b/tests/operators/test_ngram_match.py index 139b487de53..d4f36f51f06 100644 --- a/tests/operators/test_ngram_match.py +++ b/tests/operators/test_ngram_match.py @@ -26,24 +26,19 @@ def setUp(self): def test_basic_match(self): """ - Case 1: input_ids overlaps with token_ids_all, and can extract draft tokens. + Case 1: prompt overlaps with pre_ids, and can extract draft tokens. """ batch_size = 1 - seq_len = 6 - - # Input IDs - input_ids = paddle.to_tensor([[10, 20, 30, 40, 50, 60]], dtype="int64") - # Length of input IDs - input_ids_len = paddle.to_tensor([6], dtype="int64") - # Previous IDs - token_ids_all = paddle.to_tensor([[10, 20, 30, 40, 0, 0]], dtype="int64") - prompt_lens = paddle.zeros([4, 1], dtype="int64") - # Current step index - step_idx = paddle.to_tensor([3], dtype="int64") + + # Combined prompt and generated IDs: prompt=[10,20,30,40,50,60], generated=[10,20,30,40,0,0] + token_ids_all = paddle.to_tensor([[10, 20, 30, 40, 50, 60, 10, 20, 30, 40, 0, 0]], dtype="int64") + prompt_lens = paddle.to_tensor([[6]], dtype="int64") + # Current step index: 4 tokens generated (positions 0-3 valid) + step_idx = paddle.to_tensor([4], dtype="int64") # Number of draft tokens draft_token_num = paddle.to_tensor([3], dtype="int32") # Placeholder for draft tokens - draft_tokens = paddle.zeros([batch_size, seq_len], dtype="int64") + draft_tokens = paddle.zeros([batch_size, 6], dtype="int64") # Sequence lengths for this time step seq_lens_this_time = paddle.zeros([batch_size], dtype="int32") @@ -55,8 +50,6 @@ def test_basic_match(self): max_dec_len = paddle.to_tensor([10], dtype="int64") ngram_match( - input_ids, - input_ids_len, token_ids_all, prompt_lens, step_idx, @@ -80,14 +73,13 @@ def test_basic_match(self): def test_no_match(self): """ - Case 2: token_ids_all does not match input_ids, should only keep the current token. + Case 2: pre_ids does not match prompt, should only keep the current token. """ batch_size = 1 - input_ids = paddle.to_tensor([[100, 200, 300, 400]], dtype="int64") - input_ids_len = paddle.to_tensor([4], dtype="int64") - token_ids_all = paddle.to_tensor([[1, 2, 3, 4]], dtype="int64") - prompt_lens = paddle.zeros([4, 1], dtype="int64") - step_idx = paddle.to_tensor([3], dtype="int64") + # Combined prompt and generated IDs: prompt=[100,200,300,400], generated=[1,2,3,4] + token_ids_all = paddle.to_tensor([[100, 200, 300, 400, 1, 2, 3, 4]], dtype="int64") + prompt_lens = paddle.to_tensor([4], dtype="int64") + step_idx = paddle.to_tensor([4], dtype="int64") draft_token_num = paddle.to_tensor([2], dtype="int32") draft_tokens = paddle.zeros([batch_size, 4], dtype="int64") @@ -97,8 +89,6 @@ def test_no_match(self): max_dec_len = paddle.to_tensor([6], dtype="int64") ngram_match( - input_ids, - input_ids_len, token_ids_all, prompt_lens, step_idx, diff --git a/tests/spec_decode/test_benchmark_ngram_kernel.py b/tests/spec_decode/test_benchmark_ngram_kernel.py index 7a698b3e25f..270f611a0cc 100644 --- a/tests/spec_decode/test_benchmark_ngram_kernel.py +++ b/tests/spec_decode/test_benchmark_ngram_kernel.py @@ -48,64 +48,71 @@ def _build_data(batch_size, seq_len, hit_type="low_input", seed=42): """ Build test tensors with controlled ngram hit placement. + token_ids_all layout: [:seq_len] = prompt (haystack), [seq_len:] = generated tokens. + prompt_lens = seq_len so the kernel splits correctly. + step_idx = count of generated tokens (positions 0..step_idx-1 are valid in pre_ids). + hit_type controls where the ngram match is found: - - high_input: match near start of input_ids (fast find) - - high_pre: match near start of token_ids_all gen tokens - - low_input: match near end of input_ids (worst-case scan) - - low_pre: match near end of token_ids_all gen tokens + - high_input: match near start of prompt (fast input scan) + - high_pre: match near start of pre_ids (fast pre scan) + - low_input: match near end of prompt (worst-case input scan) + - low_pre: match near end of pre_ids (worst-case pre scan) - none: no planted match (full scan, no hit) """ rng = np.random.RandomState(seed) step_idx_val = max(MAX_NGRAM_SIZE + 2, 20) - pre_len = step_idx_val + 1 - max_model_len = max(seq_len + 64, pre_len + 64) + pre_len = step_idx_val + 1 # space for step_idx_val generated tokens + # Prompt tokens (haystack) and generated tokens (pre_ids) are separate arrays + # then merged into a single token_ids_all buffer. input_ids = rng.randint(10, 500, (batch_size, seq_len)).astype(np.int64) - token_ids_all = rng.randint(10, 500, (batch_size, max_model_len)).astype(np.int64) + pre_ids_area = rng.randint(10, 500, (batch_size, pre_len)).astype(np.int64) + token_ids_all = np.concatenate([input_ids, pre_ids_area], axis=1) + + # Pattern = what the fixed kernel reads: pre_ids[step_idx - ngram_size : step_idx] + # = token_ids_all[:, seq_len + step_idx_val - MAX_NGRAM_SIZE : seq_len + step_idx_val] pattern = np.arange(1001, 1001 + MAX_NGRAM_SIZE, dtype=np.int64) + ng_start_in_pre = step_idx_val - MAX_NGRAM_SIZE # relative to pre_ids start for b in range(batch_size): - # Plant pattern in token_ids_all at step_idx alignment (the ngram to search for) - ng_start = step_idx_val + 1 - MAX_NGRAM_SIZE - token_ids_all[b, ng_start : step_idx_val + 1] = pattern + # Plant pattern at the pre_ids position the kernel reads (after the fix) + token_ids_all[b, seq_len + ng_start_in_pre : seq_len + step_idx_val] = pattern if hit_type == "high_input": pos = 5 if pos + MAX_NGRAM_SIZE + MAX_DRAFT_TOKENS <= seq_len: - input_ids[b, pos : pos + MAX_NGRAM_SIZE] = pattern - input_ids[b, pos + MAX_NGRAM_SIZE : pos + MAX_NGRAM_SIZE + MAX_DRAFT_TOKENS] = np.arange( + token_ids_all[b, pos : pos + MAX_NGRAM_SIZE] = pattern + token_ids_all[b, pos + MAX_NGRAM_SIZE : pos + MAX_NGRAM_SIZE + MAX_DRAFT_TOKENS] = np.arange( 2001, 2001 + MAX_DRAFT_TOKENS, dtype=np.int64 ) elif hit_type == "high_pre": pos = 5 - if pos + MAX_NGRAM_SIZE + MAX_DRAFT_TOKENS < ng_start: - token_ids_all[b, pos : pos + MAX_NGRAM_SIZE] = pattern - token_ids_all[b, pos + MAX_NGRAM_SIZE : pos + MAX_NGRAM_SIZE + MAX_DRAFT_TOKENS] = np.arange( - 2001, 2001 + MAX_DRAFT_TOKENS, dtype=np.int64 - ) + if pos + MAX_NGRAM_SIZE + MAX_DRAFT_TOKENS < ng_start_in_pre: + token_ids_all[b, seq_len + pos : seq_len + pos + MAX_NGRAM_SIZE] = pattern + token_ids_all[ + b, seq_len + pos + MAX_NGRAM_SIZE : seq_len + pos + MAX_NGRAM_SIZE + MAX_DRAFT_TOKENS + ] = np.arange(2001, 2001 + MAX_DRAFT_TOKENS, dtype=np.int64) elif hit_type == "low_input": pos = seq_len - MAX_NGRAM_SIZE - MAX_DRAFT_TOKENS - 5 if pos > 0: - input_ids[b, pos : pos + MAX_NGRAM_SIZE] = pattern - input_ids[b, pos + MAX_NGRAM_SIZE : pos + MAX_NGRAM_SIZE + MAX_DRAFT_TOKENS] = np.arange( + token_ids_all[b, pos : pos + MAX_NGRAM_SIZE] = pattern + token_ids_all[b, pos + MAX_NGRAM_SIZE : pos + MAX_NGRAM_SIZE + MAX_DRAFT_TOKENS] = np.arange( 2001, 2001 + MAX_DRAFT_TOKENS, dtype=np.int64 ) elif hit_type == "low_pre": pos = step_idx_val - MAX_NGRAM_SIZE - MAX_DRAFT_TOKENS - 5 - if pos > 0 and pos + MAX_NGRAM_SIZE + MAX_DRAFT_TOKENS < ng_start: - token_ids_all[b, pos : pos + MAX_NGRAM_SIZE] = pattern - token_ids_all[b, pos + MAX_NGRAM_SIZE : pos + MAX_NGRAM_SIZE + MAX_DRAFT_TOKENS] = np.arange( - 2001, 2001 + MAX_DRAFT_TOKENS, dtype=np.int64 - ) + if pos > 0 and pos + MAX_NGRAM_SIZE + MAX_DRAFT_TOKENS < ng_start_in_pre: + token_ids_all[b, seq_len + pos : seq_len + pos + MAX_NGRAM_SIZE] = pattern + token_ids_all[ + b, seq_len + pos + MAX_NGRAM_SIZE : seq_len + pos + MAX_NGRAM_SIZE + MAX_DRAFT_TOKENS + ] = np.arange(2001, 2001 + MAX_DRAFT_TOKENS, dtype=np.int64) - elif hit_type == "none": - pass # No match planted — random data only + # hit_type == "none": no match planted — random data only - input_ids_len = np.full((batch_size, 1), seq_len, dtype=np.int64) - prompt_lens = np.zeros((batch_size, 1), dtype=np.int64) + prompt_lens = np.full((batch_size, 1), seq_len, dtype=np.int64) step_idx = np.full((batch_size, 1), step_idx_val, dtype=np.int64) draft_token_num = np.full((batch_size, 1), MAX_DRAFT_TOKENS, dtype=np.int32) draft_tokens = np.zeros((batch_size, MAX_DRAFT_TOKENS + 1), dtype=np.int64) @@ -115,8 +122,6 @@ def _build_data(batch_size, seq_len, hit_type="low_input", seed=42): max_dec_len = np.full((batch_size, 1), 1048576, dtype=np.int64) return { - "input_ids": input_ids, - "input_ids_len": input_ids_len, "token_ids_all": token_ids_all, "prompt_lens": prompt_lens, "step_idx": step_idx, @@ -139,8 +144,6 @@ def _to_gpu(np_dict): def _run_gpu(ngram_match_fn, gpu_data): """Run GPU kernel (tensors already on GPU).""" ngram_match_fn( - gpu_data["input_ids"], - gpu_data["input_ids_len"], gpu_data["token_ids_all"], gpu_data["prompt_lens"], gpu_data["step_idx"], diff --git a/tests/spec_decode/test_ngram_gpu_kernel.py b/tests/spec_decode/test_ngram_gpu_kernel.py index ad1f5ea845b..83c88d78f12 100644 --- a/tests/spec_decode/test_ngram_gpu_kernel.py +++ b/tests/spec_decode/test_ngram_gpu_kernel.py @@ -90,7 +90,7 @@ def _cpu_ngram_match( for ngram_size in range(max_ngram_size, 0, -1): if cur_step < ngram_size: continue - ngram = cur_pre_ids[cur_step + 1 - ngram_size : cur_step + 1] + ngram = cur_pre_ids[cur_step - ngram_size : cur_step] # Search in input_ids match_input = False @@ -241,8 +241,8 @@ def _make_ngram_test_data(batch_size=4, input_len=64, max_model_len=256, max_dra gen_len = 20 src = rng.randint(0, max(1, input_len - gen_len)) token_ids_all[b, input_len : input_len + gen_len] = input_ids[b, src : src + gen_len] - # step_idx = last valid position (0-based index) - step_idx[b] = gen_len - 1 + # step_idx = count of generated tokens (positions 0..step_idx-1 are valid) + step_idx[b] = gen_len return { "input_ids": input_ids, @@ -281,7 +281,7 @@ def _make_mixed_test_data(batch_size=4, input_len=64, pre_ids_len=256, max_draft gen_len = 20 src = rng.randint(0, max(1, input_len - gen_len)) pre_ids[b, :gen_len] = input_ids[b, src : src + gen_len] - # step_idx = last valid position (0-based index) + # step_idx = last valid position (0-based index), matches hybrid kernel semantics step_idx[b] = gen_len - 1 return { @@ -349,8 +349,6 @@ def test_correctness_basic(self): # GPU kernel gpu_data = _to_gpu(data) self.ngram_match( - gpu_data["input_ids"], - gpu_data["input_ids_len"], gpu_data["token_ids_all"], gpu_data["prompt_lens"], gpu_data["step_idx"], @@ -395,8 +393,6 @@ def test_correctness_varied_seeds(self): ) gpu_data = _to_gpu(data) self.ngram_match( - gpu_data["input_ids"], - gpu_data["input_ids_len"], gpu_data["token_ids_all"], gpu_data["prompt_lens"], gpu_data["step_idx"], @@ -444,8 +440,6 @@ def test_large_batch_long_seq(self): os.environ["INFER_WITH_REFERENCE_TOKENUM_THRESHOLD"] = str(high_threshold) try: self.ngram_match( - gpu_data["input_ids"], - gpu_data["input_ids_len"], gpu_data["token_ids_all"], gpu_data["prompt_lens"], gpu_data["step_idx"], @@ -489,8 +483,6 @@ def test_single_batch_long_seq(self): ) gpu_data = _to_gpu(data) self.ngram_match( - gpu_data["input_ids"], - gpu_data["input_ids_len"], gpu_data["token_ids_all"], gpu_data["prompt_lens"], gpu_data["step_idx"], @@ -534,8 +526,6 @@ def test_many_short_seqs(self): os.environ["INFER_WITH_REFERENCE_TOKENUM_THRESHOLD"] = str(high_threshold) try: self.ngram_match( - gpu_data["input_ids"], - gpu_data["input_ids_len"], gpu_data["token_ids_all"], gpu_data["prompt_lens"], gpu_data["step_idx"], @@ -563,8 +553,6 @@ def test_latency(self): for _ in range(1): d = _to_gpu(_make_ngram_test_data(batch_size=32, input_len=512, seed=42)) self.ngram_match( - d["input_ids"], - d["input_ids_len"], d["token_ids_all"], d["prompt_lens"], d["step_idx"], @@ -587,8 +575,6 @@ def test_latency(self): t0 = time.perf_counter() for _ in range(n_runs): self.ngram_match( - gpu_data["input_ids"], - gpu_data["input_ids_len"], gpu_data["token_ids_all"], gpu_data["prompt_lens"], gpu_data["step_idx"], @@ -639,8 +625,6 @@ def test_latency_scaling(self): # Warmup for _ in range(1): self.ngram_match( - gpu_data["input_ids"], - gpu_data["input_ids_len"], gpu_data["token_ids_all"], gpu_data["prompt_lens"], gpu_data["step_idx"], @@ -660,8 +644,6 @@ def test_latency_scaling(self): t0 = time.perf_counter() for _ in range(n_runs): self.ngram_match( - gpu_data["input_ids"], - gpu_data["input_ids_len"], gpu_data["token_ids_all"], gpu_data["prompt_lens"], gpu_data["step_idx"], @@ -744,8 +726,6 @@ def test_latency_extreme(self): # Warmup for _ in range(1): self.ngram_match( - gpu_data["input_ids"], - gpu_data["input_ids_len"], gpu_data["token_ids_all"], gpu_data["prompt_lens"], gpu_data["step_idx"], @@ -765,8 +745,6 @@ def test_latency_extreme(self): t0 = time.perf_counter() for _ in range(n_runs): self.ngram_match( - gpu_data["input_ids"], - gpu_data["input_ids_len"], gpu_data["token_ids_all"], gpu_data["prompt_lens"], gpu_data["step_idx"], diff --git a/tests/spec_decode/test_ngram_proposer.py b/tests/spec_decode/test_ngram_proposer.py new file mode 100644 index 00000000000..cf54e5f28a5 --- /dev/null +++ b/tests/spec_decode/test_ngram_proposer.py @@ -0,0 +1,182 @@ +""" +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import unittest + +import numpy as np +import paddle +from utils import FakeModelConfig, get_default_test_fd_config + +from fastdeploy.config import SpeculativeConfig +from fastdeploy.spec_decode.ngram import NgramProposer + + +class TestNgramProposer(unittest.TestCase): + def setUp(self): + if not paddle.is_compiled_with_cuda(): + raise unittest.SkipTest("CUDA not available") + paddle.set_device("gpu") + try: + import fastdeploy.model_executor.ops.gpu as _gpu_ops + + if getattr(_gpu_ops, "ngram_match", None) is None: + raise ImportError("ngram_match op not compiled") + except Exception as e: + raise unittest.SkipTest(f"Cannot import ngram_match op: {e}") + + fd_config = get_default_test_fd_config() + fd_config.model_config = FakeModelConfig() + fd_config.model_config.max_model_len = 256 + fd_config.speculative_config = SpeculativeConfig({}) + fd_config.speculative_config.method = "ngram" + fd_config.speculative_config.num_speculative_tokens = 5 + fd_config.speculative_config.max_ngram_size = 3 + fd_config.speculative_config.min_ngram_size = 1 + fd_config.scheduler_config.max_num_seqs = 2 + self.fd_config = fd_config + + bsz = fd_config.scheduler_config.max_num_seqs + max_draft = fd_config.speculative_config.num_speculative_tokens + max_model_len = fd_config.model_config.max_model_len + self.bsz = bsz + self.max_draft = max_draft + self.max_model_len = max_model_len + + self.share_inputs = { + "token_ids_all": paddle.zeros([bsz, max_model_len], dtype="int64"), + "prompt_lens": paddle.zeros([bsz, 1], dtype="int64"), + "step_idx": paddle.zeros([bsz, 1], dtype="int64"), + "actual_draft_token_num": paddle.full([bsz, 1], fill_value=max_draft, dtype="int32"), + "draft_tokens": paddle.zeros([bsz, max_draft + 1], dtype="int64"), + "seq_lens_this_time": paddle.ones([bsz], dtype="int32"), + "seq_lens_encoder": paddle.zeros([bsz], dtype="int32"), + "seq_lens_decoder": paddle.ones([bsz], dtype="int32"), + "max_dec_len": paddle.full([bsz, 1], fill_value=200, dtype="int64"), + } + + # Init / config binding + def test_init_config_binding(self): + """max_ngram_size and max_draft_token_num are correctly read from fd_config.""" + proposer = NgramProposer(self.fd_config) + self.assertEqual(proposer.max_ngram_size, 3) + self.assertEqual(proposer.max_draft_token_num, 5) + + # No-proposal scenarios + def test_run_no_proposal_step_idx_zero(self): + """step_idx=0 means no tokens generated; kernel cannot form any ngram pattern.""" + proposer = NgramProposer(self.fd_config) + self.share_inputs["step_idx"][:] = 0 + proposer.run(self.share_inputs) + paddle.device.synchronize() + + slt = self.share_inputs["seq_lens_this_time"].numpy() + np.testing.assert_array_equal(slt, [1, 1], err_msg="seq_lens_this_time should remain 1 when step_idx=0") + + # No-proposal scenarios + def test_run_no_proposal_tokens_not_in_prompt(self): + """Generated tokens never appear in the prompt → no match, no draft proposals.""" + proposer = NgramProposer(self.fd_config) + + prompt_len = 6 + prompt = [1, 2, 3, 4, 5, 6] # unique tokens + generated = [100, 200, 300] # tokens absent from prompt + + token_ids_all_np = np.zeros((self.bsz, self.max_model_len), dtype=np.int64) + for b in range(self.bsz): + token_ids_all_np[b, :prompt_len] = prompt + token_ids_all_np[b, prompt_len : prompt_len + len(generated)] = generated + + self.share_inputs["token_ids_all"] = paddle.to_tensor(token_ids_all_np, place=paddle.CUDAPlace(0)) + self.share_inputs["prompt_lens"] = paddle.full([self.bsz, 1], fill_value=prompt_len, dtype="int64") + self.share_inputs["step_idx"] = paddle.full([self.bsz, 1], fill_value=len(generated), dtype="int64") + + proposer.run(self.share_inputs) + paddle.device.synchronize() + + slt = self.share_inputs["seq_lens_this_time"].numpy() + np.testing.assert_array_equal( + slt, [1, 1], err_msg="No match expected when generated tokens absent from prompt" + ) + + # Successful proposal + def test_run_with_match_produces_draft_tokens(self): + """ + When the last ngram_size generated tokens reappear in the prompt, + the tokens following that match position become draft proposals. + + Setup (max_ngram_size=3, step_idx=3): + prompt = [10, 20, 30, 40, 50, 10, 20, 30] (prompt_len=8) + generated = [40, 50, 10] (step_idx=3) + + Pattern = generated[step_idx - 3 : step_idx] = [40, 50, 10] + Matches prompt at position 3 → proposals = prompt[6:8] = [20, 30] + Expected: seq_lens_this_time = 3, draft_tokens[:, 1:3] = [[20, 30], [20, 30]] + """ + proposer = NgramProposer(self.fd_config) + + prompt_len = 8 + prompt = [10, 20, 30, 40, 50, 10, 20, 30] + generated = [40, 50, 10] + + token_ids_all_np = np.zeros((self.bsz, self.max_model_len), dtype=np.int64) + for b in range(self.bsz): + token_ids_all_np[b, :prompt_len] = prompt + token_ids_all_np[b, prompt_len : prompt_len + len(generated)] = generated + + self.share_inputs["token_ids_all"] = paddle.to_tensor(token_ids_all_np, place=paddle.CUDAPlace(0)) + self.share_inputs["prompt_lens"] = paddle.full([self.bsz, 1], fill_value=prompt_len, dtype="int64") + self.share_inputs["step_idx"] = paddle.full([self.bsz, 1], fill_value=len(generated), dtype="int64") + + proposer.run(self.share_inputs) + paddle.device.synchronize() + + slt = self.share_inputs["seq_lens_this_time"].numpy() + dt = self.share_inputs["draft_tokens"].numpy() + + # 1 base token + 2 draft tokens = seq_len 3 + np.testing.assert_array_equal(slt, [3, 3], err_msg="seq_lens_this_time mismatch") + # Draft slots 1 and 2 should be [20, 30] for every batch item + np.testing.assert_array_equal(dt[:, 1:3], [[20, 30], [20, 30]], err_msg="draft_tokens mismatch") + + # Successful proposal + def test_run_with_match_respects_max_dec_len(self): + """Draft count is clipped when remaining budget (max_dec_len - step_idx) is exhausted.""" + proposer = NgramProposer(self.fd_config) + + prompt_len = 8 + prompt = [10, 20, 30, 40, 50, 10, 20, 30] + generated = [40, 50, 10] + + token_ids_all_np = np.zeros((self.bsz, self.max_model_len), dtype=np.int64) + for b in range(self.bsz): + token_ids_all_np[b, :prompt_len] = prompt + token_ids_all_np[b, prompt_len : prompt_len + len(generated)] = generated + + self.share_inputs["token_ids_all"] = paddle.to_tensor(token_ids_all_np, place=paddle.CUDAPlace(0)) + self.share_inputs["prompt_lens"] = paddle.full([self.bsz, 1], fill_value=prompt_len, dtype="int64") + self.share_inputs["step_idx"] = paddle.full([self.bsz, 1], fill_value=len(generated), dtype="int64") + # remaining = max_dec_len - step_idx - 1 = 4 - 3 - 1 = 0 → no draft tokens + self.share_inputs["max_dec_len"] = paddle.full([self.bsz, 1], fill_value=4, dtype="int64") + + proposer.run(self.share_inputs) + paddle.device.synchronize() + + slt = self.share_inputs["seq_lens_this_time"].numpy() + np.testing.assert_array_equal(slt, [1, 1], err_msg="No drafts expected when max_dec_len budget exhausted") + + +if __name__ == "__main__": + unittest.main(verbosity=2)