Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 16 additions & 13 deletions gsp_rl/src/actors/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,22 +449,25 @@ def learn(self):
def learn_gsp(self):
if self.gsp_networks['replay'].mem_ctr < self.gsp_batch_size:
return
# Capture the inner learn call's loss so callers can observe the GSP prediction
# network's training loss directly (needed for the information-collapse diagnostic).
# HISTORICAL NOTE: this used to dispatch to learn_DDPG / learn_RDDPG /
# learn_TD3 with gsp=True for non-attention variants, training the GSP
# predictor as a DDPG actor-critic on a clipped negative-MSE reward.
# That produced an information-collapsed predictor whose output was
# empirically worse than predicting the constant mean. Replaced on
# 2026-04-13 with direct supervised MSE for all non-attention variants.
# See Stelaris docs/research/2026-04-13-gsp-information-collapse-analysis.md
# for root cause analysis.
loss = None
if self.gsp_networks['learning_scheme'] in {'DDPG'}:
loss = self.learn_DDPG(self.gsp_networks, self.gsp, self.recurrent_gsp)
elif self.gsp_networks['learning_scheme'] in {'RDDPG'}:
loss = self.learn_RDDPG(self.gsp_networks, self.gsp, self.recurrent_gsp)
elif self.gsp_networks['learning_scheme'] == 'TD3':
loss = self.learn_TD3(self.gsp_networks, self.gsp, self.recurrent_gsp)
elif self.gsp_networks['learning_scheme'] == 'attention':
scheme = self.gsp_networks['learning_scheme']
if scheme == 'attention':
loss = self.learn_attention(self.gsp_networks)
elif scheme == 'RDDPG':
loss = self.learn_gsp_mse(self.gsp_networks, recurrent=True)
elif scheme in {'DDPG', 'TD3'}:
loss = self.learn_gsp_mse(self.gsp_networks, recurrent=False)
if loss is not None:
# TD3's non-actor-update steps return (0, 0) (critic stepped, actor did not).
# Recording a legitimate 0.0 there would produce false collapse signals every
# `update_actor_iter - 1` ticks, so skip those entries entirely — leave
# last_gsp_loss at None as if no GSP step ran this tick.
# Keep the tuple-skip guard for safety in case learn_attention's
# return type ever changes; learn_gsp_mse returns a plain float.
if isinstance(loss, tuple):
return
self.last_gsp_loss = float(loss)
Expand Down
53 changes: 52 additions & 1 deletion gsp_rl/src/actors/learning_aids.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,58 @@ def learn_attention(self, networks):
_check_nan(loss, f"Attention loss at step {networks['learn_step_counter']}")
networks['attention'].optimizer.step()
return loss.item()


def learn_gsp_mse(self, networks, recurrent: bool = False):
"""Train the GSP prediction network via direct supervised MSE.

Replaces the DDPG/RDDPG actor-critic path for non-attention GSP variants.
Samples (state, label) pairs from `networks['replay']`, forwards the state
through the actor network, and minimizes MSE against the label. The label
is stored in the action field of the replay buffer by convention — see
RL-CollectiveTransport Main.py's store_gsp_transition call sites.

Rationale: see docs/research/2026-04-13-gsp-information-collapse-analysis.md
in the Stelaris repo. Training the GSP predictor as a DDPG actor-critic
on a clipped negative-MSE reward produced an information-collapsed
predictor whose output was worse than predicting the constant mean.
Direct supervised MSE has a non-vanishing gradient `2(pred-label)` that
drives the predictor toward the label regardless of how flat the reward
landscape is.
"""
if networks['replay'].mem_ctr < self.gsp_batch_size:
return None

if recurrent:
mem_result = self.sample_memory(networks)
if len(mem_result) == 7:
states, labels, _, _, _, _, _ = mem_result
else:
states, labels, _, _, _ = mem_result
networks['actor'].optimizer.zero_grad()
preds_out = networks['actor'](states, hidden=None)
preds = preds_out[0] if isinstance(preds_out, tuple) else preds_out
if preds.dim() == labels.dim() + 1:
labels_shaped = labels.unsqueeze(-1)
else:
labels_shaped = labels.view_as(preds)
loss = F.mse_loss(preds, labels_shaped)
else:
states, labels, _, _, _ = self.sample_memory(networks)
networks['actor'].optimizer.zero_grad()
preds = networks['actor'].forward(states)
# labels shape: (batch,) or (batch, 1). preds shape: (batch, 1).
if labels.dim() == preds.dim() - 1:
labels_shaped = labels.unsqueeze(-1)
else:
labels_shaped = labels.view_as(preds)
loss = F.mse_loss(preds, labels_shaped)

loss.backward()
_check_nan(loss, f"GSP MSE loss at step {networks['learn_step_counter']}")
networks['actor'].optimizer.step()
networks['learn_step_counter'] += 1
return loss.item()

def decrement_epsilon(self):
self.epsilon = max(self.epsilon-self.eps_dec, self.eps_min)

Expand Down
126 changes: 126 additions & 0 deletions tests/test_actor/test_gsp_direct_mse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
"""Tests for direct-MSE GSP training (Option A from the collapse analysis).

Verifies that:
1. After training on a deterministic state→label mapping, the GSP predictor's
MSE is lower than a trivial "predict the mean" baseline.
2. `last_gsp_loss` is populated after learn() runs for a DDPG-GSP actor.

See docs/research/2026-04-13-gsp-information-collapse-analysis.md in Stelaris
for the root cause and the rationale for switching from DDPG actor-critic to
direct supervised MSE.
"""

import numpy as np
import pytest
import torch

from gsp_rl.src.actors.actor import Actor


BASE_CONFIG = {
"GAMMA": 0.99,
"TAU": 0.005,
"ALPHA": 0.001,
"BETA": 0.002,
"LR": 0.001,
"EPSILON": 0.0,
"EPS_MIN": 0.0,
"EPS_DEC": 0.0,
"BATCH_SIZE": 16,
"MEM_SIZE": 1000,
"REPLACE_TARGET_COUNTER": 10,
"NOISE": 0.0,
"UPDATE_ACTOR_ITER": 1,
"WARMUP": 0,
"GSP_LEARNING_FREQUENCY": 1,
"GSP_BATCH_SIZE": 16,
}


INPUT_SIZE = 8
OUTPUT_SIZE = 4
GSP_INPUT_SIZE = 6
GSP_OUTPUT_SIZE = 1


def make_gsp_actor(network="DDPG"):
return Actor(
id=1,
config=BASE_CONFIG,
network=network,
input_size=INPUT_SIZE,
output_size=OUTPUT_SIZE,
min_max_action=1,
meta_param_size=1,
gsp=True,
gsp_input_size=GSP_INPUT_SIZE,
gsp_output_size=GSP_OUTPUT_SIZE,
)


def _fill_gsp_buffer_with_linear_labels(actor, n_transitions=400, seed=0):
"""Store (state, label) pairs where label = mean(state).

A predictor trained with direct MSE should beat the trivial-mean baseline
within ~200 steps on this trivially-learnable mapping.
"""
rng = np.random.default_rng(seed)
stored_states = []
stored_labels = []
for _ in range(n_transitions):
state = rng.uniform(-1, 1, size=GSP_INPUT_SIZE).astype(np.float32)
label = float(np.mean(state))
# Label carried in the action field under the new direct-MSE convention.
actor.store_gsp_transition(
state, np.float32(label), 0.0, np.zeros_like(state), False
)
stored_states.append(state)
stored_labels.append(label)
return np.stack(stored_states), np.array(stored_labels, dtype=np.float32)


def _fill_primary_buffer(actor, n=20):
"""Primary replay buffer must have >= BATCH_SIZE transitions for learn()."""
rng = np.random.default_rng(42)
for _ in range(n):
s = rng.random(actor.network_input_size).astype(np.float32)
s_ = rng.random(actor.network_input_size).astype(np.float32)
a = actor.choose_action(s, actor.networks, test=True)
actor.store_transition(s, a, 0.0, s_, False, actor.networks)


def test_learn_gsp_mse_beats_trivial_mean_baseline_on_linear_task():
"""After training, the predictor's MSE is lower than predicting the constant mean."""
torch.manual_seed(0)
np.random.seed(0)
actor = make_gsp_actor(network="DDPG")
states, labels = _fill_gsp_buffer_with_linear_labels(actor, n_transitions=400, seed=0)
_fill_primary_buffer(actor)

for _ in range(200):
actor.learn()

net = actor.gsp_networks["actor"]
with torch.no_grad():
states_t = torch.from_numpy(states).to(net.device)
preds = net.forward(states_t).cpu().numpy().ravel()

pred_mse = float(np.mean((preds - labels) ** 2))
trivial_mse = float(np.mean((labels - labels.mean()) ** 2))
assert pred_mse < trivial_mse, (
f"Direct-MSE GSP training did not beat trivial baseline: "
f"pred_mse={pred_mse:.5f} trivial_mse={trivial_mse:.5f}"
)


def test_learn_gsp_populates_last_gsp_loss_for_ddpg_variant():
"""last_gsp_loss is populated after learn() for a DDPG-GSP actor."""
torch.manual_seed(0)
actor = make_gsp_actor(network="DDPG")
_fill_gsp_buffer_with_linear_labels(actor, n_transitions=100, seed=1)
_fill_primary_buffer(actor)

actor.learn()
assert actor.last_gsp_loss is not None
assert isinstance(actor.last_gsp_loss, float)
assert actor.last_gsp_loss >= 0.0
Loading