diff --git a/tests/event.py b/tests/event.py index 324e2ec..91dffd1 100644 --- a/tests/event.py +++ b/tests/event.py @@ -153,22 +153,34 @@ def container_id(self) -> str: def loginuid(self) -> int: return self._loginuid - @override - def __eq__(self, other: Any) -> bool: - if isinstance(other, ProcessSignal): - if self.pid is not None and self.pid != other.pid: - return False - - return ( - self.uid == other.uid and - self.gid == other.gid and - self.exe_path == other.exec_file_path and - self.args == other.args and - self.name == other.name and - self.container_id == other.container_id and - self.loginuid == other.login_uid - ) - raise NotImplementedError + def diff(self, other: ProcessSignal) -> dict | None: + """ + Compare this Process with a ProcessSignal protobuf message. + + Args: + other: ProcessSignal protobuf message to compare against + + Returns: + None if identical, dict of differences if not matching + """ + diff = {} + + # Compare each field + if self.pid is not None: + Event._diff_field(diff, 'pid', self.pid, other.pid) + + Event._diff_field(diff, 'uid', self.uid, other.uid) + Event._diff_field(diff, 'gid', self.gid, other.gid) + Event._diff_field(diff, 'exe_path', + self.exe_path, other.exec_file_path) + Event._diff_field(diff, 'args', self.args, other.args) + Event._diff_field(diff, 'name', self.name, other.name) + Event._diff_field(diff, 'container_id', + self.container_id, other.container_id) + Event._diff_field(diff, 'loginuid', + self.loginuid, other.login_uid) + + return diff if diff else None @override def __str__(self) -> str: @@ -178,13 +190,6 @@ def __str__(self) -> str: f'loginuid={self.loginuid})') -def cmp_path(p1: str | Pattern[str], p2: str) -> bool: - if isinstance(p1, Pattern): - return bool(p1.match(p2)) - else: - return p1 == p2 - - class Event: """ Represents a file activity event, associating a process with an @@ -249,38 +254,87 @@ def old_file(self) -> str | Pattern[str] | None: def old_host_path(self) -> str | Pattern[str] | None: return self._old_host_path - @override - def __eq__(self, other: Any) -> bool: - if isinstance(other, FileActivity): - if self.process != other.process or self.event_type.name.lower() != other.WhichOneof('file'): - return False - - if self.event_type == EventType.CREATION: - return cmp_path(self.file, other.creation.activity.path) and \ - cmp_path(self.host_path, other.creation.activity.host_path) - elif self.event_type == EventType.OPEN: - return cmp_path(self.file, other.open.activity.path) and \ - cmp_path(self.host_path, other.open.activity.host_path) - elif self.event_type == EventType.UNLINK: - return cmp_path(self.file, other.unlink.activity.path) and \ - cmp_path(self.host_path, other.unlink.activity.host_path) - elif self.event_type == EventType.PERMISSION: - return cmp_path(self.file, other.permission.activity.path) and \ - cmp_path(self.host_path, other.permission.activity.host_path) and \ - self.mode == other.permission.mode - elif self.event_type == EventType.OWNERSHIP: - return cmp_path(self.file, other.ownership.activity.path) and \ - cmp_path(self.host_path, other.ownership.activity.host_path) and \ - self.owner_uid == other.ownership.uid and \ - self.owner_gid == other.ownership.gid - elif self.event_type == EventType.RENAME: - return cmp_path(self.file, other.rename.new.path) and \ - cmp_path(self.host_path, other.rename.new.host_path) and \ - cmp_path(self.old_file, other.rename.old.path) and \ - cmp_path(self.old_host_path, other.rename.old.host_path) - - return False - raise NotImplementedError + @classmethod + def _diff_field(cls, diff, name, expected, actual): + if expected != actual: + diff[name] = { + 'expected': expected, + 'actual': actual, + } + + @classmethod + def _diff_path(cls, diff, name: str, expected: str | Pattern[str], actual: str): + """ + Compare paths with regex pattern support. + """ + if isinstance(expected, Pattern): + if not expected.match(actual): + diff[name] = { + 'expected': f'{expected}', + 'actual': actual + } + elif expected != actual: + diff[name] = { + 'expected': expected, + 'actual': actual + } + + def diff(self, other: FileActivity) -> dict | None: + """ + Compare this Event with a FileActivity protobuf message. + + Args: + other: FileActivity protobuf message to compare against + + Returns: + None if identical, dict of differences if not matching + """ + diff = {} + + # Check process differences first + process_diff = self.process.diff(other.process) + if process_diff is not None: + diff['process'] = process_diff + + # Check event type + event_type_expected = self.event_type.name.lower() + event_type_actual = other.WhichOneof('file') + + Event._diff_field(diff, 'event_type', + event_type_expected, event_type_actual) + if diff: + return diff + + # Get the appropriate event field based on type + event_field = getattr(other, event_type_expected) + + # Rename handling is a bit different to the rest, since it has + # new and old paths. + if self.event_type == EventType.RENAME: + Event._diff_path(diff, 'new_file', self.file, event_field.new.path) + Event._diff_path(diff, 'new_host_path', + self.host_path, event_field.new.host_path) + Event._diff_path(diff, 'old_file', self.old_file, + event_field.old.path) + Event._diff_path(diff, 'old_host_path', + self.old_host_path, event_field.old.host_path) + return diff if diff else None + + # Compare file and host_path (common to all event types) + # All event types have .activity.path and .activity.host_path except they're accessed differently + Event._diff_path(diff, 'file', self.file, event_field.activity.path) + Event._diff_path(diff, 'host_path', self.host_path, + event_field.activity.host_path) + + if self.event_type == EventType.PERMISSION: + Event._diff_field(diff, 'mode', self.mode, event_field.mode) + elif self.event_type == EventType.OWNERSHIP: + Event._diff_field(diff, 'owner_uid', + self.owner_uid, event_field.uid) + Event._diff_field(diff, 'owner_gid', + self.owner_gid, event_field.gid) + + return diff if diff else None @override def __str__(self) -> str: diff --git a/tests/server.py b/tests/server.py index d3aa8d1..004fa8e 100644 --- a/tests/server.py +++ b/tests/server.py @@ -1,9 +1,9 @@ from concurrent import futures from collections import deque +import json from threading import Event from time import sleep -from google.protobuf.json_format import MessageToJson import grpc from internalapi.sensor import sfa_iservice_pb2_grpc @@ -89,12 +89,15 @@ def _wait_events(self, events: list[Event], strict: bool): continue print(f'Got event: {msg}') - if msg in events: - events.remove(msg) + + # Check if msg matches the next expected event + diff = events[0].diff(msg) + if diff is None: + events.pop(0) if len(events) == 0: - break + return elif strict: - raise ValueError(f'Encountered unexpected event: {msg}') + raise ValueError(json.dumps(diff, indent=4)) def wait_events(self, events: list[Event], strict: bool = True): """