Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion app/service/file_svc.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(self):
self.encryptor = self._get_encryptor()
self.encrypt_output = False if self.get_config('encrypt_files') is False else True
self.packers = dict()
self._path_cache = {} # filename -> (plugin_name, file_path)

async def get_file(self, headers):
headers = CIMultiDict(headers)
Expand Down Expand Up @@ -74,6 +75,7 @@ async def save_file(self, filename, payload, target_dir, encrypt=True, encoding=
if encoding:
payload = await self._decode_contents(payload, encoding)
self._save(os.path.join(target_dir, filename), payload, encrypt)
self.invalidate_path_cache(filename)

async def create_exfil_sub_directory(self, dir_name):
path = os.path.join(self.get_config('exfil_dir'), dir_name)
Expand Down Expand Up @@ -108,16 +110,31 @@ async def save_multipart_file_upload(self, request, target_dir, encrypt=True):
except Exception as e:
self.log.debug('Exception uploading file: %s' % e)

def invalidate_path_cache(self, name=None):
"""Clear path cache entries. If name is given, clear only that entry."""
if name:
self._path_cache.pop(name, None)
else:
self._path_cache.clear()

async def find_file_path(self, name, location=''):
cache_key = f'{name}:{location}'
if cache_key in self._path_cache:
return self._path_cache[cache_key]
for plugin in await self.data_svc.locate('plugins', match=dict(enabled=True)):
for subd in ['', 'data']:
file_path = await self.walk_file_path(os.path.join('plugins', plugin.name, subd, location), name)
if file_path:
self._path_cache[cache_key] = (plugin.name, file_path)
return plugin.name, file_path
file_path = await self.walk_file_path(os.path.join('data', location), name)
if file_path:
self._path_cache[cache_key] = (None, file_path)
return None, file_path
return None, await self.walk_file_path('%s' % location, name)
result = (None, await self.walk_file_path('%s' % location, name))
if result[1]:
self._path_cache[cache_key] = result
return result

async def read_file(self, name, location='payloads'):
_, file_name = await self.find_file_path(name, location=location)
Expand Down
40 changes: 40 additions & 0 deletions tests/security/test_file_path_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import unittest
from unittest.mock import MagicMock, AsyncMock, patch
import asyncio


class TestFilePathCache(unittest.TestCase):
def test_cache_returns_cached_result(self):
from app.service.file_svc import FileSvc
svc = FileSvc.__new__(FileSvc)
svc._path_cache = {'test.ps1:': ('sandcat', '/path/test.ps1')}
svc.data_svc = MagicMock()
svc.log = MagicMock()

loop = asyncio.new_event_loop()
try:
result = loop.run_until_complete(svc.find_file_path('test.ps1'))
self.assertEqual(result, ('sandcat', '/path/test.ps1'))
# data_svc.locate should NOT have been called
svc.data_svc.locate.assert_not_called()
finally:
loop.close()

def test_invalidate_clears_cache(self):
from app.service.file_svc import FileSvc
svc = FileSvc.__new__(FileSvc)
svc._path_cache = {'test.ps1:': ('sandcat', '/path/test.ps1')}
svc.invalidate_path_cache()
self.assertEqual(len(svc._path_cache), 0)

def test_invalidate_specific_name(self):
from app.service.file_svc import FileSvc
svc = FileSvc.__new__(FileSvc)
svc._path_cache = {'a.ps1:': ('x', '/a'), 'b.ps1:': ('y', '/b')}
svc.invalidate_path_cache('a.ps1:')
self.assertNotIn('a.ps1:', svc._path_cache)
self.assertIn('b.ps1:', svc._path_cache)


if __name__ == '__main__':
unittest.main()
Loading