diff --git a/Ironwood/src/benchmark_utils.py b/Ironwood/src/benchmark_utils.py index 0136012..3396269 100644 --- a/Ironwood/src/benchmark_utils.py +++ b/Ironwood/src/benchmark_utils.py @@ -178,7 +178,13 @@ def multiple_iteration_timeit_from_trace( if trace_dir and not is_local_directory_path(trace_dir): tmp_trace_dir = f"{local_trace_dir}/{trace_name}" # data_args = data_generator() - with jax.profiler.trace(tmp_trace_dir): + options = jax.profiler.ProfileOptions() + options.advanced_configuration = { + "tpu_trace_mode" : "TRACE_ONLY_XLA", + "tpu_num_sparse_cores_to_trace": 0, + "tpu_num_sparse_core_tiles_to_trace": 0, + } + with jax.profiler.trace(tmp_trace_dir, profiler_options=options): for i in range(tries): if i % 10 == 0: print(