Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions app/service/app_svc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from aiohttp import web
import croniter

from app.objects.c_adversary import Adversary
from app.objects.c_plugin import Plugin
from app.service.interfaces.i_app_svc import AppServiceInterface
from app.utility.base_service import BaseService
Expand Down Expand Up @@ -219,6 +220,22 @@ async def watch_ability_files(self):
await self.get_service('data_svc').load_ability_file(filename=f, access=p.access)
await asyncio.sleep(int(self.get_config('ability_refresh')))

async def watch_adversary_files(self):
await asyncio.sleep(int(self.get_config('ability_refresh')))
plugins = [p for p in await self.get_service('data_svc').locate('plugins', dict(enabled=True)) if p.data_dir]
plugins.append(Plugin(data_dir='data'))
Comment on lines +225 to +226
while True:
for p in plugins:
Comment on lines +225 to +228
files = (os.path.join(rt, fle) for rt, _, f in os.walk(p.data_dir+'/adversaries') for fle in f if
time.time() - os.stat(os.path.join(rt, fle)).st_mtime < int(self.get_config('ability_refresh')))
Comment on lines +229 to +230
for f in files:
if not f.endswith(('.yml', '.yaml')):
self.log.debug('[%s] Skipping non YML file %s' % (p.name, f))
continue
self.log.debug('[%s] Reloading adversary %s' % (p.name, f))
await self.get_service('data_svc').load_yaml_file(Adversary, filename=f, access=p.access)
await asyncio.sleep(int(self.get_config('ability_refresh')))

def register_subapp(self, path: str, app: web.Application):
"""Registers a web application under the root application.

Expand Down
1 change: 1 addition & 0 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def _handle_sigterm(*args):
loop.create_task(app_svc.run_scheduler())
loop.create_task(learning_svc.build_model())
loop.create_task(app_svc.watch_ability_files())
loop.create_task(app_svc.watch_adversary_files())
loop.run_until_complete(start_server())
loop.run_until_complete(event_svc.fire_event(exchange="system", queue="ready"))
loop.run_until_complete(
Expand Down
136 changes: 136 additions & 0 deletions tests/services/test_watch_adversary_files.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
"""Tests for watch_adversary_files() in AppService.

Validates that new or modified adversary YAML files on disk are automatically
reloaded into data_svc.ram['adversaries'] by the periodic watcher task,
mirroring the existing watch_ability_files() behaviour.
"""
import ast
import os
import unittest

import pytest
import yaml

from app.objects.c_adversary import Adversary
from app.utility.base_service import BaseService


# ---------------------------------------------------------------------------
# AST-level structural check — verify watch_adversary_files is wired up
# ---------------------------------------------------------------------------

class TestWatchAdversaryFilesStructure(unittest.TestCase):
"""Verify, without importing server.py, that watch_adversary_files() is
launched as a background task inside run_tasks()."""

def _parse_server(self):
server_path = os.path.join(
os.path.dirname(__file__), '..', '..', 'server.py'
)
with open(os.path.normpath(server_path)) as fh:
return ast.parse(fh.read())

def _get_run_tasks_body(self):
tree = self._parse_server()
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef) and node.name == 'run_tasks':
return node
return None

def test_watch_adversary_files_task_registered(self):
run_tasks = self._get_run_tasks_body()
self.assertIsNotNone(run_tasks, "run_tasks() function not found in server.py")

# Look for loop.create_task(app_svc.watch_adversary_files())
found = False
for node in ast.walk(run_tasks):
if not isinstance(node, ast.Call):
continue
func = node.func
if isinstance(func, ast.Attribute) and func.attr == 'create_task':
for arg in node.args:
if isinstance(arg, ast.Call) and isinstance(arg.func, ast.Attribute):
if arg.func.attr == 'watch_adversary_files':
found = True
break
if found:
break

self.assertTrue(
found,
"loop.create_task(app_svc.watch_adversary_files()) not found inside run_tasks() in server.py"
)

def test_watch_adversary_files_method_exists(self):
"""AppService must define watch_adversary_files as an async method."""
app_svc_path = os.path.join(
os.path.dirname(__file__), '..', '..', 'app', 'service', 'app_svc.py'
)
with open(os.path.normpath(app_svc_path)) as fh:
tree = ast.parse(fh.read())

for node in ast.walk(tree):
if isinstance(node, ast.AsyncFunctionDef) and node.name == 'watch_adversary_files':
return

self.fail("watch_adversary_files() async method not found in app_svc.py")


# ---------------------------------------------------------------------------
# Functional test — adversary YAML reload
# ---------------------------------------------------------------------------

class TestAdversaryFileReload:

@pytest.fixture
def adversary_dir(self, tmp_path):
d = tmp_path / "data" / "adversaries"
d.mkdir(parents=True)
return d

async def test_load_yaml_file_stores_adversary(self, adversary_dir, data_svc):
"""Writing a new adversary YAML and calling load_yaml_file should
insert or update it in data_svc.ram['adversaries']."""
Comment on lines +91 to +93

adv_data = {
'id': 'test-adv-12345',
'name': 'Test Watcher Adversary',
'description': 'Created by test',
'atomic_ordering': [],
}
adv_file = adversary_dir / "test_watcher.yml"
adv_file.write_text(yaml.dump([adv_data]))

await data_svc.load_yaml_file(Adversary, str(adv_file), data_svc.Access.RED)

results = await data_svc.locate('adversaries', dict(adversary_id='test-adv-12345'))
assert len(results) == 1
assert results[0].name == 'Test Watcher Adversary'

async def test_reload_updates_existing_adversary(self, adversary_dir, data_svc):
"""Reloading a modified YAML should update the adversary in RAM."""
Comment on lines +110 to +111

adv_data = {
'id': 'test-adv-reload-001',
'name': 'Original Name',
'description': 'Original description',
'atomic_ordering': [],
}
adv_file = adversary_dir / "reload_test.yml"
adv_file.write_text(yaml.dump([adv_data]))
await data_svc.load_yaml_file(Adversary, str(adv_file), data_svc.Access.RED)

# Modify the file
adv_data['name'] = 'Updated Name'
adv_data['description'] = 'Updated description'
adv_file.write_text(yaml.dump([adv_data]))
await data_svc.load_yaml_file(Adversary, str(adv_file), data_svc.Access.RED)

results = await data_svc.locate('adversaries', dict(adversary_id='test-adv-reload-001'))
assert len(results) == 1
assert results[0].name == 'Updated Name'
assert results[0].description == 'Updated description'


if __name__ == '__main__':
unittest.main()
Loading