Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion gsp_rl/src/actors/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def build_gsp_network(self, learning_scheme:str | None =None):
self.gsp_networks['learning_scheme'] = 'RDDPG'
self.gsp_networks['output_size'] = self.gsp_network_output
#self.gsp_networks['replay'] = ReplayBuffer(self.mem_size, self.gsp_network_input, 1, 'Continuous', use_gsp = True)
self.gsp_networks['replay'] = SequenceReplayBuffer(self.mem_size, self.gsp_network_input, self.gsp_network_output, self.gsp_sequence_length)
self.gsp_networks['replay'] = SequenceReplayBuffer(self.mem_size, self.gsp_network_input, self.gsp_network_output, self.gsp_sequence_length, hidden_size=self.recurrent_hidden_size, num_layers=self.recurrent_num_layers)
#SequenceReplayBuffer(max_sequence=100, num_observations = self.gsp_network_input, num_actions = 1, seq_len = 5)
self.gsp_networks['learn_step_counter'] = 0
else:
Expand Down
149 changes: 98 additions & 51 deletions gsp_rl/src/actors/learning_aids.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,11 @@ def DDPG_choose_action(self, observation, networks):
if networks['learning_scheme'] == 'RDDPG':
# if using LSTM we need to add an extra dimension
state = T.tensor(np.array(observation), dtype=T.float).to(networks['actor'].device)
mu, _ = networks['actor'].forward(state)
return mu.unsqueeze(0)
else:
state = T.tensor(observation, dtype = T.float).to(networks['actor'].device)
return networks['actor'].forward(state).unsqueeze(0)
return networks['actor'].forward(state).unsqueeze(0)


def DDPG_choose_action_batch(self, observations, networks):
Expand Down Expand Up @@ -304,49 +306,84 @@ def learn_DDPG(self, networks, gsp = False, recurrent = False):
return actor_loss.item()

def learn_RDDPG(self, networks, gsp = False, recurrent = False):
s, a, r, s_, d = self.sample_memory(networks)
batch_loss = 0
# sample_memory always uses self.batch_size, so loop must match
batch_size = s.shape[0]
for batch in range(batch_size):
states = s[batch]
actions = a[batch]
rewards = r[batch]
states_ = s_[batch]
dones = d[batch]
if not recurrent:
actions = actions.unsqueeze(1)
elif recurrent:
actions = actions.view(actions.shape[0], 1, actions.shape[1])
target_actions = networks['target_actor'](states_)
q_value_ = networks['target_critic'](states_, target_actions)
# print('[REWARDS]', rewards.shape, T.unsqueeze(rewards, 1).shape)
# print('[Q_VALUE_]', q_value_.shape, T.squeeze(T.squeeze(q_value_, -1), -1).shape)
target = T.unsqueeze(rewards, 1) + self.gamma*T.squeeze(q_value_, -1)
# print(target.shape)

#Critic Update
networks['critic'].optimizer.zero_grad()
q_value = networks['critic'](states, actions)
# print('[Q_VALUE]', q_value.shape)
# print('[TARGET]', target.shape)
value_loss = Loss(T.squeeze(q_value, -1), target)
value_loss.backward()
networks['critic'].optimizer.step()

#Actor Update
networks['actor'].optimizer.zero_grad()

new_policy_actions = networks['actor'](states)
actor_loss = -networks['critic'](states, new_policy_actions)
actor_loss = actor_loss.mean()
actor_loss.backward()
batch_loss += actor_loss.item()
networks['actor'].optimizer.step()

networks['learn_step_counter'] += 1

return batch_loss
mem_result = self.sample_memory(networks)
if len(mem_result) == 7:
states, actions, rewards, states_, dones, h_batch, c_batch = mem_result
device = networks['actor'].device
h_0 = T.tensor(np.array(h_batch), dtype=T.float32).to(device)
c_0 = T.tensor(np.array(c_batch), dtype=T.float32).to(device)
# h_batch shape: (batch, num_layers, 1, hidden) -> (num_layers, batch, hidden)
h_0 = h_0.squeeze(2).permute(1, 0, 2).contiguous()
c_0 = c_0.squeeze(2).permute(1, 0, 2).contiguous()
hidden_init = (h_0, c_0)
else:
states, actions, rewards, states_, dones = mem_result
hidden_init = None

# states: (batch, seq_len, obs_dim)
# actions: (batch, seq_len, act_dim)
seq_len = states.shape[1]
burn_in_len = seq_len // 2
train_len = seq_len - burn_in_len

# Split into burn-in and training portions
burn_states = states[:, :burn_in_len, :]
train_states = states[:, burn_in_len:, :]
burn_states_ = states_[:, :burn_in_len, :]
train_states_ = states_[:, burn_in_len:, :]
train_actions = actions[:, burn_in_len:, :]
train_rewards = rewards[:, burn_in_len:]

# Burn-in: refresh hidden state without gradients
with T.no_grad():
if burn_in_len > 0:
# Run burn-in through actor encoder to get hidden state
_, actor_hidden = networks['actor'].ee(burn_states, hidden=hidden_init)
_, critic_hidden = networks['critic'].ee(burn_states, hidden=hidden_init)
_, target_actor_hidden = networks['target_actor'].ee(burn_states_, hidden=hidden_init)
_, target_critic_hidden = networks['target_critic'].ee(burn_states_, hidden=hidden_init)
else:
actor_hidden = hidden_init
critic_hidden = hidden_init
target_actor_hidden = hidden_init
target_critic_hidden = hidden_init

# Detach hidden states so burn-in gradients don't flow
if actor_hidden is not None:
actor_hidden = (actor_hidden[0].detach(), actor_hidden[1].detach())
critic_hidden = (critic_hidden[0].detach(), critic_hidden[1].detach())
target_actor_hidden = (target_actor_hidden[0].detach(), target_actor_hidden[1].detach())
target_critic_hidden = (target_critic_hidden[0].detach(), target_critic_hidden[1].detach())

# Target computation (no gradients)
with T.no_grad():
target_actions, _ = networks['target_actor'](train_states_, hidden=target_actor_hidden)
q_value_, _ = networks['target_critic'](train_states_, target_actions, hidden=target_critic_hidden)
# Use last timestep for Bellman target
q_last_ = q_value_[:, -1, :] # (batch, 1)
r_last = train_rewards[:, -1] # (batch,)
target = r_last.unsqueeze(1) + self.gamma * q_last_ # (batch, 1)

# Critic update
networks['critic'].optimizer.zero_grad()
q_value, _ = networks['critic'](train_states, train_actions, hidden=critic_hidden)
q_last = q_value[:, -1, :] # (batch, 1)
value_loss = Loss(q_last, target)
value_loss.backward()
networks['critic'].optimizer.step()

# Actor update
networks['actor'].optimizer.zero_grad()
new_policy_actions, _ = networks['actor'](train_states, hidden=actor_hidden)
# Re-run critic with fresh hidden (detached) for actor loss
actor_q_val, _ = networks['critic'](train_states, new_policy_actions, hidden=critic_hidden)
actor_loss = -actor_q_val[:, -1, :].mean()
actor_loss.backward()
networks['actor'].optimizer.step()

networks['learn_step_counter'] += 1

return actor_loss.item()

def learn_TD3(self, networks, gsp = False):
states, actions, rewards, states_, dones = self.sample_memory(networks)
Expand Down Expand Up @@ -422,18 +459,28 @@ def store_attention_transition(self, s, y, networks):
networks['replay'].store_transition(s, y)

def sample_memory(self, networks):
states, actions, rewards, states_, dones = networks['replay'].sample_buffer(self.batch_size)
result = networks['replay'].sample_buffer(self.batch_size)
if networks['learning_scheme'] in {'DQN', 'DDQN'}:
device = networks['q_eval'].device
elif networks['learning_scheme'] in {'DDPG', 'RDDPG', 'TD3'}:
device = networks['actor'].device
states = T.tensor(states, dtype=T.float32).to(device)
actions = T.tensor(actions, dtype=T.float32).to(device)
rewards = T.tensor(rewards, dtype=T.float32).to(device)
states_ = T.tensor(states_, dtype=T.float32).to(device)
dones = T.tensor(dones).to(device)

return states, actions, rewards, states_, dones
if len(result) == 7:
states, actions, rewards, states_, dones, h_batch, c_batch = result
states = T.tensor(states, dtype=T.float32).to(device)
actions = T.tensor(actions, dtype=T.float32).to(device)
rewards = T.tensor(rewards, dtype=T.float32).to(device)
states_ = T.tensor(states_, dtype=T.float32).to(device)
dones = T.tensor(dones).to(device)
return states, actions, rewards, states_, dones, h_batch, c_batch
else:
states, actions, rewards, states_, dones = result
states = T.tensor(states, dtype=T.float32).to(device)
actions = T.tensor(actions, dtype=T.float32).to(device)
rewards = T.tensor(rewards, dtype=T.float32).to(device)
states_ = T.tensor(states_, dtype=T.float32).to(device)
dones = T.tensor(dones).to(device)
return states, actions, rewards, states_, dones

def sample_attention_memory(self, networks):
observations, labels = networks['replay'].sample_buffer(self.batch_size)
Expand Down
46 changes: 45 additions & 1 deletion gsp_rl/src/buffers/sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ def __init__(
max_sequence: int,
num_observations: int,
num_actions: int,
seq_len: int
seq_len: int,
hidden_size: int = 0,
num_layers: int = 0
) -> None:
"""Initialize sequence replay buffer.

Expand All @@ -38,6 +40,9 @@ def __init__(
num_observations: Observation space dimensionality.
num_actions: Action space dimensionality.
seq_len: Length of each sequence.
hidden_size: LSTM hidden state size. Set > 0 to enable hidden state
storage (R2D2-style). Default 0 disables hidden state storage.
num_layers: Number of LSTM layers. Must be > 0 when hidden_size > 0.
"""
self.mem_size = max_sequence*seq_len
self.num_observations = num_observations
Expand All @@ -46,6 +51,17 @@ def __init__(
self.mem_ctr = 0
self.seq_mem_cntr = 0

self.hidden_size = hidden_size
self.num_layers = num_layers
self._has_hidden = hidden_size > 0 and num_layers > 0

if self._has_hidden:
num_sequences = max_sequence
self.h_memory = np.zeros((num_sequences, num_layers, 1, hidden_size), dtype=np.float32)
self.c_memory = np.zeros((num_sequences, num_layers, 1, hidden_size), dtype=np.float32)
self._pending_h = None
self._pending_c = None

#main buffer used for sampling
self.state_memory = np.zeros((self.mem_size, self.num_observations), dtype = np.float64)
self.action_memory = np.zeros((self.mem_size, self.num_actions), dtype = np.float64)
Expand All @@ -60,6 +76,19 @@ def __init__(
self.seq_reward_memory = np.zeros((self.seq_len), dtype = np.float64)
self.seq_terminal_memory = np.zeros((self.seq_len), dtype = np.bool_)

def set_sequence_hidden(self, h: np.ndarray, c: np.ndarray) -> None:
"""Set hidden state for the next sequence to be flushed.

Call this before the sequence fills up. The stored hidden state
represents the LSTM state at the start of the sequence (R2D2-style).

Args:
h: Hidden state array of shape (num_layers, 1, hidden_size).
c: Cell state array of shape (num_layers, 1, hidden_size).
"""
self._pending_h = h
self._pending_c = c

def store_transition(
self,
s: np.ndarray,
Expand All @@ -81,6 +110,13 @@ def store_transition(
self.seq_mem_cntr += 1

if self.seq_mem_cntr == self.seq_len:
seq_index = (self.mem_ctr // self.seq_len) % (self.mem_size // self.seq_len)
# Store hidden state if available
if self._has_hidden and self._pending_h is not None:
self.h_memory[seq_index] = self._pending_h
self.c_memory[seq_index] = self._pending_c
self._pending_h = None
self._pending_c = None
#Transfer Seq to main mem and clear seq buffer
for i in range(self.seq_len):
self.state_memory[mem_index+i] = self.seq_state_memory[i]
Expand Down Expand Up @@ -122,4 +158,12 @@ def sample_buffer(self, batch_size: int, replace: bool = True) -> list[np.ndarra
a[i] = self.action_memory[j:j+self.seq_len]
r[i] = self.reward_memory[j:j+self.seq_len]
d[i] = self.terminal_memory[j:j+self.seq_len]
if self._has_hidden:
h_batch = np.zeros((batch_size, self.num_layers, 1, self.hidden_size), dtype=np.float32)
c_batch = np.zeros((batch_size, self.num_layers, 1, self.hidden_size), dtype=np.float32)
for i, j in enumerate(samples_indices):
seq_idx = j // self.seq_len
h_batch[i] = self.h_memory[seq_idx % (self.mem_size // self.seq_len)]
c_batch[i] = self.c_memory[seq_idx % (self.mem_size // self.seq_len)]
return s, a, r, s_, d, h_batch, c_batch
return s, a, r, s_, d
9 changes: 3 additions & 6 deletions gsp_rl/src/networks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,13 @@ def get_device(recurrent: bool = False) -> T.device:
"""Auto-detect the best available device: cuda > mps > cpu.

Args:
recurrent: If True, indicates the network uses LSTM or attention.
On macOS MPS, these fall back to CPU due to PyTorch MPS
backend bugs with repeated LSTM backward passes.
On CUDA (Linux/Windows), recurrent networks use GPU normally.
recurrent: Previously used to force CPU fallback for LSTM/attention on MPS.
No longer needed — vectorized RDDPG works on MPS.
Kept for API compatibility but ignored.
"""
if T.cuda.is_available():
return T.device("cuda:0")
elif T.backends.mps.is_available():
if recurrent:
return T.device("cpu")
return T.device("mps")
else:
return T.device("cpu")
Expand Down
44 changes: 32 additions & 12 deletions gsp_rl/src/networks/lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,23 +72,43 @@ def __init__(
self.name = "Enviroment_Encoder"
self.to(self.device)

def forward(
self,
observation: T.Tensor,
) -> T.Tensor:
"""Encode an observation (or sequence) through embedding + LSTM + projection.
def forward(self, observation, hidden=None):
"""Encode observation through embedding + LSTM + projection.

Args:
observation: Tensor of shape (seq_len, input_size) or (batch, input_size).
observation: Tensor of shape (seq_len, input_size) for single sample,
or (batch, seq_len, input_size) for batched input.
hidden: Optional (h_0, c_0) tuple. If None, LSTM uses zeros.
Shape: (num_layers, batch, hidden_size) for each.

Returns:
Encoding tensor of shape (seq_len, 1, output_size). The middle dim=1
comes from the view reshape before LSTM.
Tuple of (output, (h_n, c_n)):
output: Shape (seq_len, output_size) for single,
or (batch, seq_len, output_size) for batched.
(h_n, c_n): Final hidden state.
"""
embed = self.embedding(observation)
lstm_out, _ = self.ee(embed.view(embed.shape[0], 1, -1))
out = self.meta_layer(lstm_out)
return out
# Handle single vs batched input
if observation.dim() == 2:
# Single sample: (seq_len, input) -> add batch dim
observation = observation.unsqueeze(0)
squeeze_batch = True
else:
squeeze_batch = False

# observation is now (batch, seq_len, input_size)
embed = self.embedding(observation) # (batch, seq_len, embed_size)

if hidden is not None:
lstm_out, (h_n, c_n) = self.ee(embed, hidden)
else:
lstm_out, (h_n, c_n) = self.ee(embed)

out = self.meta_layer(lstm_out) # (batch, seq_len, output_size)

if squeeze_batch:
out = out.squeeze(0) # back to (seq_len, output_size)

return out, (h_n, c_n)

def save_checkpoint(self, path: str, intention: bool = False) -> None:
""" Save Model """
Expand Down
Loading
Loading