Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 34 additions & 5 deletions gsp_rl/src/actors/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,21 @@ def __init__(
if attention:
self.build_gsp_network('attention')
self.build_gsp_network('DDPG')

# Information-collapse diagnostic: last GSP learner training loss.
# NOTE: this is the loss returned by the GSP learner's inner learn step, which means:
# - For DDPG/RDDPG/TD3 GSP schemes: actor loss (a critic-derived policy-gradient
# signal), NOT the prediction MSE against delta-theta. A collapsed predictor may
# not produce an anomalous value here, since the critic's value landscape can
# support multiple policy solutions.
# - For the attention GSP scheme: genuine prediction MSE against the label.
# For prediction-collapse detection, prefer the raw per-step squared error captured
# in RL-CollectiveTransport as `gsp_squared_error` plus the episode-level
# `gsp_output_std` / `gsp_pred_target_corr` attrs computed in the Stelaris HDF5Logger.
# Populated by learn_gsp() whenever a GSP learning step fires; reset to None at the
# start of each learn() call so callers can distinguish "no GSP step this tick" from
# "GSP step ran".
self.last_gsp_loss: float | None = None

def build_networks(self, learning_scheme):
if learning_scheme == 'None':
Expand Down Expand Up @@ -402,8 +417,11 @@ def choose_actions_batch(self, observations, networks, test=False):
f"Use choose_action() for RDDPG/attention networks.")

def learn(self):
# Reset per-tick GSP loss signal. None means "no GSP step ran this tick".
self.last_gsp_loss = None

# TODO Not sure why we have n_agents*batch_size + batch_size
if self.networks['replay'].mem_ctr < self.batch_size: # (self.n_agents*self.batch_size + self.batch_size):
if self.networks['replay'].mem_ctr < self.batch_size: # (self.n_agents*self.batch_size + self.batch_size):
return

if self.gsp:
Expand Down Expand Up @@ -431,14 +449,25 @@ def learn(self):
def learn_gsp(self):
if self.gsp_networks['replay'].mem_ctr < self.gsp_batch_size:
return
# Capture the inner learn call's loss so callers can observe the GSP prediction
# network's training loss directly (needed for the information-collapse diagnostic).
loss = None
if self.gsp_networks['learning_scheme'] in {'DDPG'}:
self.learn_DDPG(self.gsp_networks, self.gsp, self.recurrent_gsp)
loss = self.learn_DDPG(self.gsp_networks, self.gsp, self.recurrent_gsp)
elif self.gsp_networks['learning_scheme'] in {'RDDPG'}:
self.learn_RDDPG(self.gsp_networks, self.gsp, self.recurrent_gsp)
loss = self.learn_RDDPG(self.gsp_networks, self.gsp, self.recurrent_gsp)
elif self.gsp_networks['learning_scheme'] == 'TD3':
self.learn_TD3(self.gsp_networks, self.gsp, self.recurrent_gsp)
loss = self.learn_TD3(self.gsp_networks, self.gsp, self.recurrent_gsp)
elif self.gsp_networks['learning_scheme'] == 'attention':
self.learn_attention(self.gsp_networks)
loss = self.learn_attention(self.gsp_networks)
if loss is not None:
# TD3's non-actor-update steps return (0, 0) (critic stepped, actor did not).
# Recording a legitimate 0.0 there would produce false collapse signals every
# `update_actor_iter - 1` ticks, so skip those entries entirely — leave
# last_gsp_loss at None as if no GSP step ran this tick.
if isinstance(loss, tuple):
return
self.last_gsp_loss = float(loss)

def store_agent_transition(self, s, a, r, s_, d):
self.store_transition(s, a, r, s_, d, self.networks)
Expand Down
2 changes: 1 addition & 1 deletion gsp_rl/src/actors/learning_aids.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ def learn_RDDPG(self, networks, gsp = False, recurrent = False):

return actor_loss.item()

def learn_TD3(self, networks, gsp = False):
def learn_TD3(self, networks, gsp = False, recurrent = False):
states, actions, rewards, states_, dones = self.sample_memory(networks)

with T.no_grad():
Expand Down
130 changes: 130 additions & 0 deletions tests/test_actor/test_gsp_loss_exposure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
"""Tests for exposing per-step GSP prediction loss from Actor.learn().

Context: the information-collapse diagnostic (see Stelaris
docs/specs/2026-04-12-dispatcher-diagnostic-batch.md) requires logging the
GSP prediction network's training loss per learning step. The primary loss
already returned from learn() is the actor/critic loss, which stays normal
even when the GSP prediction head has collapsed to a near-constant output.
"""

import numpy as np
import pytest

from gsp_rl.src.actors.actor import Actor


BASE_CONFIG = {
"GAMMA": 0.99,
"TAU": 0.005,
"ALPHA": 0.001,
"BETA": 0.002,
"LR": 0.001,
"EPSILON": 0.0,
"EPS_MIN": 0.0,
"EPS_DEC": 0.0,
"BATCH_SIZE": 8,
"MEM_SIZE": 100,
"REPLACE_TARGET_COUNTER": 10,
"NOISE": 0.0,
"UPDATE_ACTOR_ITER": 1,
"WARMUP": 0,
# Fire GSP learning every primary learn() call so the test doesn't need to iterate 100 times.
"GSP_LEARNING_FREQUENCY": 1,
"GSP_BATCH_SIZE": 8,
}

INPUT_SIZE = 8
OUTPUT_SIZE = 4
GSP_INPUT_SIZE = 6
GSP_OUTPUT_SIZE = 1
N_TRANSITIONS = 30


def make_gsp_actor(network="DDPG"):
return Actor(
id=1,
config=BASE_CONFIG,
network=network,
input_size=INPUT_SIZE,
output_size=OUTPUT_SIZE,
min_max_action=1,
meta_param_size=1,
gsp=True,
gsp_input_size=GSP_INPUT_SIZE,
gsp_output_size=GSP_OUTPUT_SIZE,
)


def fill_primary_and_gsp_buffers(actor, n=N_TRANSITIONS):
for _ in range(n):
# Primary network sees inputs that include the GSP head output (network_input_size)
s = np.random.random(actor.network_input_size).astype(np.float32)
s_ = np.random.random(actor.network_input_size).astype(np.float32)
a = actor.choose_action(s, actor.networks, test=True)
r = float(np.random.random())
actor.store_transition(s, a, r, s_, False, actor.networks)

# GSP prediction network training transitions
gsp_s = np.random.random(GSP_INPUT_SIZE).astype(np.float32)
gsp_s_ = np.random.random(GSP_INPUT_SIZE).astype(np.float32)
gsp_a = np.random.uniform(-1, 1, size=GSP_OUTPUT_SIZE).astype(np.float32)
gsp_r = float(np.random.random())
actor.store_gsp_transition(gsp_s, gsp_a, gsp_r, gsp_s_, False)


class TestGSPLossExposure:
def test_last_gsp_loss_initialized_to_none(self):
actor = make_gsp_actor()
assert actor.last_gsp_loss is None

def test_last_gsp_loss_populated_after_learn_with_gsp(self):
"""After primary + GSP buffers are filled and learn() runs, last_gsp_loss is a float."""
actor = make_gsp_actor()
fill_primary_and_gsp_buffers(actor)
actor.learn()
assert actor.last_gsp_loss is not None
assert isinstance(actor.last_gsp_loss, float)

def test_last_gsp_loss_resets_between_ticks(self):
"""Each learn() call starts by resetting last_gsp_loss to None.

This is the load-bearing invariant of the field: consumers must be able to read
it after learn() and distinguish "no GSP step ran this tick" (None) from
"GSP step ran and returned a value" (float). If the reset fails, a stale value
from a previous tick bleeds into the current tick's reading.
"""
actor = make_gsp_actor()
fill_primary_and_gsp_buffers(actor)
actor.learn()
assert actor.last_gsp_loss is not None # populated after first learn

# Drain the GSP replay buffer below the batch size so the next learn_gsp early-returns.
# Simulate by swapping in an empty gsp replay buffer. This is a white-box probe of the
# reset invariant rather than a full end-to-end run.
actor.gsp_networks['replay'].mem_ctr = 0

actor.learn()
assert actor.last_gsp_loss is None, (
"last_gsp_loss should reset to None when learn() runs but no GSP step fires"
)

def test_last_gsp_loss_remains_none_when_gsp_disabled(self):
"""Non-GSP actor never populates last_gsp_loss."""
actor = Actor(
id=1,
config=BASE_CONFIG,
network="DDPG",
input_size=INPUT_SIZE,
output_size=OUTPUT_SIZE,
min_max_action=1,
meta_param_size=1,
gsp=False,
)
for _ in range(N_TRANSITIONS):
s = np.random.random(INPUT_SIZE).astype(np.float32)
a = actor.choose_action(s, actor.networks, test=True)
r = float(np.random.random())
s_ = np.random.random(INPUT_SIZE).astype(np.float32)
actor.store_transition(s, a, r, s_, False, actor.networks)
actor.learn()
assert actor.last_gsp_loss is None
Loading