Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions netsecgame/agents/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(self, host, port, role:str)->None:
self._socket.connect((host, port))
except socket.error as e:
self._logger.error(f"Socket error: {e}")
self.sock = None
self._socket = None
self._logger.info("Agent created")

def __del__(self):
Expand All @@ -32,7 +32,7 @@ def __del__(self):
self._socket.close()
self._logger.info("Socket closed")
except socket.error as e:
print(f"Error closing socket: {e}")
self._logger.error(f"Error closing socket: {e}")

def terminate_connection(self)->None:
"""Method for graceful termination of connection. Should be used by any class extending the BaseAgent."""
Expand All @@ -42,7 +42,7 @@ def terminate_connection(self)->None:
self._socket = None
self._logger.info("Socket closed")
except socket.error as e:
print(f"Error closing socket: {e}")
self._logger.error(f"Error closing socket: {e}")
@property
def socket(self)->socket.socket | None:
return self._socket
Expand Down
34 changes: 31 additions & 3 deletions netsecgame/game/coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from typing import Optional
import signal
import os
import re
import uuid

from netsecgame.game_components import Action, Observation, ActionType, GameStatus, GameState, AgentStatus, AgentRole
from netsecgame.game.global_defender import GlobalDefender
Expand All @@ -25,6 +27,16 @@ def convert_msg_dict_to_json(msg_dict: dict) -> str:
raise TypeError(f"Error when converting msg to JSON:{e}") from e
return output_message

def sanitize_agent_name(name:str)->str:
"""
Sanitizes the agent name to be used as a filename.
"""
safe_name = re.sub(r'[^a-zA-Z0-9_\-]', '_', name)
safe_name = re.sub(r'_+', '_', safe_name)
safe_name = safe_name.strip('_')[:200]
if not safe_name:
return f"agent_{uuid.uuid4().hex[:8]}"
return safe_name

class GameCoordinator:
"""
Expand Down Expand Up @@ -312,10 +324,26 @@ async def run_game(self):
self.logger.info(f"Coordinator received from agent {agent_addr}: {message}.")

action = self._parse_action_message(agent_addr, message)
if action:
if action is not None:
self._dispatch_action(agent_addr, action)
else:
self._spawn_task(self._respond_on_bad_request, agent_addr, "Malformed Action")
self.logger.info("\tAction processing task stopped.")


async def _respond_on_bad_request(self, agent_addr: tuple, message: str)->None:
"""
Sends a response to the agent indicating that the request was bad.
"""
output_message_dict = {
"to_agent": agent_addr,
"status": str(GameStatus.BAD_REQUEST),
"observation": None,
"message": {
"message": f"Bad request received: {message}",
}
}
await self._agent_response_queues[agent_addr].put(convert_msg_dict_to_json(output_message_dict))

async def _process_join_game_action(self, agent_addr: tuple, action: Action)->None:
"""
Method for processing Action of type ActionType.JoinGame
Expand All @@ -327,7 +355,7 @@ async def _process_join_game_action(self, agent_addr: tuple, action: Action)->No
try:
self.logger.info(f"New Join request by {agent_addr}.")
if agent_addr not in self.agents:
agent_name = action.parameters["agent_info"].name
agent_name = sanitize_agent_name(str(action.parameters["agent_info"].name))
agent_role = action.parameters["agent_info"].role
if agent_role in AgentRole:
# add agent to the world
Expand Down
40 changes: 39 additions & 1 deletion tests/game/test_coordinator_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,4 +480,42 @@ async def test_run_game_flow(self, mock_coordinator_core):
await mock_coordinator_core.run_game()

mock_parse.assert_called_once_with(agent_addr, valid_json)
mock_dispatch.assert_called_once_with(agent_addr, mock_action)
mock_dispatch.assert_called_once_with(agent_addr, mock_action)

@pytest.mark.asyncio
async def test_run_game_malformed_action(self, mock_coordinator_core):
"""New test for refactored method: run_game flow with malformed action."""
agent_addr = ("127.0.0.1", 12345)
invalid_json = '{"invalid": "json"}'

# Setup queue
mock_coordinator_core._agent_action_queue.get.return_value = (agent_addr, invalid_json)

with patch.object(mock_coordinator_core, '_parse_action_message') as mock_parse, \
patch.object(mock_coordinator_core, '_spawn_task') as mock_spawn:

mock_parse.return_value = None

await mock_coordinator_core.run_game()

mock_parse.assert_called_once_with(agent_addr, invalid_json)
mock_spawn.assert_called_once_with(mock_coordinator_core._respond_on_bad_request, agent_addr, "Malformed Action")

@pytest.mark.asyncio
async def test_respond_on_bad_request(self, mock_coordinator_core):
"""New test for _respond_on_bad_request."""
mock_coordinator_core._respond_on_bad_request = GameCoordinator._respond_on_bad_request.__get__(mock_coordinator_core)
agent_addr = ("127.0.0.1", 12345)
mock_coordinator_core._agent_response_queues = {agent_addr: asyncio.Queue()}

await mock_coordinator_core._respond_on_bad_request(agent_addr, "Malformed Action")

# Ensure the response is in the queue
assert not mock_coordinator_core._agent_response_queues[agent_addr].empty()

response_json = await mock_coordinator_core._agent_response_queues[agent_addr].get()
response_data = json.loads(response_json)

assert response_data["status"] == str(GameStatus.BAD_REQUEST)
assert response_data["observation"] is None
assert "Malformed Action" in response_data["message"]["message"]