diff --git a/netsecgame/agents/base_agent.py b/netsecgame/agents/base_agent.py index 0c5ad189..1578f870 100644 --- a/netsecgame/agents/base_agent.py +++ b/netsecgame/agents/base_agent.py @@ -22,7 +22,7 @@ def __init__(self, host, port, role:str)->None: self._socket.connect((host, port)) except socket.error as e: self._logger.error(f"Socket error: {e}") - self.sock = None + self._socket = None self._logger.info("Agent created") def __del__(self): @@ -32,7 +32,7 @@ def __del__(self): self._socket.close() self._logger.info("Socket closed") except socket.error as e: - print(f"Error closing socket: {e}") + self._logger.error(f"Error closing socket: {e}") def terminate_connection(self)->None: """Method for graceful termination of connection. Should be used by any class extending the BaseAgent.""" @@ -42,7 +42,7 @@ def terminate_connection(self)->None: self._socket = None self._logger.info("Socket closed") except socket.error as e: - print(f"Error closing socket: {e}") + self._logger.error(f"Error closing socket: {e}") @property def socket(self)->socket.socket | None: return self._socket diff --git a/netsecgame/game/coordinator.py b/netsecgame/game/coordinator.py index 38af94e7..27fdc797 100644 --- a/netsecgame/game/coordinator.py +++ b/netsecgame/game/coordinator.py @@ -5,6 +5,8 @@ from typing import Optional import signal import os +import re +import uuid from netsecgame.game_components import Action, Observation, ActionType, GameStatus, GameState, AgentStatus, AgentRole from netsecgame.game.global_defender import GlobalDefender @@ -25,6 +27,16 @@ def convert_msg_dict_to_json(msg_dict: dict) -> str: raise TypeError(f"Error when converting msg to JSON:{e}") from e return output_message +def sanitize_agent_name(name:str)->str: + """ + Sanitizes the agent name to be used as a filename. + """ + safe_name = re.sub(r'[^a-zA-Z0-9_\-]', '_', name) + safe_name = re.sub(r'_+', '_', safe_name) + safe_name = safe_name.strip('_')[:200] + if not safe_name: + return f"agent_{uuid.uuid4().hex[:8]}" + return safe_name class GameCoordinator: """ @@ -312,10 +324,26 @@ async def run_game(self): self.logger.info(f"Coordinator received from agent {agent_addr}: {message}.") action = self._parse_action_message(agent_addr, message) - if action: + if action is not None: self._dispatch_action(agent_addr, action) + else: + self._spawn_task(self._respond_on_bad_request, agent_addr, "Malformed Action") self.logger.info("\tAction processing task stopped.") - + + async def _respond_on_bad_request(self, agent_addr: tuple, message: str)->None: + """ + Sends a response to the agent indicating that the request was bad. + """ + output_message_dict = { + "to_agent": agent_addr, + "status": str(GameStatus.BAD_REQUEST), + "observation": None, + "message": { + "message": f"Bad request received: {message}", + } + } + await self._agent_response_queues[agent_addr].put(convert_msg_dict_to_json(output_message_dict)) + async def _process_join_game_action(self, agent_addr: tuple, action: Action)->None: """ Method for processing Action of type ActionType.JoinGame @@ -327,7 +355,7 @@ async def _process_join_game_action(self, agent_addr: tuple, action: Action)->No try: self.logger.info(f"New Join request by {agent_addr}.") if agent_addr not in self.agents: - agent_name = action.parameters["agent_info"].name + agent_name = sanitize_agent_name(str(action.parameters["agent_info"].name)) agent_role = action.parameters["agent_info"].role if agent_role in AgentRole: # add agent to the world diff --git a/tests/game/test_coordinator_core.py b/tests/game/test_coordinator_core.py index 9058a6a0..4c44eeee 100644 --- a/tests/game/test_coordinator_core.py +++ b/tests/game/test_coordinator_core.py @@ -480,4 +480,42 @@ async def test_run_game_flow(self, mock_coordinator_core): await mock_coordinator_core.run_game() mock_parse.assert_called_once_with(agent_addr, valid_json) - mock_dispatch.assert_called_once_with(agent_addr, mock_action) \ No newline at end of file + mock_dispatch.assert_called_once_with(agent_addr, mock_action) + + @pytest.mark.asyncio + async def test_run_game_malformed_action(self, mock_coordinator_core): + """New test for refactored method: run_game flow with malformed action.""" + agent_addr = ("127.0.0.1", 12345) + invalid_json = '{"invalid": "json"}' + + # Setup queue + mock_coordinator_core._agent_action_queue.get.return_value = (agent_addr, invalid_json) + + with patch.object(mock_coordinator_core, '_parse_action_message') as mock_parse, \ + patch.object(mock_coordinator_core, '_spawn_task') as mock_spawn: + + mock_parse.return_value = None + + await mock_coordinator_core.run_game() + + mock_parse.assert_called_once_with(agent_addr, invalid_json) + mock_spawn.assert_called_once_with(mock_coordinator_core._respond_on_bad_request, agent_addr, "Malformed Action") + + @pytest.mark.asyncio + async def test_respond_on_bad_request(self, mock_coordinator_core): + """New test for _respond_on_bad_request.""" + mock_coordinator_core._respond_on_bad_request = GameCoordinator._respond_on_bad_request.__get__(mock_coordinator_core) + agent_addr = ("127.0.0.1", 12345) + mock_coordinator_core._agent_response_queues = {agent_addr: asyncio.Queue()} + + await mock_coordinator_core._respond_on_bad_request(agent_addr, "Malformed Action") + + # Ensure the response is in the queue + assert not mock_coordinator_core._agent_response_queues[agent_addr].empty() + + response_json = await mock_coordinator_core._agent_response_queues[agent_addr].get() + response_data = json.loads(response_json) + + assert response_data["status"] == str(GameStatus.BAD_REQUEST) + assert response_data["observation"] is None + assert "Malformed Action" in response_data["message"]["message"] \ No newline at end of file