From 9b38f302825db10077a444f9976c3dcbe962b96c Mon Sep 17 00:00:00 2001 From: Joshua Bloom Date: Sun, 12 Apr 2026 20:14:08 -0400 Subject: [PATCH 1/2] feat(actor): expose per-step GSP prediction loss via last_gsp_loss MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The GSP prediction network's training loss was never surfaced through Actor.learn(). Only the actor/critic loss was returned, which stays normal even when the GSP head collapses to a near-constant output. Add last_gsp_loss attribute populated by learn_gsp() whenever a GSP learning step fires, reset to None at the start of each learn() call so callers can distinguish "no GSP step this tick" from "GSP step ran". Needed for the information-collapse diagnostic (see Stelaris docs/specs/2026-04-12-dispatcher-diagnostic-batch.md) — without it we cannot tell whether non-recurrent GSP variants are learning or degenerate. Co-Authored-By: Claude Opus 4.6 (1M context) --- gsp_rl/src/actors/actor.py | 27 +++++- tests/test_actor/test_gsp_loss_exposure.py | 107 +++++++++++++++++++++ 2 files changed, 129 insertions(+), 5 deletions(-) create mode 100644 tests/test_actor/test_gsp_loss_exposure.py diff --git a/gsp_rl/src/actors/actor.py b/gsp_rl/src/actors/actor.py index 124999c..5c12736 100644 --- a/gsp_rl/src/actors/actor.py +++ b/gsp_rl/src/actors/actor.py @@ -112,6 +112,12 @@ def __init__( if attention: self.build_gsp_network('attention') self.build_gsp_network('DDPG') + + # Information-collapse diagnostic: last GSP prediction network training loss. + # Populated by learn_gsp() whenever a GSP learning step fires; reset to None at the + # start of each learn() call so callers can distinguish "no GSP step this tick" from + # "GSP step ran". + self.last_gsp_loss: float | None = None def build_networks(self, learning_scheme): if learning_scheme == 'None': @@ -402,8 +408,11 @@ def choose_actions_batch(self, observations, networks, test=False): f"Use choose_action() for RDDPG/attention networks.") def learn(self): + # Reset per-tick GSP loss signal. None means "no GSP step ran this tick". + self.last_gsp_loss = None + # TODO Not sure why we have n_agents*batch_size + batch_size - if self.networks['replay'].mem_ctr < self.batch_size: # (self.n_agents*self.batch_size + self.batch_size): + if self.networks['replay'].mem_ctr < self.batch_size: # (self.n_agents*self.batch_size + self.batch_size): return if self.gsp: @@ -431,14 +440,22 @@ def learn(self): def learn_gsp(self): if self.gsp_networks['replay'].mem_ctr < self.gsp_batch_size: return + # Capture the inner learn call's loss so callers can observe the GSP prediction + # network's training loss directly (needed for the information-collapse diagnostic). + loss = None if self.gsp_networks['learning_scheme'] in {'DDPG'}: - self.learn_DDPG(self.gsp_networks, self.gsp, self.recurrent_gsp) + loss = self.learn_DDPG(self.gsp_networks, self.gsp, self.recurrent_gsp) elif self.gsp_networks['learning_scheme'] in {'RDDPG'}: - self.learn_RDDPG(self.gsp_networks, self.gsp, self.recurrent_gsp) + loss = self.learn_RDDPG(self.gsp_networks, self.gsp, self.recurrent_gsp) elif self.gsp_networks['learning_scheme'] == 'TD3': - self.learn_TD3(self.gsp_networks, self.gsp, self.recurrent_gsp) + loss = self.learn_TD3(self.gsp_networks, self.gsp, self.recurrent_gsp) elif self.gsp_networks['learning_scheme'] == 'attention': - self.learn_attention(self.gsp_networks) + loss = self.learn_attention(self.gsp_networks) + if loss is not None: + # TD3's edge-case path returns (0, 0); normalize to a scalar for logging. + if isinstance(loss, tuple): + loss = loss[0] + self.last_gsp_loss = float(loss) def store_agent_transition(self, s, a, r, s_, d): self.store_transition(s, a, r, s_, d, self.networks) diff --git a/tests/test_actor/test_gsp_loss_exposure.py b/tests/test_actor/test_gsp_loss_exposure.py new file mode 100644 index 0000000..0cbfef6 --- /dev/null +++ b/tests/test_actor/test_gsp_loss_exposure.py @@ -0,0 +1,107 @@ +"""Tests for exposing per-step GSP prediction loss from Actor.learn(). + +Context: the information-collapse diagnostic (see Stelaris +docs/specs/2026-04-12-dispatcher-diagnostic-batch.md) requires logging the +GSP prediction network's training loss per learning step. The primary loss +already returned from learn() is the actor/critic loss, which stays normal +even when the GSP prediction head has collapsed to a near-constant output. +""" + +import numpy as np +import pytest + +from gsp_rl.src.actors.actor import Actor + + +BASE_CONFIG = { + "GAMMA": 0.99, + "TAU": 0.005, + "ALPHA": 0.001, + "BETA": 0.002, + "LR": 0.001, + "EPSILON": 0.0, + "EPS_MIN": 0.0, + "EPS_DEC": 0.0, + "BATCH_SIZE": 8, + "MEM_SIZE": 100, + "REPLACE_TARGET_COUNTER": 10, + "NOISE": 0.0, + "UPDATE_ACTOR_ITER": 1, + "WARMUP": 0, + # Fire GSP learning every primary learn() call so the test doesn't need to iterate 100 times. + "GSP_LEARNING_FREQUENCY": 1, + "GSP_BATCH_SIZE": 8, +} + +INPUT_SIZE = 8 +OUTPUT_SIZE = 4 +GSP_INPUT_SIZE = 6 +GSP_OUTPUT_SIZE = 1 +N_TRANSITIONS = 30 + + +def make_gsp_actor(network="DDPG"): + return Actor( + id=1, + config=BASE_CONFIG, + network=network, + input_size=INPUT_SIZE, + output_size=OUTPUT_SIZE, + min_max_action=1, + meta_param_size=1, + gsp=True, + gsp_input_size=GSP_INPUT_SIZE, + gsp_output_size=GSP_OUTPUT_SIZE, + ) + + +def fill_primary_and_gsp_buffers(actor, n=N_TRANSITIONS): + for _ in range(n): + # Primary network sees inputs that include the GSP head output (network_input_size) + s = np.random.random(actor.network_input_size).astype(np.float32) + s_ = np.random.random(actor.network_input_size).astype(np.float32) + a = actor.choose_action(s, actor.networks, test=True) + r = float(np.random.random()) + actor.store_transition(s, a, r, s_, False, actor.networks) + + # GSP prediction network training transitions + gsp_s = np.random.random(GSP_INPUT_SIZE).astype(np.float32) + gsp_s_ = np.random.random(GSP_INPUT_SIZE).astype(np.float32) + gsp_a = np.random.uniform(-1, 1, size=GSP_OUTPUT_SIZE).astype(np.float32) + gsp_r = float(np.random.random()) + actor.store_gsp_transition(gsp_s, gsp_a, gsp_r, gsp_s_, False) + + +class TestGSPLossExposure: + def test_last_gsp_loss_initialized_to_none(self): + actor = make_gsp_actor() + assert actor.last_gsp_loss is None + + def test_last_gsp_loss_populated_after_learn_with_gsp(self): + """After primary + GSP buffers are filled and learn() runs, last_gsp_loss is a float.""" + actor = make_gsp_actor() + fill_primary_and_gsp_buffers(actor) + actor.learn() + assert actor.last_gsp_loss is not None + assert isinstance(actor.last_gsp_loss, float) + + def test_last_gsp_loss_remains_none_when_gsp_disabled(self): + """Non-GSP actor never populates last_gsp_loss.""" + actor = Actor( + id=1, + config=BASE_CONFIG, + network="DDPG", + input_size=INPUT_SIZE, + output_size=OUTPUT_SIZE, + min_max_action=1, + meta_param_size=1, + gsp=False, + ) + for _ in range(N_TRANSITIONS): + s = np.random.random(INPUT_SIZE).astype(np.float32) + a = actor.choose_action(s, actor.networks, test=True) + r = float(np.random.random()) + s_ = np.random.random(INPUT_SIZE).astype(np.float32) + actor.store_transition(s, a, r, s_, False, actor.networks) + actor.learn() + assert actor.last_gsp_loss is None From a61d4b2b61fce6b66ba39310802c45fb62e98f21 Mon Sep 17 00:00:00 2001 From: Joshua Bloom Date: Sun, 12 Apr 2026 20:25:31 -0400 Subject: [PATCH 2/2] =?UTF-8?q?fix(actor):=20address=20PR=20#22=20review?= =?UTF-8?q?=20=E2=80=94=20TD3=20signature,=20tuple=20skip,=20docs,=20reset?= =?UTF-8?q?=20test?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - learn_TD3 now accepts recurrent=False to match the DDPG/RDDPG signatures; the learn_gsp dispatch was passing 3 positional args to a 2-arg method. Latent bug today (GSP networks are built as DDPG/attention, not TD3) but removes the footgun before the diagnostic batch exercises TD3 variants. - TD3's non-actor-update step returns (0, 0); previously we unwrapped to 0.0 and logged it. That produces false collapse signals every update_actor_iter-1 ticks. Now we skip the entry entirely — leave last_gsp_loss at None as if no GSP step ran. - Doc the semantic: last_gsp_loss is the GSP learner's training loss, which is actor loss (policy-gradient signal) for DDPG/RDDPG/TD3 and genuine MSE only for attention. For prediction-collapse detection consumers should rely on gsp_squared_error and the HDF5Logger episode-level gsp_output_std / gsp_pred_target_corr attrs. - Add reset-between-ticks test covering the load-bearing invariant that last_gsp_loss returns to None when a learn() call runs but no GSP learning step fires. Co-Authored-By: Claude Opus 4.6 (1M context) --- gsp_rl/src/actors/actor.py | 18 ++++++++++++++--- gsp_rl/src/actors/learning_aids.py | 2 +- tests/test_actor/test_gsp_loss_exposure.py | 23 ++++++++++++++++++++++ 3 files changed, 39 insertions(+), 4 deletions(-) diff --git a/gsp_rl/src/actors/actor.py b/gsp_rl/src/actors/actor.py index 5c12736..5d1541b 100644 --- a/gsp_rl/src/actors/actor.py +++ b/gsp_rl/src/actors/actor.py @@ -113,7 +113,16 @@ def __init__( self.build_gsp_network('attention') self.build_gsp_network('DDPG') - # Information-collapse diagnostic: last GSP prediction network training loss. + # Information-collapse diagnostic: last GSP learner training loss. + # NOTE: this is the loss returned by the GSP learner's inner learn step, which means: + # - For DDPG/RDDPG/TD3 GSP schemes: actor loss (a critic-derived policy-gradient + # signal), NOT the prediction MSE against delta-theta. A collapsed predictor may + # not produce an anomalous value here, since the critic's value landscape can + # support multiple policy solutions. + # - For the attention GSP scheme: genuine prediction MSE against the label. + # For prediction-collapse detection, prefer the raw per-step squared error captured + # in RL-CollectiveTransport as `gsp_squared_error` plus the episode-level + # `gsp_output_std` / `gsp_pred_target_corr` attrs computed in the Stelaris HDF5Logger. # Populated by learn_gsp() whenever a GSP learning step fires; reset to None at the # start of each learn() call so callers can distinguish "no GSP step this tick" from # "GSP step ran". @@ -452,9 +461,12 @@ def learn_gsp(self): elif self.gsp_networks['learning_scheme'] == 'attention': loss = self.learn_attention(self.gsp_networks) if loss is not None: - # TD3's edge-case path returns (0, 0); normalize to a scalar for logging. + # TD3's non-actor-update steps return (0, 0) (critic stepped, actor did not). + # Recording a legitimate 0.0 there would produce false collapse signals every + # `update_actor_iter - 1` ticks, so skip those entries entirely — leave + # last_gsp_loss at None as if no GSP step ran this tick. if isinstance(loss, tuple): - loss = loss[0] + return self.last_gsp_loss = float(loss) def store_agent_transition(self, s, a, r, s_, d): diff --git a/gsp_rl/src/actors/learning_aids.py b/gsp_rl/src/actors/learning_aids.py index 100aeb8..4df433a 100644 --- a/gsp_rl/src/actors/learning_aids.py +++ b/gsp_rl/src/actors/learning_aids.py @@ -405,7 +405,7 @@ def learn_RDDPG(self, networks, gsp = False, recurrent = False): return actor_loss.item() - def learn_TD3(self, networks, gsp = False): + def learn_TD3(self, networks, gsp = False, recurrent = False): states, actions, rewards, states_, dones = self.sample_memory(networks) with T.no_grad(): diff --git a/tests/test_actor/test_gsp_loss_exposure.py b/tests/test_actor/test_gsp_loss_exposure.py index 0cbfef6..821ff5d 100644 --- a/tests/test_actor/test_gsp_loss_exposure.py +++ b/tests/test_actor/test_gsp_loss_exposure.py @@ -85,6 +85,29 @@ def test_last_gsp_loss_populated_after_learn_with_gsp(self): assert actor.last_gsp_loss is not None assert isinstance(actor.last_gsp_loss, float) + def test_last_gsp_loss_resets_between_ticks(self): + """Each learn() call starts by resetting last_gsp_loss to None. + + This is the load-bearing invariant of the field: consumers must be able to read + it after learn() and distinguish "no GSP step ran this tick" (None) from + "GSP step ran and returned a value" (float). If the reset fails, a stale value + from a previous tick bleeds into the current tick's reading. + """ + actor = make_gsp_actor() + fill_primary_and_gsp_buffers(actor) + actor.learn() + assert actor.last_gsp_loss is not None # populated after first learn + + # Drain the GSP replay buffer below the batch size so the next learn_gsp early-returns. + # Simulate by swapping in an empty gsp replay buffer. This is a white-box probe of the + # reset invariant rather than a full end-to-end run. + actor.gsp_networks['replay'].mem_ctr = 0 + + actor.learn() + assert actor.last_gsp_loss is None, ( + "last_gsp_loss should reset to None when learn() runs but no GSP step fires" + ) + def test_last_gsp_loss_remains_none_when_gsp_disabled(self): """Non-GSP actor never populates last_gsp_loss.""" actor = Actor(