Skip to content

Commit 248b93a

Browse files
committed
enable profiler to report perstep benchmark
1 parent bc7f02a commit 248b93a

2 files changed

Lines changed: 127 additions & 0 deletions

File tree

scripts/benchmarks/benchmark_non_rl.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
log_task_start_time,
7474
log_total_start_time,
7575
)
76+
from scripts.benchmarks.step_profiler import install_env_profiler
7677

7778
imports_time_begin = time.perf_counter_ns()
7879

@@ -141,6 +142,7 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
141142
env = gym.wrappers.RecordVideo(env, **video_kwargs)
142143

143144
task_startup_time_end = time.perf_counter_ns()
145+
prof = install_env_profiler(env.unwrapped)
144146

145147
env.reset()
146148

@@ -194,6 +196,7 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
194196
log_runtime_step_times(benchmark, environment_step_times, compute_stats=True)
195197

196198
benchmark.stop()
199+
print(prof.render_table(len(step_times)))
197200

198201
# close the simulator
199202
env.close()
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
# Copyright (c) 2022-2025, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md).
2+
# All rights reserved.
3+
#
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
6+
import time
7+
from prettytable import PrettyTable
8+
from collections import defaultdict
9+
10+
11+
class EnvStepProfiler:
12+
def __init__(self):
13+
self._in_step = False
14+
self.hist_ns = defaultdict(list)
15+
self._wrapped = []
16+
17+
def start_step(self):
18+
self._in_step = True
19+
20+
def end_step(self):
21+
self._in_step = False
22+
23+
def wrap(self, obj, method_name, label=None):
24+
"""Monkey-patch obj.method_name to record wall clock time into hist_ns[label]."""
25+
assert hasattr(obj, method_name), f"{obj} does not has method {method_name}"
26+
assert callable(getattr(obj, method_name)), f"{obj}'s method {method_name} is not callable"
27+
orig = getattr(obj, method_name)
28+
29+
def wrapped(*a, **kw):
30+
t0 = time.perf_counter_ns()
31+
try:
32+
return orig(*a, **kw)
33+
finally:
34+
if self._in_step:
35+
self.hist_ns[label or method_name].append(time.perf_counter_ns() - t0)
36+
37+
setattr(obj, method_name, wrapped)
38+
self._wrapped.append((obj, method_name, orig))
39+
40+
def wrap_env_step(self, env):
41+
"""Wrap the env's step itself to measure total time and delimit a step window."""
42+
orig_step = env.step
43+
44+
def step_wrapper(actions):
45+
t0 = time.perf_counter_ns()
46+
self.start_step()
47+
try:
48+
return orig_step(actions)
49+
finally:
50+
self.hist_ns["env.step_total"].append(time.perf_counter_ns() - t0)
51+
self.end_step()
52+
53+
env.step = step_wrapper
54+
self._wrapped.append((env, "step", orig_step))
55+
56+
def summary_ms_per_step(self, num_steps):
57+
# average ms spent per env.step (sums across multiple internal calls / decimation)
58+
sums_ms = {k: sum(v) / 1e6 for k, v in self.hist_ns.items()}
59+
avg_ms_per_step = {k: sums_ms[k] / max(num_steps, 1) for k in sums_ms}
60+
pct_of_total = {k: (sums_ms[k] / sums_ms["env.step_total"] * 100.0) for k in sums_ms}
61+
# compute "unaccounted" overhead inside step
62+
if "env.step_total" in sums_ms:
63+
unacct = sums_ms["env.step_total"] - sum(sums_ms[k] for k in sums_ms if k != "env.step_total")
64+
avg_ms_per_step["(unaccounted)"] = unacct / max(num_steps, 1)
65+
pct_of_total["(unaccounted)"] = (
66+
(unacct / sums_ms["env.step_total"] * 100.0) if sums_ms["env.step_total"] else 0.0
67+
)
68+
return avg_ms_per_step, pct_of_total
69+
70+
def summarize(self, num_steps):
71+
avg_ms, pct = self.summary_ms_per_step(num_steps)
72+
total_ms_series = [ns / 1e6 for ns in self.hist_ns["env.step_total"]]
73+
return avg_ms, pct, total_ms_series
74+
75+
def render_table(self, num_steps, title="env.step() breakdown"):
76+
avg_ms, pct, _ = self.summarize(num_steps)
77+
table = PrettyTable()
78+
table.title = title
79+
table.field_names = ["Section", "Avg ms/step", "% of step"]
80+
table.align["Section"] = "l"
81+
table.align["Avg ms/step"] = "r"
82+
table.align["% of step"] = "r"
83+
84+
for name, ms in sorted(avg_ms.items(), key=lambda kv: (-kv[1], kv[0])):
85+
table.add_row([name, f"{ms:,.3f}", f"{pct.get(name, 0.0):,.1f}%"])
86+
return table.get_string()
87+
88+
89+
def install_env_profiler(env):
90+
"""Call with the *inner* env (env.unwrapped) right after gym.make(...)"""
91+
p = EnvStepProfiler()
92+
93+
# wrap the high-level step first (RecordVideo wrappers will call down to this)
94+
p.wrap_env_step(env)
95+
96+
# sim/scene loop pieces
97+
p.wrap(env.sim, "step", "sim.step")
98+
p.wrap(env.sim, "render", "sim.render")
99+
p.wrap(env.scene, "write_data_to_sim", "scene.write_data_to_sim")
100+
p.wrap(env.scene, "update", "scene.update")
101+
p.wrap(env.sim, "forward", "sim.forward")
102+
103+
# managers in step()
104+
p.wrap(env.action_manager, "process_action", "action.process")
105+
p.wrap(env.action_manager, "apply_action", "action.apply")
106+
p.wrap(env.termination_manager, "compute", "termination.compute")
107+
p.wrap(env.reward_manager, "compute", "reward.compute")
108+
p.wrap(env.command_manager, "compute", "command.compute")
109+
p.wrap(env.observation_manager, "compute", "observation.compute")
110+
111+
# event/recorder (optional; harmless if not present/used)
112+
p.wrap(env.event_manager, "apply", "event.apply")
113+
if hasattr(env, "recorder_manager"):
114+
p.wrap(env.recorder_manager, "record_pre_step", "recorder.pre_step")
115+
p.wrap(env.recorder_manager, "record_post_step", "recorder.post_step")
116+
p.wrap(env.recorder_manager, "record_pre_reset", "recorder.pre_reset")
117+
p.wrap(env.recorder_manager, "record_post_reset", "recorder.post_reset")
118+
119+
# reset path pieces (will show up only on steps that reset)
120+
p.wrap(env, "_reset_idx", "env._reset_idx")
121+
122+
# expose for later
123+
env._profiler = p
124+
return p

0 commit comments

Comments
 (0)