Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions pufferlib/config/ocean/drive.ini
Original file line number Diff line number Diff line change
Expand Up @@ -70,19 +70,19 @@ reward_bound_goal_radius_min = 2.0
reward_bound_goal_radius_max = 12.0

reward_bound_collision_min = -3.0
reward_bound_collision_max = -2.9
reward_bound_collision_max = -0.1

reward_bound_offroad_min = -3.0
reward_bound_offroad_max = -2.9
reward_bound_offroad_max = -0.1

reward_bound_comfort_min = -0.1
reward_bound_comfort_max = 0.0

reward_bound_lane_align_min = 0.0020
reward_bound_lane_align_min = 0.00020
reward_bound_lane_align_max = 0.0025

reward_bound_lane_center_min = -0.00075
reward_bound_lane_center_max = -0.00065
reward_bound_lane_center_max = -0.000065

reward_bound_velocity_min = 0.0
reward_bound_velocity_max = 0.005
Expand Down Expand Up @@ -116,7 +116,7 @@ reward_bound_acc_max = 1.5

[train]
seed=42
total_timesteps = 2_000_000_000
total_timesteps = 1_000_000_000_0
; learning_rate = 0.02
; gamma = 0.985
anneal_lr = True
Expand All @@ -141,11 +141,11 @@ vf_clip_coef = 0.1999999999999999
vf_coef = 2
vtrace_c_clip = 1
vtrace_rho_clip = 1
checkpoint_interval = 250
checkpoint_interval = 1000
; Rendering options
render = True
render_async = False # Render interval of below 50 might cause process starvation and slowness in training
render_interval = 250
render_interval = 1000
; If True, show exactly what the agent sees in agent observation
obs_only = True
; Show grid lines
Expand Down
7 changes: 4 additions & 3 deletions pufferlib/ocean/drive/drive.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,8 +372,9 @@ def __init__(

self.c_envs = binding.vectorize(*env_ids)

def reset(self, seed=0):
binding.vec_reset(self.c_envs, seed)
def reset(self, seed=0, parameters=None):
parameters = parameters or {}
binding.vec_reset(self.c_envs, seed, parameters)
self.tick = 0
self.truncations[:] = 0
return self.observations, []
Expand Down Expand Up @@ -518,7 +519,7 @@ def step(self, actions):
env_ids.append(env_id)
self.c_envs = binding.vectorize(*env_ids)

binding.vec_reset(self.c_envs, seed)
binding.vec_reset(self.c_envs, seed, None)
self.terminals[:] = 1
return (self.observations, self.rewards, self.terminals, self.truncations, info)

Expand Down
59 changes: 57 additions & 2 deletions pufferlib/ocean/env_binding.h
Original file line number Diff line number Diff line change
Expand Up @@ -473,9 +473,53 @@ static PyObject *vectorize(PyObject *self, PyObject *args) {
return PyLong_FromVoidPtr(vec);
}

void apply_parameters(VecEnv *vec, PyObject *params_dict) {
if (params_dict == NULL || !PyDict_Check(params_dict)) {
return;
}

// Helper to extract float parameter and apply to all envs
#define APPLY_REWARD_BOUND(param_min_name, param_max_name, coef_index) \
{ \
PyObject *val_min = PyDict_GetItemString(params_dict, param_min_name); \
PyObject *val_max = PyDict_GetItemString(params_dict, param_max_name); \
if (val_min != NULL && PyFloat_Check(val_min)) { \
for (int i = 0; i < vec->num_envs; i++) { \
Drive *drive = (Drive *)vec->envs[i]; \
drive->reward_bounds[coef_index].min_val = (float)PyFloat_AsDouble(val_min); \
} \
} \
if (val_max != NULL && PyFloat_Check(val_max)) { \
for (int i = 0; i < vec->num_envs; i++) { \
Drive *drive = (Drive *)vec->envs[i]; \
drive->reward_bounds[coef_index].max_val = (float)PyFloat_AsDouble(val_max); \
} \
} \
}

APPLY_REWARD_BOUND("reward_bound_goal_radius_min", "reward_bound_goal_radius_max", REWARD_COEF_GOAL_RADIUS);
APPLY_REWARD_BOUND("reward_bound_collision_min", "reward_bound_collision_max", REWARD_COEF_COLLISION);
APPLY_REWARD_BOUND("reward_bound_offroad_min", "reward_bound_offroad_max", REWARD_COEF_OFFROAD);
APPLY_REWARD_BOUND("reward_bound_comfort_min", "reward_bound_comfort_max", REWARD_COEF_COMFORT);
APPLY_REWARD_BOUND("reward_bound_lane_align_min", "reward_bound_lane_align_max", REWARD_COEF_LANE_ALIGN);
APPLY_REWARD_BOUND("reward_bound_lane_center_min", "reward_bound_lane_center_max", REWARD_COEF_LANE_CENTER);
APPLY_REWARD_BOUND("reward_bound_velocity_min", "reward_bound_velocity_max", REWARD_COEF_VELOCITY);
APPLY_REWARD_BOUND("reward_bound_traffic_light_min", "reward_bound_traffic_light_max", REWARD_COEF_TRAFFIC_LIGHT);
APPLY_REWARD_BOUND("reward_bound_center_bias_min", "reward_bound_center_bias_max", REWARD_COEF_CENTER_BIAS);
APPLY_REWARD_BOUND("reward_bound_vel_align_min", "reward_bound_vel_align_max", REWARD_COEF_VEL_ALIGN);
APPLY_REWARD_BOUND("reward_bound_overspeed_min", "reward_bound_overspeed_max", REWARD_COEF_OVERSPEED);
APPLY_REWARD_BOUND("reward_bound_timestep_min", "reward_bound_timestep_max", REWARD_COEF_TIMESTEP);
APPLY_REWARD_BOUND("reward_bound_reverse_min", "reward_bound_reverse_max", REWARD_COEF_REVERSE);
APPLY_REWARD_BOUND("reward_bound_throttle_min", "reward_bound_throttle_max", REWARD_COEF_THROTTLE);
APPLY_REWARD_BOUND("reward_bound_steer_min", "reward_bound_steer_max", REWARD_COEF_STEER);
APPLY_REWARD_BOUND("reward_bound_acc_min", "reward_bound_acc_max", REWARD_COEF_ACC);

#undef APPLY_REWARD_BOUND
}

static PyObject *vec_reset(PyObject *self, PyObject *args) {
if (PyTuple_Size(args) != 2) {
PyErr_SetString(PyExc_TypeError, "vec_reset requires 2 arguments");
if (PyTuple_Size(args) != 3) {
PyErr_SetString(PyExc_TypeError, "vec_reset requires 3 arguments");
return NULL;
}

Expand All @@ -484,6 +528,17 @@ static PyObject *vec_reset(PyObject *self, PyObject *args) {
return NULL;
}

PyObject *params = PyTuple_GetItem(args, 2);

if (params == Py_None) {
// skip parameter logic
} else if (!PyDict_Check(params)) {
PyErr_SetString(PyExc_TypeError, "parameters must be dict or None");
return NULL;
} else {
apply_parameters(vec, params);
}

PyObject *seed_arg = PyTuple_GetItem(args, 1);
if (!PyObject_TypeCheck(seed_arg, &PyLong_Type)) {
PyErr_SetString(PyExc_TypeError, "seed must be an integer");
Expand Down
8 changes: 4 additions & 4 deletions pufferlib/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,8 @@ def _worker_process(

start = time.time()
if sem == RESET:
seed = recv_pipe.recv()
_, infos = envs.reset(seed=seed)
seed, parameters = recv_pipe.recv()
_, infos = envs.reset(seed=seed, parameters=parameters)
elif sem == STEP:
_, _, _, _, infos = envs.step(atn_arr)
elif sem == CLOSE:
Expand Down Expand Up @@ -503,7 +503,7 @@ def send(self, actions):
self.actions[idxs] = actions
self.buf["semaphores"][idxs] = STEP

def async_reset(self, seed=0):
def async_reset(self, seed=0, parameters=None):
# Flush any waiting workers
while self.waiting_workers:
worker = self.waiting_workers.pop(0)
Expand All @@ -528,7 +528,7 @@ def async_reset(self, seed=0):
for i in range(self.num_workers):
start = i * self.envs_per_worker
end = (i + 1) * self.envs_per_worker
self.send_pipes[i].send(seed + i)
self.send_pipes[i].send((seed + i, parameters))

def notify(self):
self.buf["notify"][:] = True
Expand Down