From a5dce0ef72b6b9628bca7d63fc73187e6d0584cd Mon Sep 17 00:00:00 2001 From: Joshua Bloom Date: Mon, 13 Apr 2026 09:51:00 -0400 Subject: [PATCH] fix(actor): train GSP predictor via direct MSE instead of DDPG actor-critic MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replaces the DDPG/RDDPG/TD3 actor-critic training path for GSP prediction networks with direct supervised MSE regression against the ground-truth delta-theta label. This is Option A (minimal fix) from the analysis at Stelaris docs/research/2026-04-13-gsp-information-collapse-analysis.md. Root cause: training the GSP predictor as a DDPG agent on a clipped negative-MSE reward r = clip(-|pred - label|^2, -2, 0) is a category error. DDPG's deterministic policy gradient flows through a Q-critic whose value landscape becomes flat when the reward is clipped or when the policy converges to any constant output — the DPG update then vanishes and the predictor freezes in place. Live evidence from the diagnostic batch (50-150 eps): - DDQN+GSP predictor MSE 0.0601 vs trivial-mean MSE 0.0521 (worse than predicting a constant) - DDPG+R-GSP-N pred std collapsed to 0.00019 (near-constant zero) - Correlation(pred, target) ~= 0 across all DDPG-trained variants - A-GSP-N (the only variant trained by direct MSE in learn_attention) did NOT collapse — the controlled experiment that isolates the training mechanism as the cause. This commit adds learn_gsp_mse(networks, recurrent) in learning_aids.py and rewires actor.learn_gsp() to dispatch to it for DDPG/RDDPG/TD3 schemes. The attention scheme is unchanged. The method samples (state, label) pairs from the replay buffer (label in the action field by the RL-CT call-site convention) and minimizes MSE directly against the label with a non-vanishing gradient. Tests: 2 new cases in tests/test_actor/test_gsp_direct_mse.py: - test_learn_gsp_mse_beats_trivial_mean_baseline_on_linear_task verifies that after 200 learn steps on a linear state→label mapping, the predictor's MSE is lower than predicting the mean. Under the previous DDPG path this test fails by ~0.5%, matching the live collapse signature. - test_learn_gsp_populates_last_gsp_loss_for_ddpg_variant confirms the diagnostic field still populates under the new dispatch. Full actor + learning_aids suite: 72/72 pass. Companion change required in RL-CollectiveTransport: Main.py's store_gsp_transition call sites must pass the ground-truth label (currently passing the previous prediction for non-attention variants). Co-Authored-By: Claude Opus 4.6 (1M context) --- gsp_rl/src/actors/actor.py | 29 +++--- gsp_rl/src/actors/learning_aids.py | 53 +++++++++- tests/test_actor/test_gsp_direct_mse.py | 126 ++++++++++++++++++++++++ 3 files changed, 194 insertions(+), 14 deletions(-) create mode 100644 tests/test_actor/test_gsp_direct_mse.py diff --git a/gsp_rl/src/actors/actor.py b/gsp_rl/src/actors/actor.py index 5d1541b..c649073 100644 --- a/gsp_rl/src/actors/actor.py +++ b/gsp_rl/src/actors/actor.py @@ -449,22 +449,25 @@ def learn(self): def learn_gsp(self): if self.gsp_networks['replay'].mem_ctr < self.gsp_batch_size: return - # Capture the inner learn call's loss so callers can observe the GSP prediction - # network's training loss directly (needed for the information-collapse diagnostic). + # HISTORICAL NOTE: this used to dispatch to learn_DDPG / learn_RDDPG / + # learn_TD3 with gsp=True for non-attention variants, training the GSP + # predictor as a DDPG actor-critic on a clipped negative-MSE reward. + # That produced an information-collapsed predictor whose output was + # empirically worse than predicting the constant mean. Replaced on + # 2026-04-13 with direct supervised MSE for all non-attention variants. + # See Stelaris docs/research/2026-04-13-gsp-information-collapse-analysis.md + # for root cause analysis. loss = None - if self.gsp_networks['learning_scheme'] in {'DDPG'}: - loss = self.learn_DDPG(self.gsp_networks, self.gsp, self.recurrent_gsp) - elif self.gsp_networks['learning_scheme'] in {'RDDPG'}: - loss = self.learn_RDDPG(self.gsp_networks, self.gsp, self.recurrent_gsp) - elif self.gsp_networks['learning_scheme'] == 'TD3': - loss = self.learn_TD3(self.gsp_networks, self.gsp, self.recurrent_gsp) - elif self.gsp_networks['learning_scheme'] == 'attention': + scheme = self.gsp_networks['learning_scheme'] + if scheme == 'attention': loss = self.learn_attention(self.gsp_networks) + elif scheme == 'RDDPG': + loss = self.learn_gsp_mse(self.gsp_networks, recurrent=True) + elif scheme in {'DDPG', 'TD3'}: + loss = self.learn_gsp_mse(self.gsp_networks, recurrent=False) if loss is not None: - # TD3's non-actor-update steps return (0, 0) (critic stepped, actor did not). - # Recording a legitimate 0.0 there would produce false collapse signals every - # `update_actor_iter - 1` ticks, so skip those entries entirely — leave - # last_gsp_loss at None as if no GSP step ran this tick. + # Keep the tuple-skip guard for safety in case learn_attention's + # return type ever changes; learn_gsp_mse returns a plain float. if isinstance(loss, tuple): return self.last_gsp_loss = float(loss) diff --git a/gsp_rl/src/actors/learning_aids.py b/gsp_rl/src/actors/learning_aids.py index 4df433a..7deafbc 100644 --- a/gsp_rl/src/actors/learning_aids.py +++ b/gsp_rl/src/actors/learning_aids.py @@ -471,7 +471,58 @@ def learn_attention(self, networks): _check_nan(loss, f"Attention loss at step {networks['learn_step_counter']}") networks['attention'].optimizer.step() return loss.item() - + + def learn_gsp_mse(self, networks, recurrent: bool = False): + """Train the GSP prediction network via direct supervised MSE. + + Replaces the DDPG/RDDPG actor-critic path for non-attention GSP variants. + Samples (state, label) pairs from `networks['replay']`, forwards the state + through the actor network, and minimizes MSE against the label. The label + is stored in the action field of the replay buffer by convention — see + RL-CollectiveTransport Main.py's store_gsp_transition call sites. + + Rationale: see docs/research/2026-04-13-gsp-information-collapse-analysis.md + in the Stelaris repo. Training the GSP predictor as a DDPG actor-critic + on a clipped negative-MSE reward produced an information-collapsed + predictor whose output was worse than predicting the constant mean. + Direct supervised MSE has a non-vanishing gradient `2(pred-label)` that + drives the predictor toward the label regardless of how flat the reward + landscape is. + """ + if networks['replay'].mem_ctr < self.gsp_batch_size: + return None + + if recurrent: + mem_result = self.sample_memory(networks) + if len(mem_result) == 7: + states, labels, _, _, _, _, _ = mem_result + else: + states, labels, _, _, _ = mem_result + networks['actor'].optimizer.zero_grad() + preds_out = networks['actor'](states, hidden=None) + preds = preds_out[0] if isinstance(preds_out, tuple) else preds_out + if preds.dim() == labels.dim() + 1: + labels_shaped = labels.unsqueeze(-1) + else: + labels_shaped = labels.view_as(preds) + loss = F.mse_loss(preds, labels_shaped) + else: + states, labels, _, _, _ = self.sample_memory(networks) + networks['actor'].optimizer.zero_grad() + preds = networks['actor'].forward(states) + # labels shape: (batch,) or (batch, 1). preds shape: (batch, 1). + if labels.dim() == preds.dim() - 1: + labels_shaped = labels.unsqueeze(-1) + else: + labels_shaped = labels.view_as(preds) + loss = F.mse_loss(preds, labels_shaped) + + loss.backward() + _check_nan(loss, f"GSP MSE loss at step {networks['learn_step_counter']}") + networks['actor'].optimizer.step() + networks['learn_step_counter'] += 1 + return loss.item() + def decrement_epsilon(self): self.epsilon = max(self.epsilon-self.eps_dec, self.eps_min) diff --git a/tests/test_actor/test_gsp_direct_mse.py b/tests/test_actor/test_gsp_direct_mse.py new file mode 100644 index 0000000..8e994c9 --- /dev/null +++ b/tests/test_actor/test_gsp_direct_mse.py @@ -0,0 +1,126 @@ +"""Tests for direct-MSE GSP training (Option A from the collapse analysis). + +Verifies that: +1. After training on a deterministic state→label mapping, the GSP predictor's + MSE is lower than a trivial "predict the mean" baseline. +2. `last_gsp_loss` is populated after learn() runs for a DDPG-GSP actor. + +See docs/research/2026-04-13-gsp-information-collapse-analysis.md in Stelaris +for the root cause and the rationale for switching from DDPG actor-critic to +direct supervised MSE. +""" + +import numpy as np +import pytest +import torch + +from gsp_rl.src.actors.actor import Actor + + +BASE_CONFIG = { + "GAMMA": 0.99, + "TAU": 0.005, + "ALPHA": 0.001, + "BETA": 0.002, + "LR": 0.001, + "EPSILON": 0.0, + "EPS_MIN": 0.0, + "EPS_DEC": 0.0, + "BATCH_SIZE": 16, + "MEM_SIZE": 1000, + "REPLACE_TARGET_COUNTER": 10, + "NOISE": 0.0, + "UPDATE_ACTOR_ITER": 1, + "WARMUP": 0, + "GSP_LEARNING_FREQUENCY": 1, + "GSP_BATCH_SIZE": 16, +} + + +INPUT_SIZE = 8 +OUTPUT_SIZE = 4 +GSP_INPUT_SIZE = 6 +GSP_OUTPUT_SIZE = 1 + + +def make_gsp_actor(network="DDPG"): + return Actor( + id=1, + config=BASE_CONFIG, + network=network, + input_size=INPUT_SIZE, + output_size=OUTPUT_SIZE, + min_max_action=1, + meta_param_size=1, + gsp=True, + gsp_input_size=GSP_INPUT_SIZE, + gsp_output_size=GSP_OUTPUT_SIZE, + ) + + +def _fill_gsp_buffer_with_linear_labels(actor, n_transitions=400, seed=0): + """Store (state, label) pairs where label = mean(state). + + A predictor trained with direct MSE should beat the trivial-mean baseline + within ~200 steps on this trivially-learnable mapping. + """ + rng = np.random.default_rng(seed) + stored_states = [] + stored_labels = [] + for _ in range(n_transitions): + state = rng.uniform(-1, 1, size=GSP_INPUT_SIZE).astype(np.float32) + label = float(np.mean(state)) + # Label carried in the action field under the new direct-MSE convention. + actor.store_gsp_transition( + state, np.float32(label), 0.0, np.zeros_like(state), False + ) + stored_states.append(state) + stored_labels.append(label) + return np.stack(stored_states), np.array(stored_labels, dtype=np.float32) + + +def _fill_primary_buffer(actor, n=20): + """Primary replay buffer must have >= BATCH_SIZE transitions for learn().""" + rng = np.random.default_rng(42) + for _ in range(n): + s = rng.random(actor.network_input_size).astype(np.float32) + s_ = rng.random(actor.network_input_size).astype(np.float32) + a = actor.choose_action(s, actor.networks, test=True) + actor.store_transition(s, a, 0.0, s_, False, actor.networks) + + +def test_learn_gsp_mse_beats_trivial_mean_baseline_on_linear_task(): + """After training, the predictor's MSE is lower than predicting the constant mean.""" + torch.manual_seed(0) + np.random.seed(0) + actor = make_gsp_actor(network="DDPG") + states, labels = _fill_gsp_buffer_with_linear_labels(actor, n_transitions=400, seed=0) + _fill_primary_buffer(actor) + + for _ in range(200): + actor.learn() + + net = actor.gsp_networks["actor"] + with torch.no_grad(): + states_t = torch.from_numpy(states).to(net.device) + preds = net.forward(states_t).cpu().numpy().ravel() + + pred_mse = float(np.mean((preds - labels) ** 2)) + trivial_mse = float(np.mean((labels - labels.mean()) ** 2)) + assert pred_mse < trivial_mse, ( + f"Direct-MSE GSP training did not beat trivial baseline: " + f"pred_mse={pred_mse:.5f} trivial_mse={trivial_mse:.5f}" + ) + + +def test_learn_gsp_populates_last_gsp_loss_for_ddpg_variant(): + """last_gsp_loss is populated after learn() for a DDPG-GSP actor.""" + torch.manual_seed(0) + actor = make_gsp_actor(network="DDPG") + _fill_gsp_buffer_with_linear_labels(actor, n_transitions=100, seed=1) + _fill_primary_buffer(actor) + + actor.learn() + assert actor.last_gsp_loss is not None + assert isinstance(actor.last_gsp_loss, float) + assert actor.last_gsp_loss >= 0.0