diff --git a/problems/nvidia/eval_better_bench_grouped_gemm.py b/problems/nvidia/eval_better_bench_grouped_gemm.py index 09b5279..3e9b552 100644 --- a/problems/nvidia/eval_better_bench_grouped_gemm.py +++ b/problems/nvidia/eval_better_bench_grouped_gemm.py @@ -1,6 +1,7 @@ import base64 import dataclasses import multiprocessing +import random import re import time import os @@ -242,10 +243,14 @@ def _run_single_benchmark( data_list = [] # generate input data once + local_seed = test.args.get("seed", None) for i in range(NUM_ITERATIONS_PER_BENCHMARK): - if "seed" in test.args: - test.args["seed"] += 42 - data = generate_input(**test.args) + if local_seed is not None: + local_seed += 42 + args = {**test.args, "seed": local_seed} + else: + args = test.args + data = generate_input(**args) data_list.append(data) check_copy = _clone_data(data_list) @@ -270,6 +275,13 @@ def _run_single_benchmark( bm_start_time = time.perf_counter_ns() for i in range(max_repeats): + # Clone and shuffle data before timing to prevent both + # object-identity caching and call-order caching exploits + iteration_data = _clone_data(data_list) + shuffle_order = list(range(len(iteration_data))) + random.shuffle(shuffle_order) + iteration_data = [iteration_data[j] for j in shuffle_order] + torch.cuda.synchronize() outputs = [] @@ -277,7 +289,7 @@ def _run_single_benchmark( start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) start_event.record() - for data in data_list: + for data in iteration_data: output = custom_kernel(data) outputs.append(output) end_event.record() @@ -287,10 +299,10 @@ def _run_single_benchmark( ) * 1e6 # Convert ms to ns if recheck: - for reference_output, custom_output in zip(check_copy, outputs): - good, message = check_implementation(reference_output, custom_output) - if not good: - return message + for j, custom_output in zip(shuffle_order, outputs): + good, message = check_implementation(check_copy[j], custom_output) + if not good: + return message durations.append(duration)