From 29774ee4881c355e95ace4644596c73792eabaf8 Mon Sep 17 00:00:00 2001 From: Zilong Date: Sat, 21 Feb 2026 09:16:16 -0600 Subject: [PATCH 1/4] Add LightsOut environment --- pufferlib/config/ocean/lightsout.ini | 15 ++ pufferlib/ocean/environment.py | 1 + pufferlib/ocean/lightsout/binding.c | 16 ++ pufferlib/ocean/lightsout/lightsout.c | 43 +++++ pufferlib/ocean/lightsout/lightsout.h | 223 +++++++++++++++++++++++++ pufferlib/ocean/lightsout/lightsout.py | 77 +++++++++ 6 files changed, 375 insertions(+) create mode 100644 pufferlib/config/ocean/lightsout.ini create mode 100644 pufferlib/ocean/lightsout/binding.c create mode 100644 pufferlib/ocean/lightsout/lightsout.c create mode 100644 pufferlib/ocean/lightsout/lightsout.h create mode 100644 pufferlib/ocean/lightsout/lightsout.py diff --git a/pufferlib/config/ocean/lightsout.ini b/pufferlib/config/ocean/lightsout.ini new file mode 100644 index 0000000000..c0b683b3f8 --- /dev/null +++ b/pufferlib/config/ocean/lightsout.ini @@ -0,0 +1,15 @@ +[base] +package = ocean +env_name = puffer_lightsout +policy_name = Policy + +[env] +num_envs = 1024 +grid_size = 5 +max_steps = 200 + +[policy] +hidden_size = 512 + +[train] +total_timesteps = 10_000_000 diff --git a/pufferlib/ocean/environment.py b/pufferlib/ocean/environment.py index 6c56a4ea20..77e9aa12fd 100644 --- a/pufferlib/ocean/environment.py +++ b/pufferlib/ocean/environment.py @@ -130,6 +130,7 @@ def make_multiagent(buf=None, **kwargs): 'moba': 'Moba', 'matsci': 'Matsci', 'memory': 'Memory', + 'lightsout': 'LightsOut', 'boids': 'Boids', 'drone': 'Drone', 'nmmo3': 'NMMO3', diff --git a/pufferlib/ocean/lightsout/binding.c b/pufferlib/ocean/lightsout/binding.c new file mode 100644 index 0000000000..db22983b3e --- /dev/null +++ b/pufferlib/ocean/lightsout/binding.c @@ -0,0 +1,16 @@ +#include "lightsout.h" + +#define Env LightsOut +#include "../env_binding.h" + +static int my_init(Env* env, PyObject* args, PyObject* kwargs) { + env->grid_size = unpack(kwargs, "grid_size"); + env->cell_size = unpack(kwargs, "cell_size"); + env->max_steps = unpack(kwargs, "max_steps"); + return 0; +} + +static int my_log(PyObject* dict, Log* log) { + assign_to_dict(dict, "score", log->score); + return 0; +} diff --git a/pufferlib/ocean/lightsout/lightsout.c b/pufferlib/ocean/lightsout/lightsout.c new file mode 100644 index 0000000000..6a08971d9c --- /dev/null +++ b/pufferlib/ocean/lightsout/lightsout.c @@ -0,0 +1,43 @@ +#include +#include +#include "lightsout.h" + +int demo(){ + srand((unsigned)time(NULL)); + LightsOut env = {.grid_size = 7, .cell_size = 100, .client = NULL}; + env.observations = (unsigned char*)calloc(env.grid_size * env.grid_size, sizeof(unsigned char)); + env.actions = (int*)calloc(1, sizeof(int)); + env.rewards = (float*)calloc(1, sizeof(float)); + env.terminals = (unsigned char*)calloc(1, sizeof(unsigned char)); + + c_reset(&env); + env.client = make_client(env.cell_size, env.grid_size); + + // printf("LightsOut template ran 10 placeholder steps.\n"); + while (!WindowShouldClose()) { + // User can take control of the first snake + if (IsKeyPressed(KEY_UP) || IsKeyPressed(KEY_W)) env.client->cursor_row = (env.client->cursor_row - 1 + env.grid_size) % env.grid_size; + if (IsKeyPressed(KEY_DOWN) || IsKeyPressed(KEY_S)) env.client->cursor_row = (env.client->cursor_row + 1) % env.grid_size; + if (IsKeyPressed(KEY_LEFT) || IsKeyPressed(KEY_A)) env.client->cursor_col = (env.client->cursor_col - 1 + env.grid_size) % env.grid_size; + if (IsKeyPressed(KEY_RIGHT) || IsKeyPressed(KEY_D)) env.client->cursor_col = (env.client->cursor_col + 1) % env.grid_size; + if (IsKeyPressed(KEY_SPACE)) { + int idx = env.client->cursor_row * env.grid_size + env.client->cursor_col; + env.actions[0] = idx; + c_step(&env); + } else if (IsKeyPressed(KEY_R)) { + c_reset(&env); + } + c_render(&env); + } + + free(env.observations); + free(env.actions); + free(env.rewards); + free(env.terminals); + c_close(&env); + return 0; +} +int main(void) { + demo(); + return 0; +} diff --git a/pufferlib/ocean/lightsout/lightsout.h b/pufferlib/ocean/lightsout/lightsout.h new file mode 100644 index 0000000000..8ef9618d06 --- /dev/null +++ b/pufferlib/ocean/lightsout/lightsout.h @@ -0,0 +1,223 @@ +#include +#include "raylib.h" + +// Only use floats. +typedef struct { + float score; + float n; // Required as the last field. +} Log; + +typedef struct Client { + int cell_size; + int grid_size; + int cursor_row; + int cursor_col; +} Client; + +typedef struct { + Log log; // Required field. + unsigned char* observations; // Required field. Ensure type matches in .py and .c. + int* actions; // Required field. Ensure type matches in .py and .c. + float* rewards; // Required field. + unsigned char* terminals; // Required field. + int grid_size; + int cell_size; + int max_steps; + int step_count; + float episode_return; + unsigned char* grid; + Client* client; +} LightsOut; + +int is_solved(LightsOut* env) { + for (int i = 0; i < env->grid_size * env->grid_size; i++) { + if (env->grid[i] == 1) return 0; // Not solved if any light is on. + } + return 1; // Solved if all lights are off. +} + +int count_lights_on(LightsOut* env) { + int on = 0; + for (int i = 0; i < env->grid_size * env->grid_size; i++) { + on += env->grid[i] != 0; + } + return on; +} + +void step_grid(LightsOut* env, int idx) { + if (idx < 0 || idx >= env->grid_size * env->grid_size) return; + int row = idx/env->grid_size; + int col = idx%env->grid_size; + + static const int dirs[5][2] = {{0,0}, {1,0}, {0,1}, {-1,0}, {0,-1}}; + for (int i = 0; i < 5; i++) { + int dr = dirs[i][0]; + int dc = dirs[i][1]; + int r = row + dr; + int c = col + dc; + if (r >= 0 && r < env->grid_size && c >= 0 && c < env->grid_size) { + int offset = r*env->grid_size + c; + env->grid[offset] = !env->grid[offset]; + } + } +} + +void init_lightsout(LightsOut* env) { + int n = env->grid_size * env->grid_size; + if (env->grid == NULL) { + env->grid = (unsigned char*)calloc(n, sizeof(unsigned char)); + } else { + for (int i = 0; i < n; i++) { + env->grid[i] = 0; + } + } + env->step_count = 0; + env->episode_return = 0.0f; + + float p = 0.5f; // scramble probability per cell + + for (int i = 0; i < n; i++) { + float u = (float)rand() / (float)RAND_MAX; // ~uniform in [0,1] + if (u < p) { + step_grid(env, i); + } + } +} + +void c_close(LightsOut* env) { + free(env->grid); + env->grid = NULL; + if (env->client != NULL) { + if (IsWindowReady()) { + CloseWindow(); + } + free(env->client); + env->client = NULL; + } +} + +void compute_observations(LightsOut* env) { + for (int i = 0; i < env->grid_size * env->grid_size; i++) { + env->observations[i] = env->grid[i]; + } +} + +void c_reset(LightsOut* env) { + env->rewards[0] = 0.0f; + env->terminals[0] = 0; + init_lightsout(env); + compute_observations(env); +} + +void c_step(LightsOut* env) { + // In manual mode, keep solved screen visible until user resets. + if (env->client != NULL && env->terminals[0]) { + env->rewards[0] = 0.0f; + compute_observations(env); + return; + } + + int num_cells = env->grid_size * env->grid_size; + int atn = env->actions[0]; + env->terminals[0] = 0; + + float reward = -0.02f; // Base step penalty. + int prev_on = count_lights_on(env); + if (atn < 0 || atn >= num_cells) { + reward -= 0.5f; // Invalid action penalty. + } else { + if (env->client != NULL) { + env->client->cursor_row = atn / env->grid_size; + env->client->cursor_col = atn % env->grid_size; + } + step_grid(env, atn); + int next_on = count_lights_on(env); + reward += 0.005f * (float)(prev_on - next_on); // Dense shaping: improve when lights decrease. + } + env->step_count += 1; + + if (is_solved(env)) { + reward = 1.0f; // Solved reward. + env->terminals[0] = 1; + } else if (env->client == NULL && env->step_count >= env->max_steps) { + reward -= 0.5f; // Timeout penalty during training. + env->terminals[0] = 1; + } + + env->rewards[0] = reward; + env->episode_return += reward; + if (env->terminals[0]) { + env->log.n += 1.0f; + env->log.score += env->episode_return; + if (env->client == NULL) { + init_lightsout(env); + } + } + compute_observations(env); +} + +// Raylib client +Color COLORS[] = { + (Color){6, 24, 24, 255}, + (Color){0, 0, 255, 255}, + (Color){0, 128, 255, 255}, + (Color){128, 128, 128, 255}, + (Color){255, 0, 0, 255}, + (Color){255, 255, 255, 255}, + (Color){255, 85, 85, 255}, + (Color){170, 170, 170, 255}, + (Color){0, 255, 255, 255}, + (Color){255, 255, 0, 255}, +}; + +Client* make_client(int cell_size, int grid_size) { + Client* client= (Client*)malloc(sizeof(Client)); + client->cell_size = cell_size; + client->grid_size = grid_size; + client->cursor_row = 0; + client->cursor_col = 0; + InitWindow(grid_size*cell_size, grid_size*cell_size, "PufferLib LightsOut"); + SetTargetFPS(3); + return client; +} + +void c_render(LightsOut* env) { + if (IsKeyDown(KEY_ESCAPE)) { + exit(0); + } + + if (env->client == NULL) { + env->client = make_client(env->cell_size, env->grid_size); + } + + Client* client = env->client; + + BeginDrawing(); + ClearBackground(COLORS[0]); + int sz = client->cell_size; + for (int y = 0; y < env->grid_size; y++) { + for (int x = 0; x < env->grid_size; x++){ + int tile = env->grid[y*env->grid_size + x]; + if (tile != 0) + DrawRectangle(x*sz, y*sz, sz, sz, COLORS[tile]); + } + } + DrawRectangleLinesEx( + (Rectangle){client->cursor_col * sz, client->cursor_row * sz, sz, sz}, + 3.0f, + COLORS[5] + ); + + if (env->terminals[0]) { + const char* msg = "Solved"; + int font_size = 48; + int text_w = MeasureText(msg, font_size); + int screen_w = env->grid_size * env->cell_size; + int screen_h = env->grid_size * env->cell_size; + + DrawRectangle(0, 0, screen_w, screen_h, (Color){0, 0, 0, 120}); // dim overlay + DrawText(msg, (screen_w - text_w) / 2, (screen_h - font_size) / 2, font_size, RAYWHITE); + } + + EndDrawing(); +} diff --git a/pufferlib/ocean/lightsout/lightsout.py b/pufferlib/ocean/lightsout/lightsout.py new file mode 100644 index 0000000000..d82385a7bf --- /dev/null +++ b/pufferlib/ocean/lightsout/lightsout.py @@ -0,0 +1,77 @@ +"""Scaffold for a future LightsOut ocean environment.""" + +import gymnasium +import numpy as np + +import pufferlib +from pufferlib.ocean.lightsout import binding + +import time + +class LightsOut(pufferlib.PufferEnv): + def __init__(self, num_envs=1, render_mode=None, log_interval=128, grid_size=5, max_steps=None, buf=None, seed=0): + self.single_observation_space = gymnasium.spaces.Box(low=0, high=1, shape=(grid_size * grid_size,), dtype=np.uint8) + self.single_action_space = gymnasium.spaces.Discrete(grid_size * grid_size) + self.render_mode = render_mode + self.num_agents = num_envs + self.log_interval = log_interval + self.tick = 0 + + if max_steps is None: + max_steps = grid_size * grid_size * 10 + + super().__init__(buf) + self.c_envs = binding.vec_init( + self.observations, + self.actions, + self.rewards, + self.terminals, + self.truncations, + num_envs, + seed, + grid_size=grid_size, + cell_size=int(np.ceil(1280 / grid_size)), + max_steps=max_steps, + ) + self.grid_size = grid_size + + def reset(self, seed=None): + self.tick = 0 + if seed is None: + seed = time.time_ns() & 0x7FFFFFFF + binding.vec_reset(self.c_envs, seed) + return self.observations, [] + + def step(self, actions): + self.actions[:] = actions + self.tick += 1 + binding.vec_step(self.c_envs) + info = [] + if self.tick % self.log_interval == 0: + info.append(binding.vec_log(self.c_envs)) + return self.observations, self.rewards, self.terminals, self.truncations, info + + def render(self): + binding.vec_render(self.c_envs, 0) + + def close(self): + binding.vec_close(self.c_envs) + + +if __name__ == "__main__": + n = 4096 + env = LightsOut(num_envs=n) + env.reset() + steps = 0 + + cache = 1024 + actions = np.zeros((cache, n), dtype=np.int32) + + import time + + start = time.time() + while time.time() - start < 10: + env.step(actions[steps % cache]) + steps += 1 + + print("LightsOut SPS:", int(env.num_agents * steps / (time.time() - start))) From f315f8b40e8cc91324279ff5abaa98c53859b90e Mon Sep 17 00:00:00 2001 From: Zilong Date: Sat, 21 Feb 2026 14:10:43 -0600 Subject: [PATCH 2/4] lightsout: curriculum, logging, cleanup, and train helper --- pufferlib/ocean/lightsout/binding.c | 3 + pufferlib/ocean/lightsout/lightsout.c | 26 +++++--- pufferlib/ocean/lightsout/lightsout.h | 92 +++++++++++++------------- pufferlib/ocean/lightsout/lightsout.py | 2 - pufferlib/ocean/lightsout/train.py | 51 ++++++++++++++ 5 files changed, 117 insertions(+), 57 deletions(-) create mode 100644 pufferlib/ocean/lightsout/train.py diff --git a/pufferlib/ocean/lightsout/binding.c b/pufferlib/ocean/lightsout/binding.c index db22983b3e..d46b34e136 100644 --- a/pufferlib/ocean/lightsout/binding.c +++ b/pufferlib/ocean/lightsout/binding.c @@ -7,10 +7,13 @@ static int my_init(Env* env, PyObject* args, PyObject* kwargs) { env->grid_size = unpack(kwargs, "grid_size"); env->cell_size = unpack(kwargs, "cell_size"); env->max_steps = unpack(kwargs, "max_steps"); + env->ema = 0.0f; + env->scramble_prob = 0.25f; return 0; } static int my_log(PyObject* dict, Log* log) { assign_to_dict(dict, "score", log->score); + assign_to_dict(dict, "scramble_p", log->scramble_p); return 0; } diff --git a/pufferlib/ocean/lightsout/lightsout.c b/pufferlib/ocean/lightsout/lightsout.c index 6a08971d9c..33d574c0c4 100644 --- a/pufferlib/ocean/lightsout/lightsout.c +++ b/pufferlib/ocean/lightsout/lightsout.c @@ -2,9 +2,25 @@ #include #include "lightsout.h" +static LightsOut* g_env = NULL; + +static void demo_cleanup(void) { + if (g_env == NULL) { + return; + } + free(g_env->observations); + free(g_env->actions); + free(g_env->rewards); + free(g_env->terminals); + c_close(g_env); + g_env = NULL; +} + int demo(){ srand((unsigned)time(NULL)); - LightsOut env = {.grid_size = 7, .cell_size = 100, .client = NULL}; + LightsOut env = {.grid_size = 5, .cell_size = 100, .client = NULL}; + g_env = &env; + atexit(demo_cleanup); env.observations = (unsigned char*)calloc(env.grid_size * env.grid_size, sizeof(unsigned char)); env.actions = (int*)calloc(1, sizeof(int)); env.rewards = (float*)calloc(1, sizeof(float)); @@ -13,9 +29,7 @@ int demo(){ c_reset(&env); env.client = make_client(env.cell_size, env.grid_size); - // printf("LightsOut template ran 10 placeholder steps.\n"); while (!WindowShouldClose()) { - // User can take control of the first snake if (IsKeyPressed(KEY_UP) || IsKeyPressed(KEY_W)) env.client->cursor_row = (env.client->cursor_row - 1 + env.grid_size) % env.grid_size; if (IsKeyPressed(KEY_DOWN) || IsKeyPressed(KEY_S)) env.client->cursor_row = (env.client->cursor_row + 1) % env.grid_size; if (IsKeyPressed(KEY_LEFT) || IsKeyPressed(KEY_A)) env.client->cursor_col = (env.client->cursor_col - 1 + env.grid_size) % env.grid_size; @@ -30,11 +44,7 @@ int demo(){ c_render(&env); } - free(env.observations); - free(env.actions); - free(env.rewards); - free(env.terminals); - c_close(&env); + demo_cleanup(); return 0; } int main(void) { diff --git a/pufferlib/ocean/lightsout/lightsout.h b/pufferlib/ocean/lightsout/lightsout.h index 8ef9618d06..61a39fa036 100644 --- a/pufferlib/ocean/lightsout/lightsout.h +++ b/pufferlib/ocean/lightsout/lightsout.h @@ -1,15 +1,17 @@ #include +#include +#include #include "raylib.h" // Only use floats. typedef struct { float score; + float scramble_p; float n; // Required as the last field. } Log; typedef struct Client { int cell_size; - int grid_size; int cursor_row; int cursor_col; } Client; @@ -24,26 +26,15 @@ typedef struct { int cell_size; int max_steps; int step_count; + int lights_on; + int last_action; float episode_return; + float ema; + float scramble_prob; unsigned char* grid; Client* client; } LightsOut; -int is_solved(LightsOut* env) { - for (int i = 0; i < env->grid_size * env->grid_size; i++) { - if (env->grid[i] == 1) return 0; // Not solved if any light is on. - } - return 1; // Solved if all lights are off. -} - -int count_lights_on(LightsOut* env) { - int on = 0; - for (int i = 0; i < env->grid_size * env->grid_size; i++) { - on += env->grid[i] != 0; - } - return on; -} - void step_grid(LightsOut* env, int idx) { if (idx < 0 || idx >= env->grid_size * env->grid_size) return; int row = idx/env->grid_size; @@ -57,7 +48,9 @@ void step_grid(LightsOut* env, int idx) { int c = col + dc; if (r >= 0 && r < env->grid_size && c >= 0 && c < env->grid_size) { int offset = r*env->grid_size + c; - env->grid[offset] = !env->grid[offset]; + unsigned char old = env->grid[offset]; + env->grid[offset] = (unsigned char)!old; + env->lights_on += old ? -1 : 1; } } } @@ -67,18 +60,23 @@ void init_lightsout(LightsOut* env) { if (env->grid == NULL) { env->grid = (unsigned char*)calloc(n, sizeof(unsigned char)); } else { - for (int i = 0; i < n; i++) { - env->grid[i] = 0; - } + memset(env->grid, 0, n * sizeof(unsigned char)); + } + + if (env->ema > 0.65f) { + env->scramble_prob = fminf(0.5f, env->scramble_prob + 0.03f); // Increase scramble prob if EMA is high + } else if (env->ema < 0.35f) { + env->scramble_prob = fmaxf(0.25f, env->scramble_prob - 0.01f); // Decrease scramble prob if EMA is low } + env->step_count = 0; + env->lights_on = 0; + env->last_action = -1; env->episode_return = 0.0f; - float p = 0.5f; // scramble probability per cell - for (int i = 0; i < n; i++) { - float u = (float)rand() / (float)RAND_MAX; // ~uniform in [0,1] - if (u < p) { + float u = (float)rand() / (float)RAND_MAX; + if (u < env->scramble_prob) { step_grid(env, i); } } @@ -110,9 +108,11 @@ void c_reset(LightsOut* env) { } void c_step(LightsOut* env) { - // In manual mode, keep solved screen visible until user resets. - if (env->client != NULL && env->terminals[0]) { + // Defer reset by one step so terminal observation is preserved. + if (env->terminals[0]) { + init_lightsout(env); env->rewards[0] = 0.0f; + env->terminals[0] = 0; compute_observations(env); return; } @@ -121,26 +121,32 @@ void c_step(LightsOut* env) { int atn = env->actions[0]; env->terminals[0] = 0; - float reward = -0.02f; // Base step penalty. - int prev_on = count_lights_on(env); + float reward = -0.02 * (36.0 / (env->grid_size * env->grid_size)); // Base step penalty. + int prev_on = env->lights_on; if (atn < 0 || atn >= num_cells) { reward -= 0.5f; // Invalid action penalty. } else { + if (atn == env->last_action) { + reward -= 0.05f; // Penalty for pressing the same cell twice in a row. + } if (env->client != NULL) { env->client->cursor_row = atn / env->grid_size; env->client->cursor_col = atn % env->grid_size; } step_grid(env, atn); - int next_on = count_lights_on(env); + env->last_action = atn; + int next_on = env->lights_on; reward += 0.005f * (float)(prev_on - next_on); // Dense shaping: improve when lights decrease. } env->step_count += 1; - if (is_solved(env)) { - reward = 1.0f; // Solved reward. + if (env->lights_on == 0) { + reward = 2.0f; // Solved reward. + env->ema = 0.85f * env->ema + 0.15f; // Update EMA of steps to solve. env->terminals[0] = 1; } else if (env->client == NULL && env->step_count >= env->max_steps) { reward -= 0.5f; // Timeout penalty during training. + env->ema = 0.85f * env->ema; // Decay EMA since we failed to solve. env->terminals[0] = 1; } @@ -149,40 +155,32 @@ void c_step(LightsOut* env) { if (env->terminals[0]) { env->log.n += 1.0f; env->log.score += env->episode_return; - if (env->client == NULL) { - init_lightsout(env); - } + env->log.scramble_p += env->scramble_prob; } + compute_observations(env); } // Raylib client -Color COLORS[] = { +static const Color COLORS[] = { (Color){6, 24, 24, 255}, (Color){0, 0, 255, 255}, - (Color){0, 128, 255, 255}, - (Color){128, 128, 128, 255}, - (Color){255, 0, 0, 255}, - (Color){255, 255, 255, 255}, - (Color){255, 85, 85, 255}, - (Color){170, 170, 170, 255}, - (Color){0, 255, 255, 255}, - (Color){255, 255, 0, 255}, + (Color){255, 255, 255, 255} }; Client* make_client(int cell_size, int grid_size) { Client* client= (Client*)malloc(sizeof(Client)); client->cell_size = cell_size; - client->grid_size = grid_size; client->cursor_row = 0; client->cursor_col = 0; InitWindow(grid_size*cell_size, grid_size*cell_size, "PufferLib LightsOut"); - SetTargetFPS(3); + SetTargetFPS(15); return client; } void c_render(LightsOut* env) { - if (IsKeyDown(KEY_ESCAPE)) { + if (IsWindowReady() && (WindowShouldClose() || IsKeyPressed(KEY_ESCAPE))) { + c_close(env); exit(0); } @@ -205,7 +203,7 @@ void c_render(LightsOut* env) { DrawRectangleLinesEx( (Rectangle){client->cursor_col * sz, client->cursor_row * sz, sz, sz}, 3.0f, - COLORS[5] + COLORS[2] ); if (env->terminals[0]) { diff --git a/pufferlib/ocean/lightsout/lightsout.py b/pufferlib/ocean/lightsout/lightsout.py index d82385a7bf..6469779827 100644 --- a/pufferlib/ocean/lightsout/lightsout.py +++ b/pufferlib/ocean/lightsout/lightsout.py @@ -67,8 +67,6 @@ def close(self): cache = 1024 actions = np.zeros((cache, n), dtype=np.int32) - import time - start = time.time() while time.time() - start < 10: env.step(actions[steps % cache]) diff --git a/pufferlib/ocean/lightsout/train.py b/pufferlib/ocean/lightsout/train.py new file mode 100644 index 0000000000..7a5978502a --- /dev/null +++ b/pufferlib/ocean/lightsout/train.py @@ -0,0 +1,51 @@ +from pufferlib import pufferl + + +def train_until_target(env_name="puffer_lightsout", load_model_path=None): + args = pufferl.load_config(env_name) + + args["train"]["device"] = "cuda" + args["vec"]["backend"] = "PufferEnv" + args["vec"]["num_envs"] = 1 + args["env"]["num_envs"] = 4096 + args["env"]["grid_size"] = 7 + + # High cap; run stops early when target is stable. + args["train"]["total_timesteps"] = 1_000_000_000 + args["train"]["ent_coef"] = 0.005 + args["train"]["learning_rate"] = 0.015 + args["train"]["update_epochs"] = 2 + args["train"]["minibatch_size"] = 32768 + + if load_model_path is not None: + args["load_model_path"] = load_model_path + + target_score = 0.6 + target_scramble_p = 0.499 + target_min_n = 50.0 + target_streak = 3 + streak = 0 + + def stop_on_target(logs): + nonlocal streak + p = logs.get("environment/scramble_p") + score = logs.get("environment/score") + n = logs.get("environment/n", 0.0) + if p is None or score is None: + return False + + hit = p >= target_scramble_p and score >= target_score and n >= target_min_n + streak = streak + 1 if hit else 0 + if hit: + print( + f"target hit: scramble_p={p:.3f} score={score:.3f} n={n:.1f} " + f"streak={streak}/{target_streak}" + ) + + return streak >= target_streak + + pufferl.train(env_name, args=args, early_stop_fn=stop_on_target) + + +if __name__ == "__main__": + train_until_target("puffer_lightsout", load_model_path=None) \ No newline at end of file From b8d799f1971843b3eacbf9d2b6ff3773fa8af808 Mon Sep 17 00:00:00 2001 From: Zilong Date: Sat, 21 Feb 2026 19:46:29 -0600 Subject: [PATCH 3/4] hold the solve screen longer --- pufferlib/ocean/lightsout/binding.c | 2 +- pufferlib/ocean/lightsout/lightsout.h | 6 +++++- pufferlib/ocean/lightsout/train.py | 6 +++--- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/pufferlib/ocean/lightsout/binding.c b/pufferlib/ocean/lightsout/binding.c index d46b34e136..f419f387e2 100644 --- a/pufferlib/ocean/lightsout/binding.c +++ b/pufferlib/ocean/lightsout/binding.c @@ -8,7 +8,7 @@ static int my_init(Env* env, PyObject* args, PyObject* kwargs) { env->cell_size = unpack(kwargs, "cell_size"); env->max_steps = unpack(kwargs, "max_steps"); env->ema = 0.0f; - env->scramble_prob = 0.25f; + env->scramble_prob = 0.2f; return 0; } diff --git a/pufferlib/ocean/lightsout/lightsout.h b/pufferlib/ocean/lightsout/lightsout.h index 61a39fa036..17553e0c64 100644 --- a/pufferlib/ocean/lightsout/lightsout.h +++ b/pufferlib/ocean/lightsout/lightsout.h @@ -66,7 +66,7 @@ void init_lightsout(LightsOut* env) { if (env->ema > 0.65f) { env->scramble_prob = fminf(0.5f, env->scramble_prob + 0.03f); // Increase scramble prob if EMA is high } else if (env->ema < 0.35f) { - env->scramble_prob = fmaxf(0.25f, env->scramble_prob - 0.01f); // Decrease scramble prob if EMA is low + env->scramble_prob = fmaxf(0.15f, env->scramble_prob - 0.01f); // Decrease scramble prob if EMA is low } env->step_count = 0; @@ -218,4 +218,8 @@ void c_render(LightsOut* env) { } EndDrawing(); + + if (env->terminals[0]) { + WaitTime(0.5); // hold solved screen + } } diff --git a/pufferlib/ocean/lightsout/train.py b/pufferlib/ocean/lightsout/train.py index 7a5978502a..16029f92f8 100644 --- a/pufferlib/ocean/lightsout/train.py +++ b/pufferlib/ocean/lightsout/train.py @@ -8,10 +8,10 @@ def train_until_target(env_name="puffer_lightsout", load_model_path=None): args["vec"]["backend"] = "PufferEnv" args["vec"]["num_envs"] = 1 args["env"]["num_envs"] = 4096 - args["env"]["grid_size"] = 7 + args["env"]["grid_size"] = 8 # High cap; run stops early when target is stable. - args["train"]["total_timesteps"] = 1_000_000_000 + args["train"]["total_timesteps"] = 2_000_000_000 args["train"]["ent_coef"] = 0.005 args["train"]["learning_rate"] = 0.015 args["train"]["update_epochs"] = 2 @@ -20,7 +20,7 @@ def train_until_target(env_name="puffer_lightsout", load_model_path=None): if load_model_path is not None: args["load_model_path"] = load_model_path - target_score = 0.6 + target_score = 0.42 target_scramble_p = 0.499 target_min_n = 50.0 target_streak = 3 From 2845237a8dda0483778a897e8b66f78a3847ed75 Mon Sep 17 00:00:00 2001 From: Zilong Date: Sun, 22 Feb 2026 00:11:58 -0600 Subject: [PATCH 4/4] punish 2-step cycles --- pufferlib/ocean/lightsout/binding.c | 5 +++-- pufferlib/ocean/lightsout/lightsout.h | 15 +++++++++++---- pufferlib/ocean/lightsout/train.py | 3 ++- 3 files changed, 16 insertions(+), 7 deletions(-) diff --git a/pufferlib/ocean/lightsout/binding.c b/pufferlib/ocean/lightsout/binding.c index f419f387e2..958dcc2e80 100644 --- a/pufferlib/ocean/lightsout/binding.c +++ b/pufferlib/ocean/lightsout/binding.c @@ -7,8 +7,9 @@ static int my_init(Env* env, PyObject* args, PyObject* kwargs) { env->grid_size = unpack(kwargs, "grid_size"); env->cell_size = unpack(kwargs, "cell_size"); env->max_steps = unpack(kwargs, "max_steps"); - env->ema = 0.0f; - env->scramble_prob = 0.2f; + env->ema = 0.5f; + env->score_ema = 0.0f; + env->scramble_prob = 0.15f; return 0; } diff --git a/pufferlib/ocean/lightsout/lightsout.h b/pufferlib/ocean/lightsout/lightsout.h index 17553e0c64..0b3e1ce7f0 100644 --- a/pufferlib/ocean/lightsout/lightsout.h +++ b/pufferlib/ocean/lightsout/lightsout.h @@ -27,9 +27,11 @@ typedef struct { int max_steps; int step_count; int lights_on; + int prev_action; int last_action; float episode_return; float ema; + float score_ema; float scramble_prob; unsigned char* grid; Client* client; @@ -63,14 +65,15 @@ void init_lightsout(LightsOut* env) { memset(env->grid, 0, n * sizeof(unsigned char)); } - if (env->ema > 0.65f) { - env->scramble_prob = fminf(0.5f, env->scramble_prob + 0.03f); // Increase scramble prob if EMA is high - } else if (env->ema < 0.35f) { + if (env->ema > 0.7f && env->score_ema > 0.0f) { + env->scramble_prob = fminf(0.5f, env->scramble_prob + 0.01f); // Increase scramble prob if EMA is high + } else if (env->ema < 0.3f) { env->scramble_prob = fmaxf(0.15f, env->scramble_prob - 0.01f); // Decrease scramble prob if EMA is low } env->step_count = 0; env->lights_on = 0; + env->prev_action = -1; env->last_action = -1; env->episode_return = 0.0f; @@ -127,13 +130,16 @@ void c_step(LightsOut* env) { reward -= 0.5f; // Invalid action penalty. } else { if (atn == env->last_action) { - reward -= 0.05f; // Penalty for pressing the same cell twice in a row. + reward -= 0.03f; // Penalty for pressing the same cell twice in a row. + } else if (atn == env->prev_action) { + reward -= 0.02f; // Penalty for 2-step loop (A,B,A). } if (env->client != NULL) { env->client->cursor_row = atn / env->grid_size; env->client->cursor_col = atn % env->grid_size; } step_grid(env, atn); + env->prev_action = env->last_action; env->last_action = atn; int next_on = env->lights_on; reward += 0.005f * (float)(prev_on - next_on); // Dense shaping: improve when lights decrease. @@ -153,6 +159,7 @@ void c_step(LightsOut* env) { env->rewards[0] = reward; env->episode_return += reward; if (env->terminals[0]) { + env->score_ema = 0.9f * env->score_ema + 0.1f * env->episode_return; env->log.n += 1.0f; env->log.score += env->episode_return; env->log.scramble_p += env->scramble_prob; diff --git a/pufferlib/ocean/lightsout/train.py b/pufferlib/ocean/lightsout/train.py index 16029f92f8..35ecb7b2db 100644 --- a/pufferlib/ocean/lightsout/train.py +++ b/pufferlib/ocean/lightsout/train.py @@ -48,4 +48,5 @@ def stop_on_target(logs): if __name__ == "__main__": - train_until_target("puffer_lightsout", load_model_path=None) \ No newline at end of file + train_until_target("puffer_lightsout", load_model_path=None) + # train_until_target("puffer_lightsout", load_model_path="latest") \ No newline at end of file