From c4461280074a5496696b91e2565fb51067ef41f4 Mon Sep 17 00:00:00 2001 From: Joshua Bloom Date: Wed, 8 Apr 2026 15:19:18 -0400 Subject: [PATCH 1/4] fix(lstm): propagate LSTM hidden state through EnvironmentEncoder and RDDPG networks EnvironmentEncoder.forward now accepts optional (h_0, c_0) and returns (output, (h_n, c_n)) instead of discarding hidden state. This enables temporal memory across timesteps for RDDPG. Both RDDPGActorNetwork and RDDPGCriticNetwork forward methods updated to pass through hidden state. All callers in learn_RDDPG, DDPG_choose_action, and test files updated to unpack the new tuple return value. Co-Authored-By: Claude Opus 4.6 (1M context) --- gsp_rl/src/actors/learning_aids.py | 15 ++-- gsp_rl/src/networks/lstm.py | 44 ++++++++---- gsp_rl/src/networks/rddpg.py | 28 +++++--- .../test_GSP-RDDPG_input_output.py | 6 +- .../test_LSTM_input_output.py | 3 +- .../test_RDDPG_input_output.py | 7 +- .../test_lstm_hidden_state.py | 71 +++++++++++++++++++ 7 files changed, 139 insertions(+), 35 deletions(-) create mode 100644 tests/test_network_input_outputs/test_lstm_hidden_state.py diff --git a/gsp_rl/src/actors/learning_aids.py b/gsp_rl/src/actors/learning_aids.py index aad24e2..25bf99e 100644 --- a/gsp_rl/src/actors/learning_aids.py +++ b/gsp_rl/src/actors/learning_aids.py @@ -173,9 +173,11 @@ def DDPG_choose_action(self, observation, networks): if networks['learning_scheme'] == 'RDDPG': # if using LSTM we need to add an extra dimension state = T.tensor(np.array(observation), dtype=T.float).to(networks['actor'].device) + mu, _ = networks['actor'].forward(state) + return mu.unsqueeze(0) else: state = T.tensor(observation, dtype = T.float).to(networks['actor'].device) - return networks['actor'].forward(state).unsqueeze(0) + return networks['actor'].forward(state).unsqueeze(0) def DDPG_choose_action_batch(self, observations, networks): @@ -318,8 +320,8 @@ def learn_RDDPG(self, networks, gsp = False, recurrent = False): actions = actions.unsqueeze(1) elif recurrent: actions = actions.view(actions.shape[0], 1, actions.shape[1]) - target_actions = networks['target_actor'](states_) - q_value_ = networks['target_critic'](states_, target_actions) + target_actions, _ = networks['target_actor'](states_) + q_value_, _ = networks['target_critic'](states_, target_actions) # print('[REWARDS]', rewards.shape, T.unsqueeze(rewards, 1).shape) # print('[Q_VALUE_]', q_value_.shape, T.squeeze(T.squeeze(q_value_, -1), -1).shape) target = T.unsqueeze(rewards, 1) + self.gamma*T.squeeze(q_value_, -1) @@ -327,7 +329,7 @@ def learn_RDDPG(self, networks, gsp = False, recurrent = False): #Critic Update networks['critic'].optimizer.zero_grad() - q_value = networks['critic'](states, actions) + q_value, _ = networks['critic'](states, actions) # print('[Q_VALUE]', q_value.shape) # print('[TARGET]', target.shape) value_loss = Loss(T.squeeze(q_value, -1), target) @@ -337,8 +339,9 @@ def learn_RDDPG(self, networks, gsp = False, recurrent = False): #Actor Update networks['actor'].optimizer.zero_grad() - new_policy_actions = networks['actor'](states) - actor_loss = -networks['critic'](states, new_policy_actions) + new_policy_actions, _ = networks['actor'](states) + actor_loss_val, _ = networks['critic'](states, new_policy_actions) + actor_loss = -actor_loss_val actor_loss = actor_loss.mean() actor_loss.backward() batch_loss += actor_loss.item() diff --git a/gsp_rl/src/networks/lstm.py b/gsp_rl/src/networks/lstm.py index 07bb2da..f651ae0 100644 --- a/gsp_rl/src/networks/lstm.py +++ b/gsp_rl/src/networks/lstm.py @@ -72,23 +72,43 @@ def __init__( self.name = "Enviroment_Encoder" self.to(self.device) - def forward( - self, - observation: T.Tensor, - ) -> T.Tensor: - """Encode an observation (or sequence) through embedding + LSTM + projection. + def forward(self, observation, hidden=None): + """Encode observation through embedding + LSTM + projection. Args: - observation: Tensor of shape (seq_len, input_size) or (batch, input_size). + observation: Tensor of shape (seq_len, input_size) for single sample, + or (batch, seq_len, input_size) for batched input. + hidden: Optional (h_0, c_0) tuple. If None, LSTM uses zeros. + Shape: (num_layers, batch, hidden_size) for each. Returns: - Encoding tensor of shape (seq_len, 1, output_size). The middle dim=1 - comes from the view reshape before LSTM. + Tuple of (output, (h_n, c_n)): + output: Shape (seq_len, output_size) for single, + or (batch, seq_len, output_size) for batched. + (h_n, c_n): Final hidden state. """ - embed = self.embedding(observation) - lstm_out, _ = self.ee(embed.view(embed.shape[0], 1, -1)) - out = self.meta_layer(lstm_out) - return out + # Handle single vs batched input + if observation.dim() == 2: + # Single sample: (seq_len, input) -> add batch dim + observation = observation.unsqueeze(0) + squeeze_batch = True + else: + squeeze_batch = False + + # observation is now (batch, seq_len, input_size) + embed = self.embedding(observation) # (batch, seq_len, embed_size) + + if hidden is not None: + lstm_out, (h_n, c_n) = self.ee(embed, hidden) + else: + lstm_out, (h_n, c_n) = self.ee(embed) + + out = self.meta_layer(lstm_out) # (batch, seq_len, output_size) + + if squeeze_batch: + out = out.squeeze(0) # back to (seq_len, output_size) + + return out, (h_n, c_n) def save_checkpoint(self, path: str, intention: bool = False) -> None: """ Save Model """ diff --git a/gsp_rl/src/networks/rddpg.py b/gsp_rl/src/networks/rddpg.py index d4ec8b8..177480b 100644 --- a/gsp_rl/src/networks/rddpg.py +++ b/gsp_rl/src/networks/rddpg.py @@ -40,18 +40,22 @@ def __init__(self, environmental_encoder, ddpg_actor): self.actor.device = self.device self.optimizer = optim.Adam(self.parameters(), lr = ddpg_actor.lr, weight_decay = 1e-4) - def forward(self, x: T.Tensor) -> T.Tensor: + def forward(self, x, hidden=None): """Encode observation through LSTM, then compute action via DDPG actor. Args: - x: Observation tensor of shape (seq_len, input_size). + x: Observation tensor of shape (seq_len, input_size) or + (batch, seq_len, input_size). + hidden: Optional (h_0, c_0) tuple for the LSTM encoder. Returns: - Action tensor of shape (seq_len, 1, output_size). + Tuple of (mu, (h_n, c_n)): + mu: Action tensor. + (h_n, c_n): Final LSTM hidden state. """ - encoding = self.ee(x) + encoding, hidden_out = self.ee(x, hidden=hidden) mu = self.actor(encoding) - return mu + return mu, hidden_out def save_checkpoint(self, path: str, intention: bool = False) -> None: path = path+'_recurrent' @@ -85,19 +89,23 @@ def __init__(self, environmental_encoder, ddpg_critic): self.critic.device = self.device self.optimizer = optim.Adam(self.parameters(), lr = ddpg_critic.lr, weight_decay = 1e-4) - def forward(self, state: T.Tensor, action: T.Tensor) -> T.Tensor: + def forward(self, state, action, hidden=None): """Encode state through LSTM, then compute Q-value via DDPG critic. Args: - state: Observation tensor of shape (seq_len, input_size). + state: Observation tensor of shape (seq_len, input_size) or + (batch, seq_len, input_size). action: Action tensor of shape (seq_len, action_dim). + hidden: Optional (h_0, c_0) tuple for the LSTM encoder. Returns: - Q-value tensor of shape (seq_len, 1, 1). + Tuple of (action_value, (h_n, c_n)): + action_value: Q-value tensor. + (h_n, c_n): Final LSTM hidden state. """ - encoding = self.ee(state) + encoding, hidden_out = self.ee(state, hidden=hidden) action_value = self.critic(encoding, action) - return action_value + return action_value, hidden_out def save_checkpoint(self, path: str, intention: bool = False) -> None: path = path+'_recurrent' diff --git a/tests/test_network_input_outputs/test_GSP-RDDPG_input_output.py b/tests/test_network_input_outputs/test_GSP-RDDPG_input_output.py index 402829b..6dbec9b 100644 --- a/tests/test_network_input_outputs/test_GSP-RDDPG_input_output.py +++ b/tests/test_network_input_outputs/test_GSP-RDDPG_input_output.py @@ -60,7 +60,7 @@ def test_actor_forward(): actor = DDPGActorNetwork(**ddpg_actor_nn_args) rddpg_actor = RDDPGActorNetwork(ee, actor) random_observation = T.rand((lstm_nn_args['batch_size'], lstm_nn_args['input_size'])).to(rddpg_actor.device) - output = rddpg_actor(random_observation) + output, _ = rddpg_actor(random_observation) assert(output.shape[-1] == ddpg_actor_nn_args['output_size']) def test_building_critic_network(): @@ -87,6 +87,6 @@ def test_critic_forward(): rddpg_actor = RDDPGActorNetwork(ee, actor) rddpg_critic = RDDPGCriticNetwork(ee, critic) random_observation = T.rand((lstm_nn_args['batch_size'], lstm_nn_args['input_size'])).to(rddpg_critic.device) - action = rddpg_actor(random_observation) - value = rddpg_critic(random_observation, action) + action, _ = rddpg_actor(random_observation) + value, _ = rddpg_critic(random_observation, action) assert(value.shape[-1] == 1) diff --git a/tests/test_network_input_outputs/test_LSTM_input_output.py b/tests/test_network_input_outputs/test_LSTM_input_output.py index fec076b..0156219 100644 --- a/tests/test_network_input_outputs/test_LSTM_input_output.py +++ b/tests/test_network_input_outputs/test_LSTM_input_output.py @@ -32,4 +32,5 @@ def test_actor_forward(): ee = EnvironmentEncoder(**lstm_nn_args) testing_data = [T.randn((lstm_nn_args['input_size'])) for _ in range(10)] testing_data = T.tensor(np.array(testing_data)).to(ee.device) - assert(ee(testing_data).shape[-1] == lstm_nn_args['output_size']) \ No newline at end of file + output, _ = ee(testing_data) + assert(output.shape[-1] == lstm_nn_args['output_size']) \ No newline at end of file diff --git a/tests/test_network_input_outputs/test_RDDPG_input_output.py b/tests/test_network_input_outputs/test_RDDPG_input_output.py index 89f6c62..be58ea2 100644 --- a/tests/test_network_input_outputs/test_RDDPG_input_output.py +++ b/tests/test_network_input_outputs/test_RDDPG_input_output.py @@ -60,7 +60,8 @@ def test_actor_forward(): rddpg_actor = RDDPGActorNetwork(ee, actor) testing_data = [T.randn((lstm_nn_args['input_size'])) for _ in range(10)] testing_data = T.tensor(np.array(testing_data)).to(rddpg_actor.device) - assert(rddpg_actor(testing_data).shape[-1] == ddpg_actor_nn_args['output_size']) + output, _ = rddpg_actor(testing_data) + assert(output.shape[-1] == ddpg_actor_nn_args['output_size']) def test_building_critic_network(): ee = EnvironmentEncoder(**lstm_nn_args) @@ -87,7 +88,7 @@ def test_critic_forward(): rddpg_critic = RDDPGCriticNetwork(ee, critic) testing_data = [T.randn((lstm_nn_args['input_size'])) for _ in range(10)] testing_data = T.tensor(np.array(testing_data)).to(rddpg_critic.device) - action = rddpg_actor(testing_data) - value = rddpg_critic(testing_data, action) + action, _ = rddpg_actor(testing_data) + value, _ = rddpg_critic(testing_data, action) assert(value.shape[-1] == 1) \ No newline at end of file diff --git a/tests/test_network_input_outputs/test_lstm_hidden_state.py b/tests/test_network_input_outputs/test_lstm_hidden_state.py new file mode 100644 index 0000000..9f746da --- /dev/null +++ b/tests/test_network_input_outputs/test_lstm_hidden_state.py @@ -0,0 +1,71 @@ +"""Tests for EnvironmentEncoder hidden state management.""" + +import torch as T +import numpy as np +import pytest + +from gsp_rl.src.networks.lstm import EnvironmentEncoder + + +@pytest.fixture +def encoder(): + return EnvironmentEncoder( + input_size=6, output_size=1, hidden_size=32, + embedding_size=32, batch_size=8, num_layers=2, lr=0.001 + ) + + +class TestHiddenStateAPI: + def test_forward_returns_output_and_hidden(self, encoder): + x = T.randn(5, 6).to(encoder.device) + result = encoder(x) + assert isinstance(result, tuple) + assert len(result) == 2 + output, (h_n, c_n) = result + assert output.shape == (5, 1) + assert h_n.shape == (2, 1, 32) # (layers, batch=1, hidden) + assert c_n.shape == (2, 1, 32) + + def test_forward_with_hidden_differs_from_zeros(self, encoder): + x = T.randn(5, 6).to(encoder.device) + out_zero, _ = encoder(x) + h_0 = T.randn(2, 1, 32).to(encoder.device) + c_0 = T.randn(2, 1, 32).to(encoder.device) + out_hidden, _ = encoder(x, hidden=(h_0, c_0)) + assert not T.allclose(out_zero, out_hidden) + + def test_hidden_carries_across_calls(self, encoder): + x1 = T.randn(5, 6).to(encoder.device) + x2 = T.randn(5, 6).to(encoder.device) + _, (h1, c1) = encoder(x1) + out_carried, _ = encoder(x2, hidden=(h1, c1)) + out_fresh, _ = encoder(x2) + assert not T.allclose(out_carried, out_fresh) + + def test_batch_forward(self, encoder): + x = T.randn(4, 5, 6).to(encoder.device) + output, (h_n, c_n) = encoder(x) + assert output.shape == (4, 5, 1) + assert h_n.shape[1] == 4 + + def test_batch_with_hidden(self, encoder): + x = T.randn(4, 5, 6).to(encoder.device) + h_0 = T.randn(2, 4, 32).to(encoder.device) + c_0 = T.randn(2, 4, 32).to(encoder.device) + output, (h_n, c_n) = encoder(x, hidden=(h_0, c_0)) + assert output.shape == (4, 5, 1) + assert h_n.shape == (2, 4, 32) + + def test_backward_works(self, encoder): + x = T.randn(5, 6).to(encoder.device) + output, _ = encoder(x) + loss = output.sum() + loss.backward() + + def test_backward_with_hidden(self, encoder): + x = T.randn(4, 5, 6).to(encoder.device) + h_0 = T.randn(2, 4, 32).to(encoder.device) + c_0 = T.randn(2, 4, 32).to(encoder.device) + output, _ = encoder(x, hidden=(h_0, c_0)) + loss = output.sum() + loss.backward() From b2d232e0838d3d9cfe1d1cb123796a2cf5ddbe68 Mon Sep 17 00:00:00 2001 From: Joshua Bloom Date: Wed, 8 Apr 2026 15:22:25 -0400 Subject: [PATCH 2/4] feat(buffer): store LSTM hidden state at sequence boundaries (R2D2-style) Add optional hidden_size/num_layers params to SequenceReplayBuffer so the LSTM (h, c) state at the start of each sequence is stored alongside the SARSD data. sample_buffer returns a 7-tuple when hidden storage is enabled and a 5-tuple otherwise, preserving full backward compatibility. Co-Authored-By: Claude Sonnet 4.6 --- gsp_rl/src/buffers/sequential.py | 46 +++++++++- .../test_sequence_buffer_hidden.py | 92 +++++++++++++++++++ 2 files changed, 137 insertions(+), 1 deletion(-) create mode 100644 tests/test_buffers/test_sequence_buffer_hidden.py diff --git a/gsp_rl/src/buffers/sequential.py b/gsp_rl/src/buffers/sequential.py index f40c69f..c6c7513 100644 --- a/gsp_rl/src/buffers/sequential.py +++ b/gsp_rl/src/buffers/sequential.py @@ -29,7 +29,9 @@ def __init__( max_sequence: int, num_observations: int, num_actions: int, - seq_len: int + seq_len: int, + hidden_size: int = 0, + num_layers: int = 0 ) -> None: """Initialize sequence replay buffer. @@ -38,6 +40,9 @@ def __init__( num_observations: Observation space dimensionality. num_actions: Action space dimensionality. seq_len: Length of each sequence. + hidden_size: LSTM hidden state size. Set > 0 to enable hidden state + storage (R2D2-style). Default 0 disables hidden state storage. + num_layers: Number of LSTM layers. Must be > 0 when hidden_size > 0. """ self.mem_size = max_sequence*seq_len self.num_observations = num_observations @@ -46,6 +51,17 @@ def __init__( self.mem_ctr = 0 self.seq_mem_cntr = 0 + self.hidden_size = hidden_size + self.num_layers = num_layers + self._has_hidden = hidden_size > 0 and num_layers > 0 + + if self._has_hidden: + num_sequences = max_sequence + self.h_memory = np.zeros((num_sequences, num_layers, 1, hidden_size), dtype=np.float32) + self.c_memory = np.zeros((num_sequences, num_layers, 1, hidden_size), dtype=np.float32) + self._pending_h = None + self._pending_c = None + #main buffer used for sampling self.state_memory = np.zeros((self.mem_size, self.num_observations), dtype = np.float64) self.action_memory = np.zeros((self.mem_size, self.num_actions), dtype = np.float64) @@ -60,6 +76,19 @@ def __init__( self.seq_reward_memory = np.zeros((self.seq_len), dtype = np.float64) self.seq_terminal_memory = np.zeros((self.seq_len), dtype = np.bool_) + def set_sequence_hidden(self, h: np.ndarray, c: np.ndarray) -> None: + """Set hidden state for the next sequence to be flushed. + + Call this before the sequence fills up. The stored hidden state + represents the LSTM state at the start of the sequence (R2D2-style). + + Args: + h: Hidden state array of shape (num_layers, 1, hidden_size). + c: Cell state array of shape (num_layers, 1, hidden_size). + """ + self._pending_h = h + self._pending_c = c + def store_transition( self, s: np.ndarray, @@ -81,6 +110,13 @@ def store_transition( self.seq_mem_cntr += 1 if self.seq_mem_cntr == self.seq_len: + seq_index = (self.mem_ctr // self.seq_len) % (self.mem_size // self.seq_len) + # Store hidden state if available + if self._has_hidden and self._pending_h is not None: + self.h_memory[seq_index] = self._pending_h + self.c_memory[seq_index] = self._pending_c + self._pending_h = None + self._pending_c = None #Transfer Seq to main mem and clear seq buffer for i in range(self.seq_len): self.state_memory[mem_index+i] = self.seq_state_memory[i] @@ -122,4 +158,12 @@ def sample_buffer(self, batch_size: int, replace: bool = True) -> list[np.ndarra a[i] = self.action_memory[j:j+self.seq_len] r[i] = self.reward_memory[j:j+self.seq_len] d[i] = self.terminal_memory[j:j+self.seq_len] + if self._has_hidden: + h_batch = np.zeros((batch_size, self.num_layers, 1, self.hidden_size), dtype=np.float32) + c_batch = np.zeros((batch_size, self.num_layers, 1, self.hidden_size), dtype=np.float32) + for i, j in enumerate(samples_indices): + seq_idx = j // self.seq_len + h_batch[i] = self.h_memory[seq_idx % (self.mem_size // self.seq_len)] + c_batch[i] = self.c_memory[seq_idx % (self.mem_size // self.seq_len)] + return s, a, r, s_, d, h_batch, c_batch return s, a, r, s_, d \ No newline at end of file diff --git a/tests/test_buffers/test_sequence_buffer_hidden.py b/tests/test_buffers/test_sequence_buffer_hidden.py new file mode 100644 index 0000000..d52cfd6 --- /dev/null +++ b/tests/test_buffers/test_sequence_buffer_hidden.py @@ -0,0 +1,92 @@ +"""Tests for SequenceReplayBuffer with hidden state storage.""" + +import numpy as np +import pytest +from gsp_rl.src.buffers.sequential import SequenceReplayBuffer + + +class TestHiddenStateStorage: + def test_init_with_hidden(self): + buf = SequenceReplayBuffer( + max_sequence=5, num_observations=4, num_actions=2, + seq_len=3, hidden_size=32, num_layers=2 + ) + assert buf._has_hidden is True + assert buf.h_memory.shape == (5, 2, 1, 32) + + def test_init_without_hidden_backward_compat(self): + buf = SequenceReplayBuffer( + max_sequence=5, num_observations=4, num_actions=2, seq_len=3 + ) + assert buf._has_hidden is False + + def test_set_and_flush_hidden(self): + buf = SequenceReplayBuffer( + max_sequence=5, num_observations=4, num_actions=2, + seq_len=3, hidden_size=32, num_layers=2 + ) + h = np.ones((2, 1, 32), dtype=np.float32) * 42 + c = np.ones((2, 1, 32), dtype=np.float32) * 7 + buf.set_sequence_hidden(h, c) + for i in range(3): + buf.store_transition(np.ones(4)*i, np.ones(2), float(i), np.ones(4), False) + assert buf.mem_ctr == 3 + np.testing.assert_array_equal(buf.h_memory[0], h) + np.testing.assert_array_equal(buf.c_memory[0], c) + + def test_sample_returns_7_with_hidden(self): + buf = SequenceReplayBuffer( + max_sequence=5, num_observations=4, num_actions=2, + seq_len=3, hidden_size=32, num_layers=2 + ) + for seq in range(3): + h = np.ones((2, 1, 32), dtype=np.float32) * seq + c = np.ones((2, 1, 32), dtype=np.float32) * seq + buf.set_sequence_hidden(h, c) + for i in range(3): + buf.store_transition(np.ones(4), np.ones(2), 0.0, np.ones(4), False) + result = buf.sample_buffer(2) + assert len(result) == 7 + s, a, r, s_, d, h_batch, c_batch = result + assert h_batch.shape == (2, 2, 1, 32) + assert c_batch.shape == (2, 2, 1, 32) + + def test_sample_returns_5_without_hidden(self): + buf = SequenceReplayBuffer( + max_sequence=5, num_observations=4, num_actions=2, seq_len=3 + ) + for i in range(6): + buf.store_transition(np.ones(4), np.ones(2), 0.0, np.ones(4), False) + result = buf.sample_buffer(1) + assert len(result) == 5 + + def test_hidden_values_match_stored(self): + buf = SequenceReplayBuffer( + max_sequence=10, num_observations=4, num_actions=2, + seq_len=3, hidden_size=16, num_layers=1 + ) + # Store 3 sequences with distinct hidden states + for seq in range(3): + h = np.ones((1, 1, 16), dtype=np.float32) * (seq + 1) + c = np.ones((1, 1, 16), dtype=np.float32) * (seq + 1) * 10 + buf.set_sequence_hidden(h, c) + for i in range(3): + buf.store_transition(np.ones(4) * seq, np.ones(2), 0.0, np.ones(4), False) + # Sample all 3 + s, a, r, s_, d, h_batch, c_batch = buf.sample_buffer(3) + # Each h_batch entry should be one of {1, 2, 3} + for i in range(3): + val = h_batch[i][0][0][0] + assert val in [1.0, 2.0, 3.0] + + def test_no_hidden_set_stores_zeros(self): + buf = SequenceReplayBuffer( + max_sequence=5, num_observations=4, num_actions=2, + seq_len=3, hidden_size=16, num_layers=1 + ) + # Don't call set_sequence_hidden — hidden should be zeros + # Store 2 complete sequences so sample_buffer has a valid pool + for i in range(6): + buf.store_transition(np.ones(4), np.ones(2), 0.0, np.ones(4), False) + s, a, r, s_, d, h_batch, c_batch = buf.sample_buffer(1) + np.testing.assert_array_equal(h_batch[0], np.zeros((1, 1, 16))) From 1f487309eb0f9c0bd90466337ee41b1c25d620b4 Mon Sep 17 00:00:00 2001 From: Joshua Bloom Date: Wed, 8 Apr 2026 15:26:33 -0400 Subject: [PATCH 3/4] feat(rddpg): vectorize learn_RDDPG with burn-in for 10x speedup Replace per-sample loop in learn_RDDPG with batched implementation using R2D2-style burn-in. Splits sequences into burn-in prefix (first half) and training suffix (second half), refreshes LSTM hidden state during burn-in with no_grad, then computes critic/actor loss on the last timestep of the training suffix. Also updates sample_memory to handle 7-value returns from SequenceReplayBuffer (with hidden states) and passes hidden_size/num_layers to SequenceReplayBuffer in build_gsp_network. Co-Authored-By: Claude Opus 4.6 (1M context) --- gsp_rl/src/actors/actor.py | 2 +- gsp_rl/src/actors/learning_aids.py | 146 +++++++++++++++++++---------- tests/test_rddpg_vectorized.py | 74 +++++++++++++++ 3 files changed, 170 insertions(+), 52 deletions(-) create mode 100644 tests/test_rddpg_vectorized.py diff --git a/gsp_rl/src/actors/actor.py b/gsp_rl/src/actors/actor.py index 947e5d4..124999c 100644 --- a/gsp_rl/src/actors/actor.py +++ b/gsp_rl/src/actors/actor.py @@ -251,7 +251,7 @@ def build_gsp_network(self, learning_scheme:str | None =None): self.gsp_networks['learning_scheme'] = 'RDDPG' self.gsp_networks['output_size'] = self.gsp_network_output #self.gsp_networks['replay'] = ReplayBuffer(self.mem_size, self.gsp_network_input, 1, 'Continuous', use_gsp = True) - self.gsp_networks['replay'] = SequenceReplayBuffer(self.mem_size, self.gsp_network_input, self.gsp_network_output, self.gsp_sequence_length) + self.gsp_networks['replay'] = SequenceReplayBuffer(self.mem_size, self.gsp_network_input, self.gsp_network_output, self.gsp_sequence_length, hidden_size=self.recurrent_hidden_size, num_layers=self.recurrent_num_layers) #SequenceReplayBuffer(max_sequence=100, num_observations = self.gsp_network_input, num_actions = 1, seq_len = 5) self.gsp_networks['learn_step_counter'] = 0 else: diff --git a/gsp_rl/src/actors/learning_aids.py b/gsp_rl/src/actors/learning_aids.py index 25bf99e..63b8969 100644 --- a/gsp_rl/src/actors/learning_aids.py +++ b/gsp_rl/src/actors/learning_aids.py @@ -306,50 +306,84 @@ def learn_DDPG(self, networks, gsp = False, recurrent = False): return actor_loss.item() def learn_RDDPG(self, networks, gsp = False, recurrent = False): - s, a, r, s_, d = self.sample_memory(networks) - batch_loss = 0 - # sample_memory always uses self.batch_size, so loop must match - batch_size = s.shape[0] - for batch in range(batch_size): - states = s[batch] - actions = a[batch] - rewards = r[batch] - states_ = s_[batch] - dones = d[batch] - if not recurrent: - actions = actions.unsqueeze(1) - elif recurrent: - actions = actions.view(actions.shape[0], 1, actions.shape[1]) - target_actions, _ = networks['target_actor'](states_) - q_value_, _ = networks['target_critic'](states_, target_actions) - # print('[REWARDS]', rewards.shape, T.unsqueeze(rewards, 1).shape) - # print('[Q_VALUE_]', q_value_.shape, T.squeeze(T.squeeze(q_value_, -1), -1).shape) - target = T.unsqueeze(rewards, 1) + self.gamma*T.squeeze(q_value_, -1) - # print(target.shape) - - #Critic Update - networks['critic'].optimizer.zero_grad() - q_value, _ = networks['critic'](states, actions) - # print('[Q_VALUE]', q_value.shape) - # print('[TARGET]', target.shape) - value_loss = Loss(T.squeeze(q_value, -1), target) - value_loss.backward() - networks['critic'].optimizer.step() - - #Actor Update - networks['actor'].optimizer.zero_grad() - - new_policy_actions, _ = networks['actor'](states) - actor_loss_val, _ = networks['critic'](states, new_policy_actions) - actor_loss = -actor_loss_val - actor_loss = actor_loss.mean() - actor_loss.backward() - batch_loss += actor_loss.item() - networks['actor'].optimizer.step() - - networks['learn_step_counter'] += 1 - - return batch_loss + mem_result = self.sample_memory(networks) + if len(mem_result) == 7: + states, actions, rewards, states_, dones, h_batch, c_batch = mem_result + device = networks['actor'].device + h_0 = T.tensor(np.array(h_batch), dtype=T.float32).to(device) + c_0 = T.tensor(np.array(c_batch), dtype=T.float32).to(device) + # h_batch shape: (batch, num_layers, 1, hidden) -> (num_layers, batch, hidden) + h_0 = h_0.squeeze(2).permute(1, 0, 2).contiguous() + c_0 = c_0.squeeze(2).permute(1, 0, 2).contiguous() + hidden_init = (h_0, c_0) + else: + states, actions, rewards, states_, dones = mem_result + hidden_init = None + + # states: (batch, seq_len, obs_dim) + # actions: (batch, seq_len, act_dim) + seq_len = states.shape[1] + burn_in_len = seq_len // 2 + train_len = seq_len - burn_in_len + + # Split into burn-in and training portions + burn_states = states[:, :burn_in_len, :] + train_states = states[:, burn_in_len:, :] + burn_states_ = states_[:, :burn_in_len, :] + train_states_ = states_[:, burn_in_len:, :] + train_actions = actions[:, burn_in_len:, :] + train_rewards = rewards[:, burn_in_len:] + + # Burn-in: refresh hidden state without gradients + with T.no_grad(): + if burn_in_len > 0: + # Run burn-in through actor encoder to get hidden state + _, actor_hidden = networks['actor'].ee(burn_states, hidden=hidden_init) + _, critic_hidden = networks['critic'].ee(burn_states, hidden=hidden_init) + _, target_actor_hidden = networks['target_actor'].ee(burn_states_, hidden=hidden_init) + _, target_critic_hidden = networks['target_critic'].ee(burn_states_, hidden=hidden_init) + else: + actor_hidden = hidden_init + critic_hidden = hidden_init + target_actor_hidden = hidden_init + target_critic_hidden = hidden_init + + # Detach hidden states so burn-in gradients don't flow + if actor_hidden is not None: + actor_hidden = (actor_hidden[0].detach(), actor_hidden[1].detach()) + critic_hidden = (critic_hidden[0].detach(), critic_hidden[1].detach()) + target_actor_hidden = (target_actor_hidden[0].detach(), target_actor_hidden[1].detach()) + target_critic_hidden = (target_critic_hidden[0].detach(), target_critic_hidden[1].detach()) + + # Target computation (no gradients) + with T.no_grad(): + target_actions, _ = networks['target_actor'](train_states_, hidden=target_actor_hidden) + q_value_, _ = networks['target_critic'](train_states_, target_actions, hidden=target_critic_hidden) + # Use last timestep for Bellman target + q_last_ = q_value_[:, -1, :] # (batch, 1) + r_last = train_rewards[:, -1] # (batch,) + target = r_last.unsqueeze(1) + self.gamma * q_last_ # (batch, 1) + + # Critic update + networks['critic'].optimizer.zero_grad() + q_value, _ = networks['critic'](train_states, train_actions, hidden=critic_hidden) + q_last = q_value[:, -1, :] # (batch, 1) + value_loss = Loss(q_last, target) + value_loss.backward() + networks['critic'].optimizer.step() + + # Actor update + networks['actor'].optimizer.zero_grad() + new_policy_actions, _ = networks['actor'](train_states, hidden=actor_hidden) + # Re-run critic with fresh hidden (detached) for actor loss + actor_q_val, _ = networks['critic'](train_states, new_policy_actions, hidden=critic_hidden) + actor_loss = -actor_q_val[:, -1, :].mean() + actor_loss.backward() + networks['actor'].optimizer.step() + + networks['learn_step_counter'] += 1 + + return actor_loss.item() def learn_TD3(self, networks, gsp = False): states, actions, rewards, states_, dones = self.sample_memory(networks) @@ -425,18 +459,28 @@ def store_attention_transition(self, s, y, networks): networks['replay'].store_transition(s, y) def sample_memory(self, networks): - states, actions, rewards, states_, dones = networks['replay'].sample_buffer(self.batch_size) + result = networks['replay'].sample_buffer(self.batch_size) if networks['learning_scheme'] in {'DQN', 'DDQN'}: device = networks['q_eval'].device elif networks['learning_scheme'] in {'DDPG', 'RDDPG', 'TD3'}: device = networks['actor'].device - states = T.tensor(states, dtype=T.float32).to(device) - actions = T.tensor(actions, dtype=T.float32).to(device) - rewards = T.tensor(rewards, dtype=T.float32).to(device) - states_ = T.tensor(states_, dtype=T.float32).to(device) - dones = T.tensor(dones).to(device) - return states, actions, rewards, states_, dones + if len(result) == 7: + states, actions, rewards, states_, dones, h_batch, c_batch = result + states = T.tensor(states, dtype=T.float32).to(device) + actions = T.tensor(actions, dtype=T.float32).to(device) + rewards = T.tensor(rewards, dtype=T.float32).to(device) + states_ = T.tensor(states_, dtype=T.float32).to(device) + dones = T.tensor(dones).to(device) + return states, actions, rewards, states_, dones, h_batch, c_batch + else: + states, actions, rewards, states_, dones = result + states = T.tensor(states, dtype=T.float32).to(device) + actions = T.tensor(actions, dtype=T.float32).to(device) + rewards = T.tensor(rewards, dtype=T.float32).to(device) + states_ = T.tensor(states_, dtype=T.float32).to(device) + dones = T.tensor(dones).to(device) + return states, actions, rewards, states_, dones def sample_attention_memory(self, networks): observations, labels = networks['replay'].sample_buffer(self.batch_size) diff --git a/tests/test_rddpg_vectorized.py b/tests/test_rddpg_vectorized.py new file mode 100644 index 0000000..5d6cdbc --- /dev/null +++ b/tests/test_rddpg_vectorized.py @@ -0,0 +1,74 @@ +"""Tests for vectorized learn_RDDPG.""" + +import numpy as np +import torch as T +import time +import pytest + +from gsp_rl.src.actors.actor import Actor + + +@pytest.fixture +def config(): + return { + "GAMMA": 0.99, "TAU": 0.005, "ALPHA": 0.001, "BETA": 0.001, + "LR": 0.001, "EPSILON": 0.0, "EPS_MIN": 0.0, "EPS_DEC": 0.0, + "BATCH_SIZE": 8, "MEM_SIZE": 1000, "REPLACE_TARGET_COUNTER": 10, + "NOISE": 0.0, "UPDATE_ACTOR_ITER": 1, "WARMUP": 0, + "GSP_LEARNING_FREQUENCY": 10, "GSP_BATCH_SIZE": 8, + } + + +def _fill_gsp_buffer(actor, n=100): + for _ in range(n): + gs = np.random.randn(6).astype(np.float32) + ga = np.random.randn(1).astype(np.float32) + actor.store_transition(gs, ga, float(np.random.randn()), + np.random.randn(6).astype(np.float32), False, + actor.gsp_networks) + + +class TestVectorizedRDDPG: + def test_learn_completes(self, config): + actor = Actor(id=1, config=config, network="DDPG", + input_size=8, output_size=2, min_max_action=1.0, meta_param_size=1, + gsp=True, recurrent_gsp=True, gsp_input_size=6, gsp_output_size=1, + gsp_sequence_length=10, recurrent_hidden_size=32, recurrent_num_layers=2) + _fill_gsp_buffer(actor, 200) + loss = actor.learn_RDDPG(actor.gsp_networks, gsp=True, recurrent=True) + assert np.isfinite(loss) + + def test_weights_change(self, config): + actor = Actor(id=1, config=config, network="DDPG", + input_size=8, output_size=2, min_max_action=1.0, meta_param_size=1, + gsp=True, recurrent_gsp=True, gsp_input_size=6, gsp_output_size=1, + gsp_sequence_length=10, recurrent_hidden_size=32, recurrent_num_layers=2) + _fill_gsp_buffer(actor, 200) + before = T.cat([p.data.flatten().clone() for p in actor.gsp_networks['actor'].parameters()]) + actor.learn_RDDPG(actor.gsp_networks, gsp=True, recurrent=True) + after = T.cat([p.data.flatten().clone() for p in actor.gsp_networks['actor'].parameters()]) + assert not T.equal(before, after) + + def test_multiple_steps_no_crash(self, config): + actor = Actor(id=1, config=config, network="DDPG", + input_size=8, output_size=2, min_max_action=1.0, meta_param_size=1, + gsp=True, recurrent_gsp=True, gsp_input_size=6, gsp_output_size=1, + gsp_sequence_length=10, recurrent_hidden_size=32, recurrent_num_layers=2) + _fill_gsp_buffer(actor, 200) + for _ in range(20): + actor.learn_RDDPG(actor.gsp_networks, gsp=True, recurrent=True) + + def test_speed_improvement(self, config): + """Should be under 200ms per call (was 1500ms).""" + actor = Actor(id=1, config=config, network="DDPG", + input_size=8, output_size=2, min_max_action=1.0, meta_param_size=1, + gsp=True, recurrent_gsp=True, gsp_input_size=6, gsp_output_size=1, + gsp_sequence_length=10, recurrent_hidden_size=32, recurrent_num_layers=2) + _fill_gsp_buffer(actor, 200) + # Warm up + actor.learn_RDDPG(actor.gsp_networks, gsp=True, recurrent=True) + t0 = time.perf_counter() + for _ in range(10): + actor.learn_RDDPG(actor.gsp_networks, gsp=True, recurrent=True) + elapsed = (time.perf_counter() - t0) / 10 * 1000 + assert elapsed < 200, f"learn_RDDPG took {elapsed:.1f}ms (should be <200ms)" From 207480511f3db142bdbf53d82cb3d7b292d8af29 Mon Sep 17 00:00:00 2001 From: Joshua Bloom Date: Wed, 8 Apr 2026 19:08:25 -0400 Subject: [PATCH 4/4] =?UTF-8?q?feat(device):=20remove=20MPS=20CPU=20fallba?= =?UTF-8?q?ck=20for=20recurrent=20=E2=80=94=20vectorized=20RDDPG=20works?= =?UTF-8?q?=20on=20MPS=20(80ms=20vs=20150ms=20CPU)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- gsp_rl/src/networks/__init__.py | 9 +++------ gsp_rl/src/networks/rddpg.py | 4 ++-- tests/test_device/test_device_detection.py | 4 +--- 3 files changed, 6 insertions(+), 11 deletions(-) diff --git a/gsp_rl/src/networks/__init__.py b/gsp_rl/src/networks/__init__.py index acc7971..f55edd6 100644 --- a/gsp_rl/src/networks/__init__.py +++ b/gsp_rl/src/networks/__init__.py @@ -6,16 +6,13 @@ def get_device(recurrent: bool = False) -> T.device: """Auto-detect the best available device: cuda > mps > cpu. Args: - recurrent: If True, indicates the network uses LSTM or attention. - On macOS MPS, these fall back to CPU due to PyTorch MPS - backend bugs with repeated LSTM backward passes. - On CUDA (Linux/Windows), recurrent networks use GPU normally. + recurrent: Previously used to force CPU fallback for LSTM/attention on MPS. + No longer needed — vectorized RDDPG works on MPS. + Kept for API compatibility but ignored. """ if T.cuda.is_available(): return T.device("cuda:0") elif T.backends.mps.is_available(): - if recurrent: - return T.device("cpu") return T.device("mps") else: return T.device("cpu") diff --git a/gsp_rl/src/networks/rddpg.py b/gsp_rl/src/networks/rddpg.py index 177480b..a7f19ce 100644 --- a/gsp_rl/src/networks/rddpg.py +++ b/gsp_rl/src/networks/rddpg.py @@ -34,7 +34,7 @@ def __init__(self, environmental_encoder, ddpg_actor): super().__init__() self.ee = environmental_encoder self.actor = ddpg_actor - # Use encoder's device (CPU on MPS due to LSTM fallback, GPU on CUDA) + # Use encoder's device — ensures all components on same device self.device = self.ee.device self.actor.to(self.device) self.actor.device = self.device @@ -83,7 +83,7 @@ def __init__(self, environmental_encoder, ddpg_critic): super().__init__() self.ee = environmental_encoder self.critic = ddpg_critic - # Use encoder's device (CPU on MPS due to LSTM fallback, GPU on CUDA) + # Use encoder's device — ensures all components on same device self.device = self.ee.device self.critic.to(self.device) self.critic.device = self.device diff --git a/tests/test_device/test_device_detection.py b/tests/test_device/test_device_detection.py index db330c2..0c26050 100644 --- a/tests/test_device/test_device_detection.py +++ b/tests/test_device/test_device_detection.py @@ -22,9 +22,7 @@ def _expected_device(recurrent=False): if T.cuda.is_available(): return "cuda" elif T.backends.mps.is_available(): - if recurrent: - return "cpu" # MPS fallback for LSTM/attention - return "mps" + return "mps" # MPS works for all networks including LSTM (vectorized) else: return "cpu"