diff --git a/docs/examples/tools/shell_example.py b/docs/examples/tools/shell_example.py new file mode 100644 index 000000000..1a8525ba7 --- /dev/null +++ b/docs/examples/tools/shell_example.py @@ -0,0 +1,236 @@ +# pytest: e2e, ollama, qualitative +"""Example usage patterns for bash_executor tool. + +Demonstrates multiple ways to use Mellea's bash execution capabilities: +1. Direct execution for local commands +2. Wrapping as a MelleaTool for agent use +3. LLM-based tool calling with forced tool use +4. Working directory and path restrictions +5. Integration with error handling + +⚠️ Security note: bash_executor runs commands locally with a conservative +safety denylist (recommended for typical agentic workflows). The denylist +enforces: no sudo, no rm -rf, no destructive git operations, no writes to +/etc, /sys, /proc, etc. Write operations can also be constrained with +``working_dir`` and explicit ``allowed_paths``. + +For higher isolation requirements (untrusted code, security research), +provide isolation at the application layer (containers, VMs). + +Note: Commands must use argv-friendly syntax (no pipes, redirects, or shell builtins). +Use individual commands and compose them in Python instead. +""" + +from mellea import MelleaSession, start_session +from mellea.backends import ModelOption +from mellea.backends.tools import MelleaTool +from mellea.stdlib.requirements import uses_tool +from mellea.stdlib.tools.shell import bash_executor + + +def example_1_direct_execution() -> None: + """Example 1: Execute bash commands locally (default).""" + print("=== Example 1: Local Execution (Default) ===") + + # Execute a simple command locally + result = bash_executor("echo 'Hello from Bash'") + print("Command: echo 'Hello from Bash'") + print(f"Success: {result.success}") + print(f"Output: {result.stdout}") + print() + + # Execute a command to list files (no pipes/redirects) + result = bash_executor("ls -la") + print("Command: ls -la") + print(f"Success: {result.success}") + if result.stdout: + # Show first few lines + lines = result.stdout.split("\n")[:3] + print("Output (first 3 lines):\n" + "\n".join(lines)) + print() + + # Demonstrate that pipes are blocked (for security) + result = bash_executor("ls -la | wc -l") + print("Command: ls -la | wc -l (pipe operator blocked)") + print(f"Rejected: {result.skipped}") + print(f"Reason: {result.skip_message}") + print() + + # Attempt a dangerous command (will be rejected) + result = bash_executor("sudo echo unsafe") + print("Command: sudo echo unsafe") + print(f"Skipped: {result.skipped}") + print(f"Reason: {result.skip_message}") + print() + + +def example_2_wrapped_as_tool() -> None: + """Example 2: Wrap bash executor as a MelleaTool for LLM use.""" + print("=== Example 2: Wrapped as MelleaTool ===") + + # Create tool from bash executor (local execution by default) + bash_tool = MelleaTool.from_callable(bash_executor) + print(f"Tool name: {bash_tool.name}") + print(f"Tool schema keys: {bash_tool.as_json_tool.keys()}") + print() + + # Invoke the tool directly (normally LLM would call this) + result = bash_tool.run("pwd") + print("Tool invocation result:") + print(f" Success: {result.success}") + print(f" Output: {result.stdout}") + print() + + +def example_3_llm_with_forced_tool_use(m: MelleaSession) -> None: + """Example 3: LLM generates bash commands with forced tool use (requires Ollama). + + This mirrors the Python interpreter pattern: ask the LLM to generate + a bash command, force it to use the tool, then execute the command. + + Requirements: + - Ollama running locally (or compatible LLM configured) + - Run: ollama serve + """ + print("=== Example 3: LLM-Generated Bash Commands with Forced Tool Use ===") + + result = m.instruct( + description="Use bash to find Python files in the current directory. " + "Generate a single command using find or ls (no pipes, redirects, or shell operators allowed).", + requirements=[uses_tool(bash_executor)], + model_options={ModelOption.TOOLS: [MelleaTool.from_callable(bash_executor)]}, + tool_calls=True, + ) + + if result.tool_calls is None: + raise ValueError("Expected tool_calls but got None") + + if "bash_executor" not in result.tool_calls: + available_tools = list(result.tool_calls.keys()) + raise ValueError( + f"Expected tool 'bash_executor' in tool_calls, but got: {available_tools}" + ) + + # Extract the bash command the LLM generated + tool_call = result.tool_calls["bash_executor"] + if "command" not in tool_call.args: + raise ValueError( + f"Expected 'command' argument in tool call args, " + f"but got: {list(tool_call.args.keys())}" + ) + + command = tool_call.args["command"] + print(f"LLM generated bash command:\n {command}\n") + + # Execute the command + exec_result = tool_call.call_func() + + print("Execution result:") + print(f" Success: {exec_result.success}") + print(f" Skipped: {exec_result.skipped}") + if exec_result.skip_message: + print(f" Skip reason: {exec_result.skip_message}") + print(f" Output: {exec_result.stdout}") + if exec_result.stderr: + print(f" Error: {exec_result.stderr}") + print() + + +def example_3_with_working_dir() -> None: + """Example 3: Restrict write validation and execution cwd to a directory.""" + print("=== Example 3: Working Directory Restriction ===") + + import os + import tempfile + + with tempfile.TemporaryDirectory() as tmpdir: + print(f"Working directory: {tmpdir}") + + # Create a file using touch within the working directory (redirects blocked) + result = bash_executor("touch myfile.txt", working_dir=tmpdir) + print(f"Command: touch myfile.txt (relative path, executed in {tmpdir})") + print(f"Success: {result.success}") + print() + + # Verify the file was created + file_path = os.path.join(tmpdir, "myfile.txt") + if os.path.exists(file_path): + print(f"✓ File created at: {file_path}") + print() + + # Read it back + result = bash_executor("cat myfile.txt", working_dir=tmpdir) + print("Command: cat myfile.txt") + print(f"Output: {result.stdout}") + print() + + # Writing to /tmp is always allowed (temp directory exception) + result = bash_executor("touch /tmp/tmpfile.txt", working_dir=tmpdir) + print(f"Command: touch /tmp/tmpfile.txt (with working_dir={tmpdir})") + print(f"Success: {result.success} (note: /tmp is always allowed)") + print() + + # Attempt to write to system paths (will be rejected) + result = bash_executor("touch /etc/config.txt", working_dir=tmpdir) + print(f"Command: touch /etc/config.txt (with working_dir={tmpdir})") + print(f"Rejected: {result.skipped}") + print(f"Reason: {result.skip_message}") + print() + + +def example_4_safety_features() -> None: + """Example 4: Demonstrate safety features.""" + print("=== Example 4: Safety Features ===") + + dangerous_commands = [ + ("rm -rf /home", "Recursive force delete"), + ("git push --force", "Force git push"), + ("sudo whoami", "Privilege escalation"), + ("bash -i", "Interactive shell"), + ("touch /etc/config", "Write to system path"), + ] + + for cmd, description in dangerous_commands: + result = bash_executor(cmd) + print(f"{description}: {cmd}") + print(f" Rejected: {result.skipped}") + print(f" Reason: {result.skip_message}") + print() + + +def example_5_error_handling() -> None: + """Example 5: Handle execution errors gracefully.""" + print("=== Example 5: Error Handling ===") + + # Command that fails (returns non-zero exit code) + result = bash_executor("false") + print("Command: false (POSIX command that returns exit code 1)") + print(f"Success: {result.success}") + print(f"Return code indicates failure: {not result.success}") + print() + + # Command that doesn't exist + result = bash_executor("nonexistent_command_xyz") + print("Command: nonexistent_command_xyz") + print(f"Success: {result.success}") + if not result.success and result.stderr is not None: + print(f"Error output: {result.stderr[:100]}") + print() + + +if __name__ == "__main__": + example_1_direct_execution() + example_2_wrapped_as_tool() + + # Example 3: Run with LLM-based tool calling (requires Ollama or compatible LLM) + try: + m = start_session() + example_3_llm_with_forced_tool_use(m) + except Exception as e: + print(f"Example 3 skipped: {e!s}") + print(" Requires: Ollama running locally or compatible LLM configured") + print(" See: https://docs.ollama.ai/") + + example_3_with_working_dir() + example_4_safety_features() + example_5_error_handling() diff --git a/mellea/stdlib/tools/__init__.py b/mellea/stdlib/tools/__init__.py index 755dea583..b665cbb72 100644 --- a/mellea/stdlib/tools/__init__.py +++ b/mellea/stdlib/tools/__init__.py @@ -1,5 +1,12 @@ """Implementations of tools.""" from .interpreter import code_interpreter, local_code_interpreter +from .shell import BashEnvironment, StaticBashEnvironment, bash_executor -__all__ = ["code_interpreter", "local_code_interpreter"] +__all__ = [ + "BashEnvironment", + "StaticBashEnvironment", + "bash_executor", + "code_interpreter", + "local_code_interpreter", +] diff --git a/mellea/stdlib/tools/_bash_audit.py b/mellea/stdlib/tools/_bash_audit.py new file mode 100644 index 000000000..71e00d67c --- /dev/null +++ b/mellea/stdlib/tools/_bash_audit.py @@ -0,0 +1,269 @@ +"""Audit trail for bash guardrails violations. + +Records all command rejections with pattern, severity, category, and execution +context for compliance audits, security monitoring, and observability integration. +""" + +import threading +import time +from dataclasses import dataclass +from typing import Any + +from mellea.core.utils import MelleaLogger +from mellea.telemetry.context import get_current_context +from mellea.telemetry.metrics import create_counter + + +@dataclass +class BashViolation: + """Record of a guardrail violation. + + Attributes: + timestamp: Unix timestamp when violation occurred. + command: Original command string. + argv: Tokenized command arguments. + pattern: Pattern name that detected the violation (e.g., "DangerousCommandPattern"). + category: Violation category (e.g., "PRIVILEGE_ESCALATION", "DESTRUCTIVE"). + severity: Severity level ("CRITICAL", "HIGH", "MEDIUM", "LOW"). + reason: Human-readable explanation of why it was rejected. + working_dir: Working directory for execution context. + allowed_paths: Path restrictions that were in effect. + session_id: Session ID from context if available. + request_id: Request ID from context if available. + """ + + timestamp: float + command: str + argv: list[str] + pattern: str + category: str + severity: str + reason: str + working_dir: str | None = None + allowed_paths: list[str] | None = None + session_id: str | None = None + request_id: str | None = None + + +class BashAuditTrail: + """Singleton audit trail for bash guardrails violations. + + Records, queries, and exports metrics for all command rejections. + Thread-safe with in-memory storage suitable for typical workflows + where violations are rare. + """ + + _instance: "BashAuditTrail | None" = None + _lock = threading.Lock() + + def __init__(self) -> None: + self._violations: list[BashViolation] = [] + self._violations_by_session: dict[str, list[BashViolation]] = {} + self._violations_by_pattern: dict[str, list[BashViolation]] = {} + self._storage_lock = threading.Lock() + + @classmethod + def get_instance(cls) -> "BashAuditTrail": + """Get singleton audit trail instance (thread-safe).""" + if cls._instance is None: + with cls._lock: + if cls._instance is None: + cls._instance = cls() + return cls._instance + + def record_violation(self, violation: BashViolation) -> None: + """Record a guardrail violation. + + Args: + violation: The violation to record. + """ + with self._storage_lock: + self._violations.append(violation) + + if violation.session_id: + if violation.session_id not in self._violations_by_session: + self._violations_by_session[violation.session_id] = [] + self._violations_by_session[violation.session_id].append(violation) + + if violation.pattern not in self._violations_by_pattern: + self._violations_by_pattern[violation.pattern] = [] + self._violations_by_pattern[violation.pattern].append(violation) + + _log_violation(violation) + _record_violation_metrics(violation.category, violation.severity) + + def get_violations( + self, + session_id: str | None = None, + pattern: str | None = None, + category: str | None = None, + severity: str | None = None, + limit: int | None = None, + ) -> list[BashViolation]: + """Query recorded violations with optional filters. + + Args: + session_id: Filter by session ID. + pattern: Filter by pattern name. + category: Filter by category. + severity: Filter by severity level. + limit: Maximum number of results to return. + + Returns: + List of matching violations. + """ + with self._storage_lock: + results = self._violations[:] + + for violation in results: + if session_id and violation.session_id != session_id: + results.remove(violation) + elif pattern and violation.pattern != pattern: + results.remove(violation) + elif category and violation.category != category: + results.remove(violation) + elif severity and violation.severity != severity: + results.remove(violation) + + if limit: + results = results[:limit] + + return results + + def export_metrics(self) -> dict[str, Any]: + """Export violation metrics. + + Returns: + Dictionary with counts by severity, category, and pattern. + """ + with self._storage_lock: + violations = self._violations[:] + + metrics: dict[str, Any] = {"total": len(violations)} + + severity_counts: dict[str, int] = {} + category_counts: dict[str, int] = {} + pattern_counts: dict[str, int] = {} + + for v in violations: + severity_counts[v.severity] = severity_counts.get(v.severity, 0) + 1 + category_counts[v.category] = category_counts.get(v.category, 0) + 1 + pattern_counts[v.pattern] = pattern_counts.get(v.pattern, 0) + 1 + + for severity, count in severity_counts.items(): + metrics[f"severity_{severity}"] = count + + for category, count in category_counts.items(): + metrics[f"category_{category}"] = count + + for pattern, count in pattern_counts.items(): + metrics[f"pattern_{pattern}"] = count + + return metrics + + def clear(self) -> None: + """Clear all recorded violations (primarily for testing).""" + with self._storage_lock: + self._violations.clear() + self._violations_by_session.clear() + self._violations_by_pattern.clear() + + +def record_bash_violation( + command: str, + argv: list[str], + pattern_name: str, + category: str, + severity: str, + reason: str, + working_dir: str | None = None, + allowed_paths: list[str] | None = None, +) -> None: + """Record a bash guardrail violation (public entry point). + + Args: + command: Original command string. + argv: Tokenized command arguments. + pattern_name: Pattern that detected the violation. + category: Violation category. + severity: Severity level. + reason: Why it was rejected. + working_dir: Working directory context. + allowed_paths: Path restrictions context. + """ + context = get_current_context() + session_id = context.get("session_id") if context else None + request_id = context.get("request_id") if context else None + + violation = BashViolation( + timestamp=time.time(), + command=command[:200], + argv=argv, + pattern=pattern_name, + category=category, + severity=severity, + reason=reason, + working_dir=working_dir, + allowed_paths=allowed_paths, + session_id=session_id, + request_id=request_id, + ) + + trail = BashAuditTrail.get_instance() + trail.record_violation(violation) + + +def get_bash_violations( + session_id: str | None = None, + pattern: str | None = None, + category: str | None = None, + severity: str | None = None, + limit: int | None = None, +) -> list[BashViolation]: + """Query recorded bash violations (public entry point). + + Args: + session_id: Filter by session ID. + pattern: Filter by pattern name. + category: Filter by category. + severity: Filter by severity level. + limit: Maximum number of results. + + Returns: + List of matching violations. + """ + trail = BashAuditTrail.get_instance() + return trail.get_violations( + session_id=session_id, + pattern=pattern, + category=category, + severity=severity, + limit=limit, + ) + + +def _log_violation(violation: BashViolation) -> None: + """Log violation using structured logging.""" + logger = MelleaLogger.get_logger() + logger.warning( + "Bash guardrail violation", + extra={ + "bash_violation": True, + "pattern": violation.pattern, + "category": violation.category, + "severity": violation.severity, + "reason": violation.reason, + "session_id": violation.session_id, + "request_id": violation.request_id, + }, + ) + + +def _record_violation_metrics(category: str, severity: str) -> None: + """Record violation metrics.""" + counter = create_counter( + "bash.guardrail.violations", + description="Count of bash guardrail violations", + unit="1", + ) + counter.add(1, {"category": category, "severity": severity}) diff --git a/mellea/stdlib/tools/_bash_guardrails.py b/mellea/stdlib/tools/_bash_guardrails.py new file mode 100644 index 000000000..4b5f0b8ac --- /dev/null +++ b/mellea/stdlib/tools/_bash_guardrails.py @@ -0,0 +1,292 @@ +"""Structured bash security guardrails framework. + +Provides systematic organization of bash command safety rules with categories, +severity levels, and rationale. This enables: +- Clear documentation of what is blocked and why +- Auditing of coverage across threat categories +- Principled decision-making about future allowlist expansions +- Test-driven verification of guardrail completeness +""" + +from dataclasses import dataclass +from enum import Enum + + +class CommandCategory(Enum): + """Categorization of dangerous commands by threat profile.""" + + PRIVILEGE_ESCALATION = "privilege_escalation" + INTERACTIVE = "interactive" + DESTRUCTIVE = "destructive" + ENVIRONMENT_CHANGING = "environment_changing" + FILE_PERMISSIONS = "file_permissions" + + +class Severity(Enum): + """Risk severity of a command if misused.""" + + CRITICAL = "critical" # Allows full system compromise + HIGH = "high" # Causes significant damage or data loss + MEDIUM = "medium" # Limited damage but noteworthy risk + LOW = "low" # Mostly informational; rarely legitimately needed + + +@dataclass +class CommandRule: + """Rule defining why a command is dangerous and how it's handled. + + Attributes: + category: Threat category this command belongs to. + severity: Risk severity if the command is executed unexpectedly. + rationale: Explanation of why this command is blocked. + safe_with: List of conditions under which the command might be safe + (informational only; not enforced yet). + """ + + category: CommandCategory + severity: Severity + rationale: str + safe_with: list[str] | None = None + + +# Canonical mapping of dangerous commands to their security rules. +# This is the authoritative reference for why each command is blocked. +COMMAND_RULES: dict[str, CommandRule] = { + # Privilege escalation: always dangerous + "sudo": CommandRule( + category=CommandCategory.PRIVILEGE_ESCALATION, + severity=Severity.CRITICAL, + rationale="Elevation to root requires human interaction or stored credentials. Cannot be automated safely in an untrusted pipeline.", + ), + "su": CommandRule( + category=CommandCategory.PRIVILEGE_ESCALATION, + severity=Severity.CRITICAL, + rationale="User switching (su) requires password input or stored credentials. Cannot be automated safely.", + ), + "doas": CommandRule( + category=CommandCategory.PRIVILEGE_ESCALATION, + severity=Severity.CRITICAL, + rationale="Alternative privilege escalation (OpenBSD/BSD). Same risk as sudo.", + ), + # Interactive shells: would block LLM workflow + "bash": CommandRule( + category=CommandCategory.INTERACTIVE, + severity=Severity.HIGH, + rationale="Interactive bash shells (-i flag) block the LLM workflow. Non-interactive usage (e.g., bash script.sh) is allowed.", + safe_with=["non_interactive_mode"], + ), + "sh": CommandRule( + category=CommandCategory.INTERACTIVE, + severity=Severity.HIGH, + rationale="Interactive sh shells block the LLM workflow. Non-interactive usage is allowed.", + safe_with=["non_interactive_mode"], + ), + "zsh": CommandRule( + category=CommandCategory.INTERACTIVE, + severity=Severity.HIGH, + rationale="Interactive zsh shells block the LLM workflow.", + ), + "ksh": CommandRule( + category=CommandCategory.INTERACTIVE, + severity=Severity.HIGH, + rationale="Interactive ksh shells block the LLM workflow.", + ), + "tcsh": CommandRule( + category=CommandCategory.INTERACTIVE, + severity=Severity.HIGH, + rationale="Interactive tcsh shells block the LLM workflow.", + ), + # User/group/password management: permanent system changes + "passwd": CommandRule( + category=CommandCategory.ENVIRONMENT_CHANGING, + severity=Severity.CRITICAL, + rationale="Password changes require user interaction and permanently alter system state.", + ), + "visudo": CommandRule( + category=CommandCategory.ENVIRONMENT_CHANGING, + severity=Severity.CRITICAL, + rationale="Sudo configuration changes require human validation and affect system security model.", + ), + "chsh": CommandRule( + category=CommandCategory.ENVIRONMENT_CHANGING, + severity=Severity.HIGH, + rationale="Change shell permanently alters user environment; rarely needed in automation.", + ), + "chfn": CommandRule( + category=CommandCategory.ENVIRONMENT_CHANGING, + severity=Severity.MEDIUM, + rationale="Change GECOS (user info) has low direct risk but indicates attempt to alter user identity.", + ), + "useradd": CommandRule( + category=CommandCategory.ENVIRONMENT_CHANGING, + severity=Severity.HIGH, + rationale="User creation permanently alters system and requires elevated privileges.", + ), + "userdel": CommandRule( + category=CommandCategory.ENVIRONMENT_CHANGING, + severity=Severity.HIGH, + rationale="User deletion permanently alters system and affects file ownership.", + ), + "usermod": CommandRule( + category=CommandCategory.ENVIRONMENT_CHANGING, + severity=Severity.HIGH, + rationale="User modification (groups, shells, etc.) permanently alters system state.", + ), + "groupadd": CommandRule( + category=CommandCategory.ENVIRONMENT_CHANGING, + severity=Severity.HIGH, + rationale="Group creation permanently alters system.", + ), + "groupdel": CommandRule( + category=CommandCategory.ENVIRONMENT_CHANGING, + severity=Severity.HIGH, + rationale="Group deletion permanently alters system.", + ), + "groupmod": CommandRule( + category=CommandCategory.ENVIRONMENT_CHANGING, + severity=Severity.HIGH, + rationale="Group modification permanently alters system.", + ), +} + + +class ShellOperatorCategory(Enum): + """Category of shell operators that bypass argv parsing.""" + + REDIRECTION = "redirection" # >, >>, <, >&, etc. + PIPE = "pipe" # |, |& + CHAINING = "chaining" # ;, &&, || + BACKGROUND = "background" # & + SUBSTITUTION = "substitution" # $(...), `...`, ${...} + + +@dataclass +class ShellOperatorRule: + """Rule for detecting and blocking shell operators. + + Attributes: + operator: The operator token (e.g., ">>", "&&"). + category: Category of operator (redirection, pipe, etc.). + rationale: Why this operator is blocked. + blocked_if: Description of when it's blocked (e.g., "always", "as standalone token"). + """ + + operator: str + category: ShellOperatorCategory + rationale: str + blocked_if: str = "always" + + +# Canonical shell operators that are always dangerous. +SHELL_OPERATOR_RULES: dict[str, ShellOperatorRule] = { + # Redirection operators + ">": ShellOperatorRule( + operator=">", + category=ShellOperatorCategory.REDIRECTION, + rationale="Output redirection allows writing arbitrary output to any file. Use subprocess or file operations instead.", + blocked_if="standalone or as prefix (e.g., >file, >&2)", + ), + ">>": ShellOperatorRule( + operator=">>", + category=ShellOperatorCategory.REDIRECTION, + rationale="Append redirection allows modifying arbitrary files.", + blocked_if="standalone or as prefix", + ), + "<": ShellOperatorRule( + operator="<", + category=ShellOperatorCategory.REDIRECTION, + rationale="Input redirection bypasses file access controls.", + blocked_if="standalone or as prefix", + ), + ">&": ShellOperatorRule( + operator=">&", + category=ShellOperatorCategory.REDIRECTION, + rationale="Stream redirection (stderr/stdout redirect) bypasses output controls.", + blocked_if="as prefix (e.g., >&2)", + ), + "<<": ShellOperatorRule( + operator="<<", + category=ShellOperatorCategory.REDIRECTION, + rationale="Heredoc redirection embeds multi-line input, reducing transparency.", + blocked_if="standalone or as prefix", + ), + # Pipe operators + "|": ShellOperatorRule( + operator="|", + category=ShellOperatorCategory.PIPE, + rationale="Pipes chain commands without explicit control flow, enabling complex attacks.", + blocked_if="standalone", + ), + "|&": ShellOperatorRule( + operator="|&", + category=ShellOperatorCategory.PIPE, + rationale="Coproc pipes (bash 4.0+) enable bidirectional command interaction.", + blocked_if="standalone", + ), + # Chaining operators + ";": ShellOperatorRule( + operator=";", + category=ShellOperatorCategory.CHAINING, + rationale="Semicolon chains commands regardless of success. Use Python control flow instead.", + blocked_if="substring (dangerous even in quoted contexts in some shells)", + ), + "&&": ShellOperatorRule( + operator="&&", + category=ShellOperatorCategory.CHAINING, + rationale="AND operator chains commands conditionally. Use Python if/else instead.", + blocked_if="standalone", + ), + "||": ShellOperatorRule( + operator="||", + category=ShellOperatorCategory.CHAINING, + rationale="OR operator chains commands on failure. Use Python try/except instead.", + blocked_if="standalone", + ), + # Background operator + "&": ShellOperatorRule( + operator="&", + category=ShellOperatorCategory.BACKGROUND, + rationale="Background execution reduces visibility and control over command lifetime.", + blocked_if="standalone", + ), +} + + +def get_command_rules_by_category(category: CommandCategory) -> dict[str, CommandRule]: + """Get all commands in a specific category. + + Args: + category: The category to filter by. + + Returns: + Dictionary of command -> rule for all commands in the category. + """ + return { + cmd: rule for cmd, rule in COMMAND_RULES.items() if rule.category == category + } + + +def get_high_severity_commands() -> dict[str, CommandRule]: + """Get all commands with high or critical severity. + + Returns: + Dictionary of command -> rule for high/critical severity commands. + """ + return { + cmd: rule + for cmd, rule in COMMAND_RULES.items() + if rule.severity in (Severity.CRITICAL, Severity.HIGH) + } + + +def audit_guardrails_coverage() -> dict[str, list[str]]: + """Audit the coverage of guardrails across threat categories. + + Returns: + Dictionary mapping category names to lists of commands in that category. + """ + coverage: dict[str, list[str]] = {} + for category in CommandCategory: + commands = list(get_command_rules_by_category(category).keys()) + coverage[category.value] = commands + return coverage diff --git a/mellea/stdlib/tools/_bash_patterns.py b/mellea/stdlib/tools/_bash_patterns.py new file mode 100644 index 000000000..b912b0af2 --- /dev/null +++ b/mellea/stdlib/tools/_bash_patterns.py @@ -0,0 +1,252 @@ +"""Extensible bash security pattern detection framework. + +Defines abstract base class and concrete implementations for bash security checks. +New patterns can be added without modifying core validation logic. +""" + +from abc import ABC, abstractmethod + +from ._bash_audit import record_bash_violation +from ._bash_guardrails import COMMAND_RULES, SHELL_OPERATOR_RULES + + +class BashSecurityPattern(ABC): + """Base class for pattern-based security checks. + + Each pattern detects a specific class of dangerous usage (e.g., shell operators, + code execution paths, dangerous commands). Patterns can be composed and registered + in a central registry for modular validation. + """ + + @abstractmethod + def check(self, argv: list[str]) -> tuple[bool, str]: + """Check if a command violates this security pattern. + + Args: + argv: Tokenized command arguments (from shlex.split()). + + Returns: + Tuple of (is_dangerous, reason_message). If is_dangerous is True, + reason_message explains why the pattern was violated. + """ + + +class DangerousCommandPattern(BashSecurityPattern): + """Detects usage of dangerous commands like sudo, passwd, useradd, etc. + + Uses COMMAND_RULES from _bash_guardrails for authoritative definitions. + """ + + def check(self, argv: list[str]) -> tuple[bool, str]: + """Check if command is in the dangerous commands list.""" + if not argv: + return False, "" + + cmd = argv[0].split("/")[-1] # Get basename + + if cmd in COMMAND_RULES: + # Special case: interactive shells are only dangerous with -i flag + if cmd in ("bash", "sh", "zsh", "ksh", "tcsh"): + if any(arg in ("-i", "--interactive", "-l", "-login") for arg in argv): + return True, f"Interactive shell '{cmd}' is not allowed" + else: + return True, f"Command '{cmd}' is not allowed" + + return False, "" + + +class ShellOperatorPattern(BashSecurityPattern): + """Detects shell operators: |, >, &&, ;, etc. + + These operators require shell interpretation and can enable complex attacks. + Detected after shlex.split(), so they appear as standalone tokens or prefixes. + """ + + def check(self, argv: list[str]) -> tuple[bool, str]: + """Check for shell operators in argv.""" + if not argv: + return False, "" + + shell_operators = {"<", ">", "|", ";", "&", "&&", "||", ">>", ">&", "<<", "|&"} + + for arg in argv: + # Exact match: standalone operators like "&&", "|" + if arg in shell_operators: + rule = SHELL_OPERATOR_RULES.get(arg) + reason = rule.rationale if rule else "Shell operator is not allowed" + return True, reason + + # Prefix match: operators with content like ">&2", ">file" + for op in shell_operators: + if arg.startswith(op) and len(arg) > len(op): + rule = SHELL_OPERATOR_RULES.get(op) + reason = rule.rationale if rule else "Shell operator is not allowed" + return True, reason + + # Semicolon: substring check (dangerous even in some quote contexts) + if ";" in arg: + return True, "Command chaining (;) is not allowed" + + return False, "" + + +class CommandSubstitutionPattern(BashSecurityPattern): + """Detects command substitution: $(cmd), `cmd`, ${var}, etc. + + These patterns allow arbitrary code execution and bypass argv parsing. + """ + + def check(self, argv: list[str]) -> tuple[bool, str]: + """Check for command substitution patterns.""" + if not argv: + return False, "" + + for arg in argv: + if "`" in arg or "$(" in arg: + return True, "Command substitution is not allowed" + if "${" in arg: + return True, "Variable expansion is not allowed" + + return False, "" + + +class CodeExecutionPattern(BashSecurityPattern): + """Detects interpreter code execution paths: python -c, bash -c, etc. + + These flags cause interpreters to treat arguments as source code, bypassing argv parsing. + """ + + def check(self, argv: list[str]) -> tuple[bool, str]: + """Check for interpreter indirection (code execution flags).""" + if not argv: + return False, "" + + cmd = argv[0].split("/")[-1] + + code_execution_interpreters = { + "python": ("-c", "-m"), + "python3": ("-c", "-m"), + "python2": ("-c", "-m"), + "perl": ("-e", "-E"), + "ruby": ("-e", "-E"), + "node": ("-e", "--eval"), + "bash": ("-c",), + "sh": ("-c",), + "zsh": ("-c",), + "ksh": ("-c",), + "tcsh": ("-c",), + } + + if cmd in code_execution_interpreters: + dangerous_flags = code_execution_interpreters[cmd] + if any(arg in dangerous_flags for arg in argv): + return ( + True, + f"Interpreter code execution ('{cmd} {' '.join(dangerous_flags)}') is not allowed", + ) + + return False, "" + + +class DestructiveGitPattern(BashSecurityPattern): + """Detects dangerous git operations: push --force, reset --hard, clean -f, etc. + + These operations have high regret cost (lost commits, data loss). + """ + + def check(self, argv: list[str]) -> tuple[bool, str]: + """Check for destructive git operations.""" + if not argv or argv[0].split("/")[-1] != "git": + return False, "" + + # git push --force + if "push" in argv and any(arg in ("--force", "-f") for arg in argv): + return True, "Destructive git operation is not allowed" + + # git reset --hard + if "reset" in argv and "--hard" in argv: + return True, "Destructive git operation is not allowed" + + # git clean -f/-d + if "clean" in argv: + for arg in argv: + if arg in ("-f", "-d", "-fd", "-df"): + return True, "Destructive git operation is not allowed" + if arg.startswith("-") and not arg.startswith("--"): + if "f" in arg or "d" in arg: + return True, "Destructive git operation is not allowed" + + return False, "" + + +class DestructiveRmPattern(BashSecurityPattern): + """Detects destructive rm operations: rm -rf, rm -r, etc. + + Recursive deletion is the highest-regret operation for filesystem safety. + """ + + def check(self, argv: list[str]) -> tuple[bool, str]: + """Check for destructive rm patterns.""" + if not argv or argv[0].split("/")[-1] != "rm": + return False, "" + + if any(flag in argv for flag in ("-r", "-rf", "--recursive")): + return True, "rm with -r or -rf flag is not allowed" + + return False, "" + + +# Registry of all security patterns. New patterns can be added here. +SECURITY_PATTERNS: list[BashSecurityPattern] = [ + DangerousCommandPattern(), + ShellOperatorPattern(), + CommandSubstitutionPattern(), + CodeExecutionPattern(), + DestructiveGitPattern(), + DestructiveRmPattern(), +] + + +def check_all_patterns( + argv: list[str], + working_dir: str | None = None, + allowed_paths: list[str] | None = None, +) -> tuple[bool, str]: + """Check command against all registered security patterns. + + Args: + argv: Tokenized command arguments. + working_dir: Working directory context for audit trail. + allowed_paths: Allowed paths context for audit trail. + + Returns: + Tuple of (is_dangerous, reason_message) from the first matching pattern, + or (False, "") if all patterns pass. + """ + for pattern in SECURITY_PATTERNS: + is_dangerous, reason = pattern.check(argv) + if is_dangerous: + pattern_name = type(pattern).__name__ + category = getattr(pattern, "category", "unknown") + severity = getattr(pattern, "severity", "MEDIUM") + record_bash_violation( + command=" ".join(argv), + argv=argv, + pattern_name=pattern_name, + category=category, + severity=severity, + reason=reason, + working_dir=working_dir, + allowed_paths=allowed_paths, + ) + return True, reason + return False, "" + + +def get_pattern_names() -> list[str]: + """Get names of all registered security patterns. + + Returns: + List of pattern class names. + """ + return [type(pattern).__name__ for pattern in SECURITY_PATTERNS] diff --git a/mellea/stdlib/tools/shell.py b/mellea/stdlib/tools/shell.py new file mode 100644 index 000000000..92a450b28 --- /dev/null +++ b/mellea/stdlib/tools/shell.py @@ -0,0 +1,796 @@ +"""Bash shell command execution tool and execution environments for agentic workflows. + +Provides ``BashEnvironment`` (abstract base for bash execution) and two concrete +implementations: ``StaticBashEnvironment`` (parse and safety-check only, no execution) +and ``_LocalBashEnvironment`` (subprocess execution in the current shell). All +environments enforce a conservative safety denylist (sudo, rm -rf, git push --force, +system paths, interactive shells). Write operations may also be constrained by +``working_dir`` and ``allowed_paths``. + +The top-level ``bash_executor`` (recommended entry point) executes commands locally +with denylist safety checks. Bash executor runs with access to the host environment; +isolation must be provided by the application layer (containers, VMs). + +The function is ready to be wrapped as a ``MelleaTool`` instance for ReACT or +other agentic loops. + +Security note: The denylist covers inline code execution (e.g., bash -c, python -e) and +dangerous commands in argv. However, it does not prevent execution of pre-existing +script files (e.g., bash script.sh, python script.py), which can execute arbitrary +code from the file. For untrusted inputs, ensure that script files are either absent +or come from a trusted source. +""" + +import shlex +import subprocess +from abc import ABC, abstractmethod +from pathlib import Path + +from ...core import MelleaLogger +from ._bash_audit import record_bash_violation +from .interpreter import ExecutionResult + +logger = MelleaLogger.get_logger() + +# Commands that are always dangerous +DANGEROUS_COMMANDS = { + # Privilege escalation + "sudo", + "su", + "doas", + # Interactive shells + "bash", + "sh", + "zsh", + "ksh", + "tcsh", + # User/group/password changes + "passwd", + "visudo", + "chsh", + "chfn", + "useradd", + "userdel", + "usermod", + "groupadd", + "groupdel", + "groupmod", +} + +# Safe wrapper commands that can invoke nested commands (e.g., env, timeout). +# Only commands in this set are allowed to have dangerous commands as nested arguments. +# For all other commands, dangerous commands are rejected regardless of nesting. +SAFE_WRAPPER_COMMANDS = { + "env", # set environment vars + "timeout", # limit execution time + "nice", # set process priority + "nohup", # ignore SIGHUP + "stdbuf", # modify buffering + "unbuffer", # alias for stdbuf + "ionice", # set I/O priority +} + +# System paths that are dangerous to write to +# Includes both standard Linux paths and macOS /private equivalents +# (on macOS, many system paths are symlinks to /private/*) +# NOTE: On macOS, /var/folders/* resolves to /private/var/folders/* (user temp dirs). +# We don't block the entire /private/var tree because that would block legitimate +# writes to temp directories. Instead, we block specific dangerous subdirectories. +DANGEROUS_PATHS = { + "/", + "/bin", + "/sbin", + "/usr", + "/lib", + "/lib64", + "/etc", + "/sys", + "/proc", + "/boot", + "/root", + "/var/log", + "/var/www", + # macOS /private equivalents (specific paths, not entire /private/var) + "/private/etc", + "/private/var/log", + "/private/var/www", + "/private/var/db", + "/private/var/root", +} + +# Dangerous flags that make commands unsafe +DANGEROUS_FLAGS = {"-rf", "-r", "--recursive", "--force", "-f", "--force-all"} + +# Maximum output size (10KB per stream) +MAX_OUTPUT_SIZE = 10 * 1024 + + +def _is_dangerous_command(argv: list[str]) -> tuple[bool, str]: + """Check if a command is dangerous based on argv analysis. + + Args: + argv: Tokenized command arguments. + + Returns: + A tuple of (is_dangerous, reason_message). + """ + if not argv: + return False, "" + + cmd = argv[0].split("/")[-1] # Get basename of command + + # Check for shell metacharacters that would need shell interpretation + # After shlex.split(), these characters in argv indicate shell operators (not quoted strings) + # These are only dangerous if they're standalone tokens (e.g., argv[i] == ">>"). + # Substring matches would cause false positives for legitimate patterns like "a&&b". + shell_operators = {"<", ">", "|", ";", "&", "&&", "||", ">>", ">&", "<<", "|&"} + + for arg in argv: + # Check for redirect/pipe/logic operators (these are shell operators). + # Operators can appear as: + # 1. Standalone tokens: arg == "&&" (caught by exact match) + # 2. Operator with argument: arg == ">&2" or arg == ">file" (start with operator) + # We don't check for substring matches (e.g., "a&&b") to avoid false positives + # for legitimate patterns like regex or AWK/sed code. + + # Exact match first (standalone operators like "&&", "|", etc.) + if arg in shell_operators: + return True, f"Shell operator '{arg}' is not allowed" + + # Check if argument starts with a shell operator (e.g., ">&2", ">file", "2>&1") + for op in shell_operators: + if arg.startswith(op) and len(arg) > len(op): + # Token starts with operator and has additional content (e.g., ">&2", ">file") + # This is a shell redirection/operator usage + return True, f"Shell operator '{op}' is not allowed" + + # Note: semicolon is already in shell_operators, but we check for it separately + # as a substring because semicolon can be dangerous even in patterns. + # Unlike && or ||, semicolon rarely appears legitimately in arguments. + if ";" in arg: + return True, "Command chaining (;) is not allowed" + + # Check for command substitution (backticks, $(...)) + for arg in argv: + if "`" in arg or "$(" in arg: + return True, "Command substitution is not allowed" + # Check for variable expansion patterns + if "${" in arg: + return True, "Variable expansion is not allowed" + + # Check for interpreter indirection (code execution via -c, -e, etc.) + # These allow arbitrary code execution and bypass argv parsing + code_execution_interpreters = { + "python": ("-c", "-m"), + "python3": ("-c", "-m"), + "python2": ("-c", "-m"), + "perl": ("-e", "-E"), + "ruby": ("-e", "-E"), + "node": ("-e", "--eval"), + "bash": ("-c",), + "sh": ("-c",), + "zsh": ("-c",), + "ksh": ("-c",), + "tcsh": ("-c",), + } + if cmd in code_execution_interpreters: + dangerous_flags = code_execution_interpreters[cmd] + if any(arg in dangerous_flags for arg in argv): + return ( + True, + f"Interpreter code execution ('{cmd} {' '.join(dangerous_flags)}') is not allowed", + ) + + # Check if any argument is a dangerous command (e.g., env sudo, timeout sudo) + # Only check positional arguments that are not paths or flag values. + # Known value-taking flags that consume the next argument (space-separated only). + # NOTE: -i / --input are intentionally not included. env(1)'s -i takes NO value; + # other commands' -i/--input (if present) should not mask dangerous commands. + flag_value_flags = { + "-c", + "--config", + "-f", + "--file", + "-o", + "--output", + "-d", + "--dir", + "-p", + "--path", + "-t", + "--timeout", + "-w", + "--wait", + } + for i, arg in enumerate(argv[1:], start=1): + # Skip if this argument is the value for a preceding flag (space-separated) + # E.g., in "timeout -t 10 sudo", skip "10" (it's the value for -t) + # But don't skip "sudo" when the flag uses = notation (e.g., --kill-after=1) + if i > 1 and argv[i - 1] in flag_value_flags: + continue + # Skip if argument contains / (it's a path, not a command name) + if "/" in arg: + continue + # Skip if argument starts with - (it's a flag) + if arg.startswith("-"): + continue + + arg_cmd = arg.split("/")[-1] + + # Only allow dangerous commands as nested arguments if the top-level + # command is in SAFE_WRAPPER_COMMANDS. This prevents bypasses like + # "env -i sudo" or "timeout -t 10 sudo". Without the wrapper allowlist, + # these would pass validation despite being dangerous. + if arg_cmd in DANGEROUS_COMMANDS: + if cmd not in SAFE_WRAPPER_COMMANDS or arg_cmd not in ( + "bash", + "sh", + "zsh", + "ksh", + "tcsh", + ): + # Allow shells as arguments in safe wrappers (e.g., timeout bash script.sh) + # but reject sudo/su/etc as nested commands + return True, f"Command '{arg_cmd}' is not allowed as an argument" + + # Check for dangerous nested commands that aren't in DANGEROUS_COMMANDS but + # have dangerous flags (e.g., "env rm -rf" or "timeout rm -rf"). + # Only apply this check if the wrapper is in SAFE_WRAPPER_COMMANDS. + if cmd in SAFE_WRAPPER_COMMANDS and arg_cmd == "rm": + # Check if rm has dangerous flags anywhere in argv after this command + if any(flag in argv for flag in ["-r", "-rf", "--recursive"]): + return ( + True, + "Command 'rm' with dangerous flags is not allowed as an argument", + ) + + # Check for dangerous commands + if cmd in DANGEROUS_COMMANDS: + if cmd in ("bash", "sh", "zsh", "ksh", "tcsh"): + if any(arg in ("-i", "--interactive", "-l", "-login") for arg in argv): + return True, f"Interactive shell '{cmd}' is not allowed" + else: + return True, f"Command '{cmd}' is not allowed" + + # Check for dangerous git operations + if cmd == "git": + # Check for destructive git operations: push --force, reset --hard, clean -f/-d + has_destructive_op = False + + # git push --force: check for exact tokens (not substrings) + if "push" in argv and any(arg == "--force" or arg == "-f" for arg in argv): + has_destructive_op = True + + # git reset --hard: both must be exact tokens + if "reset" in argv and "--hard" in argv: + has_destructive_op = True + + # git clean -f/-d: check for these dangerous flags (exact match or combined like -fd) + # Avoid false positives: --dry-run should not match -d (use exact or combined check) + if "clean" in argv: + for arg in argv: + # Exact matches: -f, -d, -fd, -df, etc. + if arg in ("-f", "-d", "-fd", "-df"): + has_destructive_op = True + break + # Also check arg startswith for combined flags containing d or f + # but NOT for things like --dry-run (those start with --) + if arg.startswith("-") and not arg.startswith("--"): + # Short flags: -f, -d, or combinations like -fd, -ddf, etc. + if "f" in arg or "d" in arg: + has_destructive_op = True + break + + if has_destructive_op: + return True, "Destructive git operation is not allowed" + + # Check for dangerous rm patterns + if cmd == "rm": + if "-r" in argv or "-rf" in argv or "--recursive" in argv: + return True, "rm with -r or -rf flag is not allowed" + + # Check for dangerous flags in specific commands where they are truly dangerous. + # Note: We don't check cp/mv/make here because: + # - cp -r: standard way to copy directories recursively + # - mv -r: standard way to move directories recursively + # - make -f: standard way to specify a makefile + # These are not "dangerous" operations in themselves. + # The real danger is rm -rf (covered above) and git --force (covered above). + for flag in DANGEROUS_FLAGS: + if flag in argv: + # Only apply DANGEROUS_FLAGS check to apt/yum (package managers) + # These with -f or -r can indeed be risky + if cmd in ("apt", "yum"): + return True, f"Command '{cmd}' with '{flag}' flag is not allowed" + + return False, "" + + +def _is_path_within(path_str: str, allowed_root: str) -> bool: + """Return whether a resolved path is equal to or nested under an allowed root.""" + if path_str == allowed_root: + return True + allowed_root_prefix = ( + allowed_root if allowed_root.endswith("/") else allowed_root + "/" + ) + return path_str.startswith(allowed_root_prefix) + + +def _normalize_allowed_path(allowed_path: str) -> str: + """Normalize an allowlisted path for string-prefix containment checks.""" + return str(Path(allowed_path).expanduser().resolve(strict=False)) + + +def _resolve_allowed_paths(allowed_paths: list[str]) -> list[str]: + """Normalize allowed path roots for prefix-based containment checks.""" + normalized_paths: list[str] = [] + for allowed_path in allowed_paths: + try: + normalized_paths.append(_normalize_allowed_path(allowed_path)) + except Exception: + logger.warning("Skipping invalid allowed path: %s", allowed_path) + return normalized_paths + + +def _is_default_safe_write_path(path_str: str) -> bool: + """Return whether a path is in a default safe write location.""" + home_dir = str(Path.home()) + return path_str.startswith(("/tmp", "/private/tmp")) or _is_path_within( + path_str, home_dir + ) + + +def _check_dangerous_paths( + argv: list[str], allowed_paths: list[str] | None = None +) -> tuple[bool, str]: + """Check if command targets dangerous or disallowed filesystem paths. + + Args: + argv: Tokenized command arguments. + allowed_paths: Optional additional resolved path roots where writes are allowed. + + Returns: + A tuple of (has_dangerous_paths, reason_message). + """ + write_commands = {"rm", "touch", "cp", "mv", "mkdir", "mkfifo", "mknod", "tee"} + if not argv or argv[0].split("/")[-1] not in write_commands: + return False, "" + + resolved_allowed_paths = _resolve_allowed_paths(allowed_paths or []) + + for arg in argv[1:]: + if arg.startswith("-"): + continue + + try: + # Resolve all paths consistently to catch symlink bypasses + if arg.startswith(("/", "~")): + resolved_arg = str(Path(arg).expanduser().resolve(strict=False)) + else: + # For relative paths, resolve against current directory + resolved_arg = str(Path(arg).expanduser().resolve(strict=False)) + except Exception as e: + # Fail closed: if we can't resolve a path, deny it. + # This prevents attackers from bypassing checks via crafted paths that fail to resolve. + logger.warning(f"Cannot resolve path '{arg}' in dangerous paths check: {e}") + return ( + True, + f"Cannot validate path '{arg}': path resolution failed ({type(e).__name__})", + ) + + for danger_path in DANGEROUS_PATHS: + if resolved_arg == danger_path or resolved_arg.startswith( + danger_path + "/" + ): + return True, f"Writing to '{resolved_arg}' is not allowed" + + if resolved_allowed_paths: + if not any( + _is_path_within(resolved_arg, allowed_root) + for allowed_root in resolved_allowed_paths + ): + return ( + True, + f"Path '{arg}' is outside explicitly allowed paths: {', '.join(allowed_paths or [])}", + ) + + return False, "" + + +def _check_working_dir_restriction( + argv: list[str], working_dir: str | None +) -> tuple[bool, str]: + """Check if command respects working directory restriction. + + Args: + argv: Tokenized command arguments. + working_dir: Allowed working directory, or None for no restriction. + + Returns: + A tuple of (violates_restriction, reason_message). + """ + if not working_dir: + return False, "" + + write_commands = {"rm", "touch", "cp", "mv", "mkdir", "mkfifo", "mknod", "tee"} + if not argv or argv[0].split("/")[-1] not in write_commands: + return False, "" + + try: + allowed_path_str = str(Path(working_dir).expanduser().resolve()) + except Exception as e: + # Fail closed: if we can't resolve working_dir, deny all writes in this directory. + logger.warning(f"Cannot resolve working_dir '{working_dir}': {e}") + return ( + True, + f"Cannot validate working directory: working_dir '{working_dir}' is not resolvable ({type(e).__name__})", + ) + + # Ensure the allowed path ends with / for prefix matching + if not allowed_path_str.endswith("/"): + allowed_path_str_prefix = allowed_path_str + "/" + else: + allowed_path_str_prefix = allowed_path_str + + for arg in argv[1:]: + if arg.startswith("-"): + continue + + # Try to resolve all paths (both absolute and relative) + try: + # For relative paths, resolve them relative to working_dir, not caller's cwd + if arg.startswith(("/", "~")): + resolved_path = str(Path(arg).expanduser().resolve()) + is_relative = False + else: + resolved_path = str( + Path(working_dir, arg).expanduser().resolve(strict=False) + ) + is_relative = True + + # Check if path is allowed: in working_dir, /tmp, or /private/tmp (macOS) + is_in_tmp = resolved_path.startswith(("/tmp", "/private/tmp")) + is_in_working_dir = ( + resolved_path == allowed_path_str + or resolved_path.startswith(allowed_path_str_prefix) + ) + + # For relative paths: must be within working_dir (not just /tmp) + # For absolute paths: can be in working_dir OR /tmp + if is_relative: + # Relative paths must stay within working_dir + if not is_in_working_dir: + return ( + True, + f"Path '{arg}' is outside allowed directory '{working_dir}'", + ) + else: + # Absolute paths can be in working_dir or /tmp + if not (is_in_tmp or is_in_working_dir): + return ( + True, + f"Path '{arg}' is outside allowed directory '{working_dir}'", + ) + except Exception as e: + # Fail closed: if we can't resolve an argument path, deny it. + # This prevents attackers from bypassing checks via crafted paths that fail to resolve. + logger.warning( + f"Cannot resolve argument path '{arg}' in working_dir check: {e}" + ) + return ( + True, + f"Cannot validate path '{arg}': path resolution failed ({type(e).__name__})", + ) + + return False, "" + + +def _truncate_output(output: str, max_size: int = MAX_OUTPUT_SIZE) -> tuple[str, bool]: + """Truncate output if it exceeds max size. + + Args: + output: The output string to potentially truncate. + max_size: Maximum allowed size in bytes. + + Returns: + A tuple of (truncated_output, was_truncated). The output string is clean + (no truncation message). The caller is responsible for appending any + truncation message. + """ + if len(output) > max_size: + return output[:max_size], True + return output, False + + +class BashEnvironment(ABC): + """Abstract environment for executing bash commands. + + Args: + allowed_paths (list[str] | None): Optional explicit write allowlist. When + provided, write-target paths must fall under one of these roots in + addition to passing the default dangerous-path checks. + working_dir (str | None): Optional directory restriction for write + operations. This is a host path where the command executes. When + specified, writes must remain within this directory or ``/tmp``. + timeout (int): Maximum number of seconds to allow command execution. + + Note: + Subclass ``StaticBashEnvironment`` returns ``success=True, skipped=True`` + to indicate that validation passed but the command was intentionally not + executed (analysis-only mode). Consumers that branch on ``success`` should + check ``skipped`` first to handle this state correctly. + + """ + + def __init__( + self, + allowed_paths: list[str] | None = None, + working_dir: str | None = None, + timeout: int = 60, + ): + """Initialize BashEnvironment with optional path allowlist and timeout.""" + self.allowed_paths = allowed_paths or [] + self.working_dir = working_dir + self.timeout = timeout + + def _validate_command(self, command: str) -> ExecutionResult | list[str]: + """Parse and validate a command before execution. + + The shared validation step performs argv parsing, rejects dangerous shell + constructs, applies path safety checks, and enforces ``allowed_paths`` and + ``working_dir`` restrictions for write operations. + + Args: + command: The bash command string to validate. + + Returns: + Either the validated argv list or a skipped ``ExecutionResult`` + describing why validation failed. + """ + try: + argv = shlex.split(command) + except ValueError as e: + return ExecutionResult( + success=False, + stdout=None, + stderr=None, + skipped=True, + skip_message=f"Failed to parse command: {e!s}", + ) + + if not argv: + return ExecutionResult( + success=False, + stdout=None, + stderr=None, + skipped=True, + skip_message="Empty command", + ) + + is_dangerous, reason = _is_dangerous_command(argv) + if is_dangerous: + record_bash_violation( + command=" ".join(argv), + argv=argv, + pattern_name="DangerousCommandPattern", + category="UNKNOWN", + severity="HIGH", + reason=reason, + working_dir=self.working_dir, + allowed_paths=self.allowed_paths, + ) + return ExecutionResult( + success=False, + stdout=None, + stderr=None, + skipped=True, + skip_message=reason, + ) + + has_dangerous, reason = _check_dangerous_paths(argv, self.allowed_paths) + if has_dangerous: + record_bash_violation( + command=" ".join(argv), + argv=argv, + pattern_name="DangerousPathPattern", + category="UNKNOWN", + severity="HIGH", + reason=reason, + working_dir=self.working_dir, + allowed_paths=self.allowed_paths, + ) + return ExecutionResult( + success=False, + stdout=None, + stderr=None, + skipped=True, + skip_message=reason, + ) + + violates_restriction, reason = _check_working_dir_restriction( + argv, self.working_dir + ) + if violates_restriction: + record_bash_violation( + command=" ".join(argv), + argv=argv, + pattern_name="WorkingDirRestrictionPattern", + category="UNKNOWN", + severity="MEDIUM", + reason=reason, + working_dir=self.working_dir, + allowed_paths=self.allowed_paths, + ) + return ExecutionResult( + success=False, + stdout=None, + stderr=None, + skipped=True, + skip_message=reason, + ) + + return argv + + @abstractmethod + def execute(self, command: str) -> ExecutionResult: + """Execute the given bash command and return the result. + + Args: + command (str): The bash command to execute. + + Returns: + ExecutionResult: Execution outcome including stdout, stderr, and + success flag. + """ + + +class StaticBashEnvironment(BashEnvironment): + """Safe environment that validates but does not execute bash commands. + + Returns ``success=True, skipped=True`` when validation passes (command is + syntactically valid and passes all safety checks), indicating the command + would be safe to execute but this environment intentionally does not run it. + Returns ``success=False, skipped=True`` when validation fails (safety check + rejection or parse error). + """ + + def execute(self, command: str) -> ExecutionResult: + """Parse and validate command without executing. + + Args: + command (str): The bash command to validate. + + Returns: + ExecutionResult: Result with ``skipped=True`` and parsed argv in + ``analysis_result`` on success, or a safety-check failure on rejection. + """ + validated = self._validate_command(command) + if isinstance(validated, ExecutionResult): + return validated + + argv = validated + + return ExecutionResult( + success=True, + stdout=None, + stderr=None, + skipped=True, + skip_message="Command passes safety checks; static analysis environment does not execute commands. To execute, use bash_executor().", + analysis_result=argv, + ) + + +class _LocalBashEnvironment(BashEnvironment): + """Environment that executes bash commands directly with subprocess. + + This is the primary execution environment for bash_executor(). Commands execute + in the current process with access to the host environment (working directory, + PATH, git repos, installed tools, environment variables). + + Safety model: Denylist-based (not isolation-based). The conservative denylist + covers dangerous commands, shell operators, code execution paths, and writes to + system directories. This is sufficient for typical agentic workflows where + the command source is trusted (e.g., LLM-generated code in a known pipeline). + + For higher isolation requirements (untrusted code, CTF challenges, or security + research), provide isolation at the application layer (containers, VMs). + """ + + def execute(self, command: str) -> ExecutionResult: + """Execute bash command after safety checks. + + Args: + command (str): The bash command to execute. + + Returns: + ExecutionResult: Execution outcome with captured stdout/stderr and + success flag, or a skipped result if safety checks fail. + """ + validated = self._validate_command(command) + if isinstance(validated, ExecutionResult): + return validated + + argv = validated + + # Execute command with shell=False to prevent shell metacharacter bypass + try: + result = subprocess.run( + argv, + shell=False, + capture_output=True, + text=True, + timeout=self.timeout, + cwd=self.working_dir, + ) + + stdout, stdout_truncated = _truncate_output(result.stdout.strip()) + stderr, stderr_truncated = _truncate_output(result.stderr.strip()) + + # Append truncation warnings if needed + if stdout_truncated: + stdout += "\n[Output truncated - stdout exceeded 10KB]" + if stderr_truncated: + stderr += "\n[Output truncated - stderr exceeded 10KB]" + + return ExecutionResult( + success=result.returncode == 0, stdout=stdout, stderr=stderr + ) + except subprocess.TimeoutExpired: + return ExecutionResult( + success=False, + stdout=None, + stderr=None, + skipped=True, + skip_message=f"Execution timed out after {self.timeout} seconds", + ) + except Exception as e: + return ExecutionResult( + success=False, + stdout=None, + stderr=None, + skipped=True, + skip_message=f"Subprocess execution error: {e!s}", + ) + + +def bash_executor( + command: str, working_dir: str | None = None, allowed_paths: list[str] | None = None +) -> ExecutionResult: + """Execute a bash command with denylist safety checks. + + This is the recommended entry point. Commands execute locally with access to + the host environment (working directory, PATH, git repos, installed tools). + + Safety model: Conservative denylist applied to all commands. The denylist + refuses sudo, interactive shells, destructive operations (rm -rf, git push + --force), shell operators (|, >, &&), code execution paths (python -c, bash + -c), and writes to system paths (/etc, /sys, /proc, etc.). + + Args: + command: The bash command to execute. + working_dir: Optional working directory for the command (host path). + allowed_paths: Optional explicit write allowlist. When provided, + write-target paths must fall under one of these roots (in addition + to passing the default dangerous-path checks). + + Returns: + An ``ExecutionResult`` with stdout, stderr, and success flag. If the + command was rejected for safety reasons, ``skipped=True`` and + ``skip_message`` contains the reason. + + Examples: + Basic execution: + >>> result = bash_executor("echo hello") + >>> assert result.success is True + >>> assert result.stdout == "hello" + + With working directory: + >>> result = bash_executor("pwd", working_dir="/tmp") + >>> assert "/tmp" in result.stdout + + With path restrictions: + >>> result = bash_executor("touch file.txt", allowed_paths=["/tmp"]) + >>> assert result.success is True + """ + env = _LocalBashEnvironment(allowed_paths=allowed_paths, working_dir=working_dir) + return env.execute(command) diff --git a/test/stdlib/tools/test_bash_guardrails.py b/test/stdlib/tools/test_bash_guardrails.py new file mode 100644 index 000000000..241a99e11 --- /dev/null +++ b/test/stdlib/tools/test_bash_guardrails.py @@ -0,0 +1,456 @@ +"""Tests for bash security guardrails framework. + +Verifies: +1. All dangerous commands are properly categorized and documented +2. All security patterns correctly identify violations +3. Guardrails coverage is complete across threat categories +4. New patterns can be added without breaking existing checks +""" + +from mellea.stdlib.tools._bash_guardrails import ( + COMMAND_RULES, + SHELL_OPERATOR_RULES, + CommandCategory, + Severity, + audit_guardrails_coverage, + get_command_rules_by_category, + get_high_severity_commands, +) +from mellea.stdlib.tools._bash_patterns import ( + SECURITY_PATTERNS, + CodeExecutionPattern, + CommandSubstitutionPattern, + DangerousCommandPattern, + DestructiveGitPattern, + DestructiveRmPattern, + ShellOperatorPattern, + check_all_patterns, + get_pattern_names, +) + + +class TestCommandRules: + """Tests for COMMAND_RULES structure and metadata.""" + + def test_all_dangerous_commands_have_rules(self) -> None: + """Every dangerous command should have a defined rule.""" + # From shell.py DANGEROUS_COMMANDS set + expected_commands = { + "sudo", + "su", + "doas", + "bash", + "sh", + "zsh", + "ksh", + "tcsh", + "passwd", + "visudo", + "chsh", + "chfn", + "useradd", + "userdel", + "usermod", + "groupadd", + "groupdel", + "groupmod", + } + rules_keys = set(COMMAND_RULES.keys()) + assert expected_commands.issubset(rules_keys), ( + f"Missing rules for: {expected_commands - rules_keys}" + ) + + def test_rules_have_required_fields(self) -> None: + """Each rule should have category, severity, and rationale.""" + for cmd, rule in COMMAND_RULES.items(): + assert rule.category in CommandCategory, f"{cmd}: invalid category" + assert rule.severity in Severity, f"{cmd}: invalid severity" + assert isinstance(rule.rationale, str) and len(rule.rationale) > 0, ( + f"{cmd}: rationale must be non-empty string" + ) + + def test_privilege_escalation_are_critical(self) -> None: + """Privilege escalation commands should be critical severity.""" + priv_esc = get_command_rules_by_category(CommandCategory.PRIVILEGE_ESCALATION) + for cmd, rule in priv_esc.items(): + assert rule.severity == Severity.CRITICAL, ( + f"{cmd} is privilege escalation but not critical" + ) + + def test_get_command_rules_by_category(self) -> None: + """get_command_rules_by_category should filter correctly.""" + get_command_rules_by_category(CommandCategory.DESTRUCTIVE) + for cmd in ["rm", "rmdir"]: + if cmd in COMMAND_RULES: + assert COMMAND_RULES[cmd].category == CommandCategory.DESTRUCTIVE + + def test_high_severity_commands_retrieved(self) -> None: + """get_high_severity_commands should return critical and high.""" + high_severity = get_high_severity_commands() + for cmd, rule in high_severity.items(): + assert rule.severity in (Severity.CRITICAL, Severity.HIGH), ( + f"{cmd} is in high_severity but not critical/high" + ) + + def test_audit_guardrails_coverage(self) -> None: + """audit_guardrails_coverage should return all categories.""" + coverage = audit_guardrails_coverage() + for category in CommandCategory: + assert category.value in coverage, f"Missing coverage for {category}" + assert isinstance(coverage[category.value], list) + + +class TestShellOperatorRules: + """Tests for SHELL_OPERATOR_RULES structure.""" + + def test_all_shell_operators_have_rules(self) -> None: + """All dangerous shell operators should have rules.""" + expected_operators = {"|", ">", "&&", "||", ";", "&", ">>", ">&", "<<", "|&"} + rules_keys = set(SHELL_OPERATOR_RULES.keys()) + assert expected_operators.issubset(rules_keys), ( + f"Missing rules for: {expected_operators - rules_keys}" + ) + + def test_operator_rules_have_required_fields(self) -> None: + """Each operator rule should have category and rationale.""" + for op, rule in SHELL_OPERATOR_RULES.items(): + assert hasattr(rule, "category"), f"{op}: missing category" + assert isinstance(rule.rationale, str) and len(rule.rationale) > 0 + + +class TestDangerousCommandPattern: + """Tests for DangerousCommandPattern detection.""" + + def test_sudo_rejected(self) -> None: + """sudo command should be rejected.""" + pattern = DangerousCommandPattern() + is_dangerous, reason = pattern.check(["sudo", "echo", "test"]) + assert is_dangerous is True + assert "sudo" in reason.lower() + + def test_interactive_bash_rejected(self) -> None: + """bash -i should be rejected.""" + pattern = DangerousCommandPattern() + is_dangerous, _ = pattern.check(["bash", "-i"]) + assert is_dangerous is True + + def test_non_interactive_bash_allowed(self) -> None: + """bash script.sh should pass (not rejected by pattern).""" + pattern = DangerousCommandPattern() + is_dangerous, _ = pattern.check(["bash", "script.sh"]) + assert is_dangerous is False + + def test_passwd_rejected(self) -> None: + """passwd should be rejected.""" + pattern = DangerousCommandPattern() + is_dangerous, _ = pattern.check(["passwd"]) + assert is_dangerous is True + + +class TestShellOperatorPattern: + """Tests for ShellOperatorPattern detection.""" + + def test_pipe_operator_rejected(self) -> None: + """Pipe operator should be rejected.""" + pattern = ShellOperatorPattern() + is_dangerous, reason = pattern.check(["cat", "file", "|", "grep", "pattern"]) + assert is_dangerous is True + assert len(reason) > 0 # Should have a reason + + def test_redirect_operator_rejected(self) -> None: + """Output redirect should be rejected.""" + pattern = ShellOperatorPattern() + is_dangerous, _ = pattern.check(["echo", "hello", ">", "file.txt"]) + assert is_dangerous is True + + def test_redirect_prefix_rejected(self) -> None: + """Redirect as prefix (>&2) should be rejected.""" + pattern = ShellOperatorPattern() + is_dangerous, _ = pattern.check(["echo", "error", ">&2"]) + assert is_dangerous is True + + def test_and_operator_rejected(self) -> None: + """AND operator should be rejected.""" + pattern = ShellOperatorPattern() + is_dangerous, _ = pattern.check(["cmd1", "&&", "cmd2"]) + assert is_dangerous is True + + def test_semicolon_rejected(self) -> None: + """Semicolon chaining should be rejected.""" + pattern = ShellOperatorPattern() + is_dangerous, _ = pattern.check(["cmd1", ";", "cmd2"]) + assert is_dangerous is True + + def test_quoted_pipe_in_string_allowed(self) -> None: + """Pipe inside quoted string should be allowed (after shlex.split).""" + pattern = ShellOperatorPattern() + # After shlex.split("grep 'a|b'"), we get ["grep", "a|b"] + # The pipe is part of the string, not a standalone operator + is_dangerous, _ = pattern.check(["grep", "a|b"]) + assert is_dangerous is False + + +class TestCommandSubstitutionPattern: + """Tests for CommandSubstitutionPattern detection.""" + + def test_backtick_substitution_rejected(self) -> None: + """Backtick command substitution should be rejected.""" + pattern = CommandSubstitutionPattern() + is_dangerous, _ = pattern.check(["echo", "`date`"]) + assert is_dangerous is True + + def test_dollar_paren_substitution_rejected(self) -> None: + """$(...) command substitution should be rejected.""" + pattern = CommandSubstitutionPattern() + is_dangerous, _ = pattern.check(["echo", "$(whoami)"]) + assert is_dangerous is True + + def test_variable_expansion_rejected(self) -> None: + """Variable expansion ${VAR} should be rejected.""" + pattern = CommandSubstitutionPattern() + is_dangerous, _ = pattern.check(["echo", "${HOME}"]) + assert is_dangerous is True + + +class TestCodeExecutionPattern: + """Tests for CodeExecutionPattern detection.""" + + def test_python_c_rejected(self) -> None: + """python -c should be rejected.""" + pattern = CodeExecutionPattern() + is_dangerous, _ = pattern.check(["python", "-c", "print('hello')"]) + assert is_dangerous is True + + def test_python_m_rejected(self) -> None: + """python -m should be rejected.""" + pattern = CodeExecutionPattern() + is_dangerous, _ = pattern.check(["python", "-m", "http.server"]) + assert is_dangerous is True + + def test_bash_c_rejected(self) -> None: + """bash -c should be rejected.""" + pattern = CodeExecutionPattern() + is_dangerous, _ = pattern.check(["bash", "-c", "rm -rf /"]) + assert is_dangerous is True + + def test_perl_e_rejected(self) -> None: + """perl -e should be rejected.""" + pattern = CodeExecutionPattern() + is_dangerous, _ = pattern.check(["perl", "-e", "system('rm -rf /')"]) + assert is_dangerous is True + + def test_python_script_allowed(self) -> None: + """python script.py should be allowed.""" + pattern = CodeExecutionPattern() + is_dangerous, _ = pattern.check(["python", "script.py"]) + assert is_dangerous is False + + +class TestDestructiveGitPattern: + """Tests for DestructiveGitPattern detection.""" + + def test_git_push_force_rejected(self) -> None: + """git push --force should be rejected.""" + pattern = DestructiveGitPattern() + is_dangerous, _ = pattern.check(["git", "push", "--force", "origin", "main"]) + assert is_dangerous is True + + def test_git_reset_hard_rejected(self) -> None: + """git reset --hard should be rejected.""" + pattern = DestructiveGitPattern() + is_dangerous, _ = pattern.check(["git", "reset", "--hard", "HEAD~1"]) + assert is_dangerous is True + + def test_git_clean_f_rejected(self) -> None: + """git clean -f should be rejected.""" + pattern = DestructiveGitPattern() + is_dangerous, _ = pattern.check(["git", "clean", "-f"]) + assert is_dangerous is True + + def test_git_log_allowed(self) -> None: + """git log should be allowed.""" + pattern = DestructiveGitPattern() + is_dangerous, _ = pattern.check(["git", "log", "--oneline"]) + assert is_dangerous is False + + +class TestDestructiveRmPattern: + """Tests for DestructiveRmPattern detection.""" + + def test_rm_rf_rejected(self) -> None: + """rm -rf should be rejected.""" + pattern = DestructiveRmPattern() + is_dangerous, _ = pattern.check(["rm", "-rf", "/home"]) + assert is_dangerous is True + + def test_rm_r_rejected(self) -> None: + """rm -r should be rejected.""" + pattern = DestructiveRmPattern() + is_dangerous, _ = pattern.check(["rm", "-r", "/home"]) + assert is_dangerous is True + + def test_rm_single_file_allowed(self) -> None: + """rm file.txt should be allowed.""" + pattern = DestructiveRmPattern() + is_dangerous, _ = pattern.check(["rm", "file.txt"]) + assert is_dangerous is False + + +class TestPatternRegistry: + """Tests for SECURITY_PATTERNS registry and composition.""" + + def test_all_patterns_registered(self) -> None: + """All pattern types should be in the registry.""" + pattern_types = {type(p).__name__ for p in SECURITY_PATTERNS} + expected = { + "DangerousCommandPattern", + "ShellOperatorPattern", + "CommandSubstitutionPattern", + "CodeExecutionPattern", + "DestructiveGitPattern", + "DestructiveRmPattern", + } + assert expected.issubset(pattern_types), ( + f"Missing patterns: {expected - pattern_types}" + ) + + def test_check_all_patterns_integration(self) -> None: + """check_all_patterns should integrate all registered patterns.""" + # Should catch sudo (DangerousCommandPattern) + is_dangerous, _ = check_all_patterns(["sudo", "echo"]) + assert is_dangerous is True + + # Should catch pipe (ShellOperatorPattern) + is_dangerous, _ = check_all_patterns(["cat", "file", "|", "grep"]) + assert is_dangerous is True + + # Should catch substitution (CommandSubstitutionPattern) + is_dangerous, _ = check_all_patterns(["echo", "$(date)"]) + assert is_dangerous is True + + # Should catch code execution (CodeExecutionPattern) + is_dangerous, _ = check_all_patterns(["python", "-c", "print(1)"]) + assert is_dangerous is True + + def test_get_pattern_names(self) -> None: + """get_pattern_names should return all pattern class names.""" + names = get_pattern_names() + assert isinstance(names, list) + assert len(names) == len(SECURITY_PATTERNS) + assert all(isinstance(n, str) for n in names) + + +class TestPatternExtensibility: + """Tests that new patterns can be added without breaking existing logic.""" + + def test_custom_pattern_can_be_created(self) -> None: + """New pattern subclasses should be creatable.""" + from mellea.stdlib.tools._bash_patterns import BashSecurityPattern + + class CustomPattern(BashSecurityPattern): + def check(self, argv: list[str]) -> tuple[bool, str]: + if argv and argv[0] == "dangerous_custom": + return True, "Custom dangerous command" + return False, "" + + pattern = CustomPattern() + is_dangerous, _ = pattern.check(["dangerous_custom", "arg"]) + assert is_dangerous is True + + def test_safe_command_passes_all_patterns(self) -> None: + """Safe commands should pass all patterns.""" + safe_commands = [ + ["echo", "hello"], + ["pwd"], + ["ls", "-la"], + ["cat", "file.txt"], + ["grep", "pattern", "file.txt"], + ] + + for cmd in safe_commands: + is_dangerous, reason = check_all_patterns(cmd) + assert is_dangerous is False, f"Safe command {cmd} failed with: {reason}" + + +class TestBashAuditTrail: + """Tests for audit trail recording and querying.""" + + def test_violation_recorded_on_pattern_rejection(self) -> None: + """Verify violation recorded when pattern rejects.""" + from mellea.stdlib.tools._bash_audit import BashAuditTrail + + trail = BashAuditTrail.get_instance() + trail.clear() + + check_all_patterns(["sudo", "echo"]) + violations = trail.get_violations() + assert len(violations) == 1 + assert violations[0].pattern == "DangerousCommandPattern" + + def test_violation_contains_correct_metadata(self) -> None: + """Verify violation has all required fields.""" + from mellea.stdlib.tools._bash_audit import BashAuditTrail + + trail = BashAuditTrail.get_instance() + trail.clear() + + check_all_patterns(["rm", "-rf", "/"]) + violations = trail.get_violations() + v = violations[0] + assert v.command == "rm -rf /" + assert v.severity in ("HIGH", "MEDIUM") + assert v.reason + assert v.timestamp > 0 + + def test_get_violations_filters_by_severity(self) -> None: + """Verify filter by severity works.""" + from mellea.stdlib.tools._bash_audit import BashAuditTrail + + trail = BashAuditTrail.get_instance() + trail.clear() + + check_all_patterns(["sudo", "ls"]) # CRITICAL + check_all_patterns(["rm", "-r", "/tmp"]) # HIGH + + critical = trail.get_violations(severity="CRITICAL") + assert len(critical) >= 1 + + def test_export_metrics_counts_violations(self) -> None: + """Verify metrics export includes violation counts.""" + from mellea.stdlib.tools._bash_audit import BashAuditTrail + + trail = BashAuditTrail.get_instance() + trail.clear() + + check_all_patterns(["sudo", "ls"]) + check_all_patterns(["rm", "-r", "/tmp"]) + + metrics = trail.export_metrics() + assert metrics["total"] == 2 + assert any(k.startswith("severity_") for k in metrics.keys()) + + def test_violations_cleared_between_tests(self) -> None: + """Verify clear() removes all violations.""" + from mellea.stdlib.tools._bash_audit import BashAuditTrail + + trail = BashAuditTrail.get_instance() + trail.clear() + + check_all_patterns(["sudo", "ls"]) + assert len(trail.get_violations()) == 1 + + trail.clear() + assert len(trail.get_violations()) == 0 + + def test_query_violations_with_pattern_filter(self) -> None: + """Verify filter by pattern name works.""" + from mellea.stdlib.tools._bash_audit import BashAuditTrail + + trail = BashAuditTrail.get_instance() + trail.clear() + + check_all_patterns(["sudo", "ls"]) # DangerousCommandPattern + check_all_patterns(["echo", "|", "grep"]) # ShellOperatorPattern + + dangerous = trail.get_violations(pattern="DangerousCommandPattern") + assert len(dangerous) >= 1 diff --git a/test/stdlib/tools/test_shell.py b/test/stdlib/tools/test_shell.py new file mode 100644 index 000000000..6ace6abf3 --- /dev/null +++ b/test/stdlib/tools/test_shell.py @@ -0,0 +1,987 @@ +"""Tests for bash shell execution environments.""" + +from unittest.mock import patch + +import pytest + +from mellea.stdlib.tools.shell import ( + StaticBashEnvironment, + _LocalBashEnvironment, + bash_executor, +) + + +class TestStaticBashEnvironment: + """Tests for static bash command parsing and validation.""" + + def test_parse_simple_command(self) -> None: + """Valid simple command should pass validation.""" + env = StaticBashEnvironment() + result = env.execute("echo hello") + + assert result.skipped is True + assert result.success is True + assert result.analysis_result == ["echo", "hello"] + + def test_parse_command_with_args(self) -> None: + """Command with quoted arguments should parse correctly.""" + env = StaticBashEnvironment() + result = env.execute('echo "hello world"') + + assert result.skipped is True + assert result.success is True + assert result.analysis_result == ["echo", "hello world"] + + def test_parse_empty_command(self) -> None: + """Empty command should be rejected.""" + env = StaticBashEnvironment() + result = env.execute("") + + assert result.skipped is True + assert result.success is False + assert result.skip_message is not None + assert "Empty command" in result.skip_message + + def test_parse_invalid_quoting(self) -> None: + """Command with invalid quoting should fail to parse.""" + env = StaticBashEnvironment() + result = env.execute('echo "unclosed quote') + + assert result.skipped is True + assert result.success is False + assert result.skip_message is not None + assert "parse" in result.skip_message.lower() + + +class TestDangerousCommandDetection: + """Tests for dangerous command detection.""" + + @pytest.mark.parametrize( + "dangerous_cmd", + [ + "sudo echo hello", + "su - root", + "doas whoami", + "sudo -i", + "sudo -s", + "bash -i", + "sh -i", + "zsh -i", + "passwd", + "visudo", + "chsh", + "chfn", + "useradd testuser", + "userdel testuser", + "usermod -l newname testuser", + ], + ) + def test_dangerous_commands_rejected(self, dangerous_cmd: str) -> None: + """Dangerous commands should be rejected at parse time.""" + env = StaticBashEnvironment() + result = env.execute(dangerous_cmd) + + assert result.skipped is True + assert result.success is False + assert result.skip_message is not None + assert "not allowed" in result.skip_message.lower() + + def test_safe_shell_commands_allowed(self) -> None: + """Non-interactive shell commands should be allowed.""" + env = StaticBashEnvironment() + + # bash/sh without -i flag should pass (might be used for scripting) + result = env.execute("bash script.sh") + assert result.success is True + assert result.skipped is True + + +class TestInterpreterIndirectionBypassAttempts: + """Tests for interpreter-indirection bypass attempts. + + Interpreter indirection occurs when a program (bash, python, env, timeout, etc.) + is used to run arbitrary code. These are separate from simple command execution + and need explicit testing to ensure the safety checks cover them. + """ + + def test_bash_c_string_rejected(self) -> None: + """bash -c with arbitrary code should be rejected.""" + env = StaticBashEnvironment() + # bash -c runs a command string, can bypass argv parsing + result = env.execute("bash -c 'rm -rf /'") + + assert result.skipped is True + assert result.success is False + assert result.skip_message is not None + # Should reject either the -c flag or command substitution + assert "not allowed" in result.skip_message.lower() + + def test_sh_c_string_rejected(self) -> None: + """sh -c with arbitrary code should be rejected.""" + env = StaticBashEnvironment() + result = env.execute("sh -c 'sudo echo'") + + assert result.skipped is True + assert result.success is False + assert result.skip_message is not None + assert "not allowed" in result.skip_message.lower() + + def test_env_with_sudo_rejected(self) -> None: + """env with sudo should be rejected (privilege escalation).""" + env = StaticBashEnvironment() + # env is sometimes used to set environment vars for sudo + result = env.execute("env sudo bash") + + assert result.skipped is True + assert result.success is False + assert result.skip_message is not None + # Should reject sudo + assert "not allowed" in result.skip_message.lower() + + def test_timeout_with_sudo_rejected(self) -> None: + """timeout with sudo should be rejected (privilege escalation).""" + env = StaticBashEnvironment() + result = env.execute("timeout 10 sudo whoami") + + assert result.skipped is True + assert result.success is False + assert result.skip_message is not None + assert "not allowed" in result.skip_message.lower() + + def test_timeout_with_flag_value_and_sudo_rejected(self) -> None: + """timeout with --kill-after=value and sudo should be rejected (checks value-taking flags).""" + env = StaticBashEnvironment() + # Regression test: ensure sudo is detected despite --kill-after=1 consuming the value + result = env.execute("timeout --kill-after=1 sudo whoami") + + assert result.skipped is True + assert result.success is False + assert result.skip_message is not None + assert "not allowed" in result.skip_message.lower() + + def test_env_i_with_sudo_rejected(self) -> None: + """env -i with sudo should be rejected (privilege escalation bypass attempt).""" + env = StaticBashEnvironment() + # Regression test for CVE-like: env -i (clear environment) + sudo + # -i is NOT a value-taking flag; it takes no argument. + # The skip logic must not incorrectly skip sudo. + result = env.execute("env -i sudo whoami") + + assert result.skipped is True + assert result.success is False + assert result.skip_message is not None + assert "not allowed" in result.skip_message.lower() + + def test_env_i_with_dangerous_rm_rejected(self) -> None: + """env -i with rm -rf should be rejected (destructive bypass attempt).""" + env = StaticBashEnvironment() + result = env.execute("env -i rm -rf /tmp/test") + + assert result.skipped is True + assert result.success is False + assert result.skip_message is not None + assert "not allowed" in result.skip_message.lower() + + def test_env_with_dangerous_rm_rejected(self) -> None: + """env with rm -rf should be rejected.""" + env = StaticBashEnvironment() + result = env.execute("env rm -rf /tmp/test") + + assert result.skipped is True + assert result.success is False + assert result.skip_message is not None + assert "not allowed" in result.skip_message.lower() + + def test_timeout_with_dangerous_rm_rejected(self) -> None: + """timeout with rm -rf should be rejected.""" + env = StaticBashEnvironment() + result = env.execute("timeout 10 rm -rf /tmp/test") + + assert result.skipped is True + assert result.success is False + assert result.skip_message is not None + assert "not allowed" in result.skip_message.lower() + + def test_env_with_safe_rm_allowed(self) -> None: + """env with rm (no -r/-rf) should be allowed.""" + env = StaticBashEnvironment() + result = env.execute("env rm file.txt") + + assert result.skipped is True + assert result.success is True + + def test_timeout_with_safe_rm_allowed(self) -> None: + """timeout with rm (no -r/-rf) should be allowed.""" + env = StaticBashEnvironment() + result = env.execute("timeout 10 rm file.txt") + + assert result.skipped is True + assert result.success is True + + def test_env_with_dangerous_command_in_middle_rejected(self) -> None: + """env with variable assignment followed by dangerous command should be rejected.""" + env = StaticBashEnvironment() + result = env.execute("env LD_LIBRARY_PATH=/lib sudo whoami") + + assert result.skipped is True + assert result.success is False + assert result.skip_message is not None + assert "not allowed" in result.skip_message.lower() + + def test_python_c_arbitrary_code_rejected(self) -> None: + """python -c with arbitrary code should be rejected.""" + env = StaticBashEnvironment() + result = env.execute("python3 -c 'import os; os.system(\"rm -rf /\")'") + + assert result.skipped is True + assert result.success is False + assert result.skip_message is not None + # Should reject the -c flag or command substitution + assert "not allowed" in result.skip_message.lower() + + def test_python_multiline_code_rejected(self) -> None: + """python with multiline code (using \\n) should be rejected.""" + env = StaticBashEnvironment() + # Even with \\n instead of ; to bypass semicolon check + result = env.execute("python3 -c 'import os\\nos.system(\"bad\")'") + + assert result.skipped is True + assert result.success is False + assert result.skip_message is not None + assert "not allowed" in result.skip_message.lower() + + def test_bash_script_file_allowed(self) -> None: + """bash with a script file (not -c) should be allowed for scripts.""" + env = StaticBashEnvironment() + result = env.execute("bash /path/to/script.sh") + + # Should be allowed (script execution is legitimate) + assert result.skipped is True + assert result.success is True + + def test_perl_e_code_rejected(self) -> None: + """perl -e with inline code should be rejected.""" + env = StaticBashEnvironment() + result = env.execute("perl -e 'system(\"sudo whoami\")'") + + assert result.skipped is True + assert result.success is False + assert result.skip_message is not None + assert "not allowed" in result.skip_message.lower() + + def test_ruby_e_code_rejected(self) -> None: + """ruby -e with inline code should be rejected.""" + env = StaticBashEnvironment() + result = env.execute("ruby -e 'system(\"rm -rf /\")'") + + assert result.skipped is True + assert result.success is False + assert result.skip_message is not None + assert "not allowed" in result.skip_message.lower() + + @pytest.mark.parametrize( + "safe_cmd", + [ + "echo hello", + "pwd", + "ls -la", + "cat file.txt", + "grep pattern file.txt", + "find . -name '*.py'", + ], + ) + def test_safe_commands_allowed(self, safe_cmd: str) -> None: + """Safe commands should pass validation.""" + env = StaticBashEnvironment() + result = env.execute(safe_cmd) + + assert result.skipped is True + assert result.success is True + assert result.analysis_result is not None + + +class TestDestructivePatternDetection: + """Tests for detection of destructive operations.""" + + @pytest.mark.parametrize( + "destructive_cmd", + [ + "rm -rf /", + "rm -r /home/user", + "rm -rf .", + "git push --force origin main", + "git push -f", + "git reset --hard HEAD~1", + "git clean -fd", + ], + ) + def test_destructive_operations_rejected(self, destructive_cmd: str) -> None: + """Destructive operations should be rejected.""" + env = StaticBashEnvironment() + result = env.execute(destructive_cmd) + + assert result.skipped is True + assert result.success is False + assert result.skip_message is not None + assert "not allowed" in result.skip_message.lower() + + def test_standard_operations_with_flags_allowed(self) -> None: + """Standard operations with force flags should be allowed. + + These are not inherently destructive: + - cp -f: copy with force overwrite (standard) + - mv -f: move with force overwrite (standard) + - make -f: specify makefile (standard) + """ + env = StaticBashEnvironment() + + result = env.execute("cp -f largefile /tmp") + assert result.skipped is True + assert result.success is True + + result = env.execute("mv -f file /tmp") + assert result.skipped is True + assert result.success is True + + result = env.execute("make -f Makefile clean") + assert result.skipped is True + assert result.success is True + + def test_safe_git_operations_allowed(self) -> None: + """Safe git operations without --force should be allowed.""" + env = StaticBashEnvironment() + + result = env.execute("git push origin main") + assert result.success is True + assert result.skipped is True + + def test_safe_rm_operations_allowed(self) -> None: + """rm without -r/-rf flags should be allowed.""" + env = StaticBashEnvironment() + + result = env.execute("rm file.txt") + assert result.success is True + assert result.skipped is True + + +class TestShellOperatorFalsePositives: + """Tests for legitimate patterns that were previously false positives. + + The shell operator detection originally used substring matching, + which blocked legitimate patterns like "a&&b" (regex). These tests + verify that the fix correctly allows such patterns while still + blocking actual shell operators. + """ + + def test_grep_with_and_in_pattern(self) -> None: + """grep with && in regex pattern should be allowed.""" + env = StaticBashEnvironment() + result = env.execute("grep 'a&&b' file.txt") + + assert result.skipped is True + assert result.success is True + + def test_grep_with_or_in_pattern(self) -> None: + """grep with || in regex pattern should be allowed.""" + env = StaticBashEnvironment() + result = env.execute("grep 'a||b' file.txt") + + assert result.skipped is True + assert result.success is True + + def test_echo_with_redirect_symbol_in_string(self) -> None: + """echo with >> in string should be allowed.""" + env = StaticBashEnvironment() + result = env.execute("echo 'a>>b'") + + assert result.skipped is True + assert result.success is True + + def test_grep_with_heredoc_symbol_in_pattern(self) -> None: + """grep with << in pattern should be allowed.""" + env = StaticBashEnvironment() + result = env.execute("grep 'x< None: + """awk code with >> in pattern should be allowed.""" + env = StaticBashEnvironment() + result = env.execute('awk "{print $1>>$2}"') + + assert result.skipped is True + assert result.success is True + + def test_sed_with_pipe_in_pattern(self) -> None: + """sed pattern with | should be allowed.""" + env = StaticBashEnvironment() + result = env.execute("sed 's/a|b/c/'") + + assert result.skipped is True + assert result.success is True + + def test_actual_shell_redirect_operator_blocked(self) -> None: + """Actual shell redirect operators with arguments should be blocked.""" + env = StaticBashEnvironment() + # >&2 is a shell redirect (stderr redirect) + result = env.execute("echo 'test' >&2") + + assert result.skipped is True + assert result.success is False + assert result.skip_message is not None + assert "not allowed" in result.skip_message.lower() + + def test_redirect_to_file_blocked(self) -> None: + """Redirect to file (>filename) should be blocked.""" + env = StaticBashEnvironment() + result = env.execute("echo 'test' >output.txt") + + assert result.skipped is True + assert result.success is False + assert result.skip_message is not None + assert "not allowed" in result.skip_message.lower() + + +class TestShellMetacharacterDetection: + """Tests for detection of shell metacharacters that bypass argv parsing.""" + + @pytest.mark.parametrize( + "metacharacter_cmd", + [ + "echo error >&2", # stderr redirection + "echo hello > /tmp/file", # stdout redirection + "cat file | grep pattern", # pipe + "echo a; rm -rf /", # command chaining + "echo $(whoami)", # command substitution + "echo `date`", # backtick substitution + "echo ${HOME}", # variable expansion with braces + "ls &", # background execution + "find . -name '*.py' && echo done", # logical AND + "ls || echo failed", # logical OR + ], + ) + def test_shell_metacharacters_rejected(self, metacharacter_cmd: str) -> None: + """Shell metacharacters should be rejected to prevent bypass attacks.""" + env = StaticBashEnvironment() + result = env.execute(metacharacter_cmd) + + assert result.skipped is True + assert result.success is False + assert result.skip_message is not None + assert "not allowed" in result.skip_message.lower() + + +class TestMacOSPrivateVarHandling: + """Tests for correct handling of macOS /private/var paths. + + On macOS, tempfile.mkdtemp() returns /var/folders/... which resolves to + /private/var/folders/... . We should allow writes to /private/var/folders/* + (user temp directories) while blocking /private/var/log, /private/var/www, etc. + """ + + def test_private_var_log_blocked(self) -> None: + """Writing to /private/var/log should be blocked.""" + env = StaticBashEnvironment() + result = env.execute("touch /private/var/log/test.log") + + assert result.skipped is True + assert result.success is False + assert result.skip_message is not None + assert "not allowed" in result.skip_message.lower() + + def test_private_var_www_blocked(self) -> None: + """Writing to /private/var/www should be blocked.""" + env = StaticBashEnvironment() + result = env.execute("touch /private/var/www/index.html") + + assert result.skipped is True + assert result.success is False + assert result.skip_message is not None + assert "not allowed" in result.skip_message.lower() + + def test_private_var_db_blocked(self) -> None: + """Writing to /private/var/db should be blocked.""" + env = StaticBashEnvironment() + result = env.execute("touch /private/var/db/test.db") + + assert result.skipped is True + assert result.success is False + assert result.skip_message is not None + assert "not allowed" in result.skip_message.lower() + + def test_private_var_root_blocked(self) -> None: + """Writing to /private/var/root should be blocked.""" + env = StaticBashEnvironment() + result = env.execute("touch /private/var/root/test.txt") + + assert result.skipped is True + assert result.success is False + assert result.skip_message is not None + assert "not allowed" in result.skip_message.lower() + + def test_private_var_folders_allowed(self) -> None: + """Writing to /private/var/folders/* (macOS temp dirs) should be allowed.""" + env = StaticBashEnvironment() + # This simulates a resolved path from tempfile on macOS + result = env.execute("touch /private/var/folders/kl/tmpXXXX/test.txt") + + # Should pass validation (not marked as dangerous path) + assert result.skipped is True + assert result.success is True + + def test_macos_temp_directory_resolved_allowed(self) -> None: + """Resolved macOS temp directory paths should be allowed.""" + import tempfile + + with tempfile.TemporaryDirectory() as tmpdir: + from pathlib import Path + + # Resolve the temp dir (on macOS this becomes /private/var/folders/...) + resolved = str(Path(tmpdir).resolve()) + + env = StaticBashEnvironment() + result = env.execute(f"touch {resolved}/test.txt") + + # Should pass (temp dir is safe) + assert result.skipped is True + assert result.success is True + + +class TestSystemPathDetection: + """Tests for detection of system path access.""" + + @pytest.mark.parametrize( + "system_path_cmd", + [ + "rm /etc/passwd", + "touch /etc/config.conf", + "cp file /sys/module", + "mkdir /proc/newdir", + ], + ) + def test_system_paths_rejected(self, system_path_cmd: str) -> None: + """Attempts to write to system paths should be rejected.""" + env = StaticBashEnvironment() + result = env.execute(system_path_cmd) + + assert result.skipped is True + assert result.success is False + assert result.skip_message is not None + assert "not allowed" in result.skip_message.lower() + + @pytest.mark.parametrize( + "safe_path_cmd", + [ + "cat /etc/passwd", # Reading is OK + "ls /sys", # Reading is OK + "touch ~/file.txt", # Writing to home is OK + "mkdir /tmp/tmpdir", # Writing to /tmp is OK + ], + ) + def test_safe_paths_allowed(self, safe_path_cmd: str) -> None: + """Safe path operations should be allowed.""" + env = StaticBashEnvironment() + result = env.execute(safe_path_cmd) + + assert result.skipped is True + assert result.success is True + + def test_symlink_to_dangerous_path_rejected(self) -> None: + """Symlinks pointing to dangerous paths should be rejected.""" + import os + import tempfile + + with tempfile.TemporaryDirectory() as tmpdir: + # Create a symlink in /tmp pointing to /etc + symlink_path = os.path.join(tmpdir, "link_to_etc") + try: + os.symlink("/etc", symlink_path) + except OSError: + # Skip if symlink creation fails (e.g., on some filesystems) + pytest.skip("Cannot create symlinks on this system") + + # Try to write through the symlink + env = StaticBashEnvironment() + result = env.execute(f"touch {symlink_path}/config") + + # Should be rejected because symlink resolves to /etc + assert result.skipped is True + assert result.success is False + assert result.skip_message is not None + assert "not allowed" in result.skip_message.lower() + + +class TestWorkingDirRestriction: + """Tests for working directory restrictions.""" + + def test_working_dir_restriction_blocks_outside_writes(self) -> None: + """Writing outside working_dir should be rejected by working_dir check.""" + env = StaticBashEnvironment(working_dir="/home/user/project") + # Use a safe path that is not in DANGEROUS_PATHS (so working_dir check fires first) + result = env.execute("touch /home/other/file.txt") + + assert result.skipped is True + assert result.success is False + assert result.skip_message is not None + # Must be rejected by working_dir check, not dangerous-path check + assert "outside" in result.skip_message.lower() + + def test_working_dir_allows_inside_writes(self) -> None: + """Writing inside working_dir should be allowed.""" + env = StaticBashEnvironment(working_dir="/home/user/project") + result = env.execute("touch /home/user/project/test.txt") + + assert result.skipped is True + assert result.success is True + + def test_working_dir_allows_tmp_writes(self) -> None: + """Writing to /tmp should always be allowed.""" + env = StaticBashEnvironment(working_dir="/home/user/project") + result = env.execute("touch /tmp/tmpfile") + + assert result.skipped is True + assert result.success is True + + def test_working_dir_relative_path_resolved_within_working_dir(self) -> None: + """Relative paths should be resolved relative to working_dir, not caller's cwd.""" + import tempfile + + with tempfile.TemporaryDirectory() as tmpdir: + # Create a file relative to working_dir (not caller's cwd) + env = StaticBashEnvironment(working_dir=tmpdir) + result = env.execute("touch myfile.txt") + + # Should be allowed: relative path resolves to tmpdir/myfile.txt + assert result.skipped is True + assert result.success is True + + def test_working_dir_relative_path_blocks_outside(self) -> None: + """Relative paths that escape working_dir should be rejected.""" + import tempfile + + with tempfile.TemporaryDirectory() as tmpdir: + # Try to write outside working_dir using relative path + env = StaticBashEnvironment(working_dir=tmpdir) + result = env.execute("touch ../outside.txt") + + # Should be rejected: ../outside.txt escapes working_dir + assert result.skipped is True + assert result.success is False + assert result.skip_message is not None + assert "outside" in result.skip_message.lower() + + def test_working_dir_unresolvable_fails_closed(self) -> None: + """Unresolvable working_dir should fail closed (deny writes).""" + env = StaticBashEnvironment(working_dir="~invalid/nonexistent") + # Invalid home dir prefix causes RuntimeError in .resolve() + result = env.execute("touch /tmp/test.txt") + + # Should be rejected: can't resolve working_dir + assert result.skipped is True + assert result.success is False + assert result.skip_message is not None + assert ( + "not resolvable" in result.skip_message.lower() + or "cannot validate" in result.skip_message.lower() + ) + + def test_working_dir_unresolvable_blocks_even_etc(self) -> None: + """Unresolvable working_dir should block attempts to write to /etc.""" + env = StaticBashEnvironment(working_dir="~invalid/path") + result = env.execute("touch /etc/config") + + # Should be blocked, first by /etc check, but verify it fails + assert result.skipped is True + assert result.success is False + + +class TestPathResolutionFailures: + """Tests for fail-closed behavior when path resolution fails.""" + + def test_unresolvable_argument_path_fails_closed(self) -> None: + """Unresolvable argument paths should fail closed (deny writes).""" + import tempfile + + with tempfile.TemporaryDirectory() as tmpdir: + env = StaticBashEnvironment(working_dir=tmpdir) + # Invalid home dir prefix in argument causes RuntimeError + result = env.execute("touch ~invalid/file.txt") + + # Should be rejected: can't resolve argument path + assert result.skipped is True + assert result.success is False + assert result.skip_message is not None + # Error should mention path resolution failure + assert ( + "cannot validate" in result.skip_message.lower() + or "resolution failed" in result.skip_message.lower() + ) + + +class TestAllowedPaths: + """Tests for explicit allowed path enforcement.""" + + def test_allowed_paths_allows_write_inside_explicit_path(self) -> None: + """Writing inside an explicit allowed path should be permitted.""" + env = StaticBashEnvironment(allowed_paths=["/home/user/project"]) + result = env.execute("touch /home/user/project/output.txt") + + assert result.skipped is True + assert result.success is True + + def test_allowed_paths_blocks_write_outside_explicit_path(self) -> None: + """Writing outside explicit allowed paths should be rejected.""" + env = StaticBashEnvironment(allowed_paths=["/home/user/project"]) + result = env.execute("touch /home/user/other/output.txt") + + assert result.skipped is True + assert result.success is False + assert result.skip_message is not None + assert "outside explicitly allowed paths" in result.skip_message.lower() + + def test_allowed_paths_does_not_override_dangerous_paths(self) -> None: + """Explicit allowed paths must not permit writes to dangerous system paths.""" + env = StaticBashEnvironment(allowed_paths=["/etc"]) + result = env.execute("touch /etc/config.conf") + + assert result.skipped is True + assert result.success is False + assert result.skip_message is not None + assert "not allowed" in result.skip_message.lower() + + +class TestLocalBashEnvironment: + """Tests for local bash environment execution (no isolation).""" + + def test_safe_command_execution(self) -> None: + """Safe commands should execute successfully.""" + env = _LocalBashEnvironment() + result = env.execute("echo hello") + + assert result.skipped is False + assert result.success is True + assert result.stdout is not None + assert "hello" in result.stdout + + def test_command_with_failing_exit_code(self) -> None: + """Commands with non-zero exit should fail.""" + env = _LocalBashEnvironment() + result = env.execute("false") + + assert result.skipped is False + assert result.success is False + + def test_shell_metacharacters_rejected(self) -> None: + """Shell redirections and pipes should be rejected for security.""" + env = _LocalBashEnvironment() + result = env.execute("echo error >&2") + + assert result.skipped is True + assert result.success is False + assert result.skip_message is not None + assert "not allowed" in result.skip_message.lower() + + def test_dangerous_command_rejected(self) -> None: + """Dangerous commands should be rejected even before execution.""" + env = _LocalBashEnvironment() + result = env.execute("sudo echo test") + + assert result.skipped is True + assert result.success is False + assert result.skip_message is not None + assert "not allowed" in result.skip_message.lower() + + def test_timeout_enforcement(self) -> None: + """Command exceeding timeout should be interrupted.""" + env = _LocalBashEnvironment(timeout=1) + result = env.execute("sleep 5") + + assert result.skipped is True + assert result.success is False + assert result.skip_message is not None + assert "timed out" in result.skip_message.lower() + + def test_output_truncation(self) -> None: + """Very large output should be truncated.""" + from mellea.stdlib.tools.shell import MAX_OUTPUT_SIZE + + env = _LocalBashEnvironment() + # Generate output larger than MAX_OUTPUT_SIZE (10KB) to trigger truncation + import os + import tempfile + + with tempfile.TemporaryDirectory() as tmpdir: + # Create a large file to cat (500 lines of 51 bytes each = ~25.5 KB) + large_file = os.path.join(tmpdir, "large.txt") + with open(large_file, "w") as f: + for _ in range(500): + f.write("x" * 50 + "\n") + + result = env.execute(f"cat {large_file}") + + assert result.success is True + # Check that output was truncated + assert result.stdout is not None + # Verify truncation marker is present + assert "Output truncated" in result.stdout + # Verify output is actually truncated (not full ~25KB content) + # Truncated output should be MAX_OUTPUT_SIZE + marker message + # Allow some slack for the exact message format + assert len(result.stdout) <= MAX_OUTPUT_SIZE + 100 + + def test_working_dir_parameter(self) -> None: + """working_dir should be passed to subprocess.""" + import tempfile + from pathlib import Path + + with tempfile.TemporaryDirectory() as tmpdir: + env = _LocalBashEnvironment(working_dir=tmpdir) + result = env.execute("pwd") + + assert result.success is True + assert result.stdout is not None + assert tmpdir in result.stdout + + +class TestBashExecutorFunctions: + """Tests for public bash_executor function.""" + + def test_bash_executor_uses_local_by_default(self) -> None: + """bash_executor should use _LocalBashEnvironment by default (no sandbox).""" + result = bash_executor("echo test") + + # bash_executor with no sandbox parameter should always succeed (uses local execution) + assert result.success is True + assert result.stdout is not None + assert "test" in result.stdout + + def test_bash_executor_with_working_dir(self) -> None: + """bash_executor should pass working_dir through to sandbox execution.""" + import tempfile + + with tempfile.TemporaryDirectory() as tmpdir: + result = bash_executor("pwd", working_dir=tmpdir) + + if result.skip_message is not None and ( + "not installed" in result.skip_message + or "sandbox execution error" in result.skip_message.lower() + ): + assert result.skipped is True + else: + assert result.success is True + assert result.stdout is not None + assert tmpdir in result.stdout + + def test_bash_executor_with_allowed_paths(self) -> None: + """bash_executor should accept allowed_paths parameter.""" + # Just verify the parameter is accepted (actual execution may skip on sandbox) + result = bash_executor( + "echo test", allowed_paths=["/tmp", "/home/user/project"] + ) + + # Either executes or skips due to sandbox setup + if result.skip_message is not None and ( + "not installed" in result.skip_message + or "not a valid" in result.skip_message + ): + assert result.skipped is True + else: + assert result.success is True + + def test_bash_executor_local_execution(self) -> None: + """bash_executor should execute locally.""" + result = bash_executor("echo hello") + + assert result.success is True + assert result.stdout is not None + assert "hello" in result.stdout + + def test_dangerous_command_rejected(self) -> None: + """Dangerous commands should be rejected.""" + result = bash_executor("sudo echo test") + + assert result.skipped is True + assert result.success is False + assert result.skip_message is not None + assert "not allowed" in result.skip_message.lower() + + +class TestCommandParsing: + """Tests for command parsing and quoting handling.""" + + def test_command_with_spaces_in_quotes(self) -> None: + """Arguments with spaces should be handled correctly.""" + env = StaticBashEnvironment() + result = env.execute('echo "hello world" "foo bar"') + + assert result.success is True + assert result.analysis_result == ["echo", "hello world", "foo bar"] + + def test_command_with_escaped_quotes(self) -> None: + """Escaped quotes should be handled.""" + env = StaticBashEnvironment() + result = env.execute(r'echo "say \"hello\""') + + assert result.success is True + assert result.analysis_result is not None + assert "echo" in result.analysis_result + + def test_command_with_equals_in_args(self) -> None: + """Arguments with = (like env vars) should parse correctly.""" + env = StaticBashEnvironment() + result = env.execute("grep FOO=bar file.txt") + + assert result.success is True + assert result.analysis_result is not None + assert "FOO=bar" in result.analysis_result + + +class TestErrorMessages: + """Tests for clear error messages.""" + + def test_sudo_rejection_message(self) -> None: + """sudo rejection should have clear message.""" + env = StaticBashEnvironment() + result = env.execute("sudo apt-get install package") + + assert result.skip_message is not None + assert "not allowed" in result.skip_message.lower() + + def test_dangerous_flag_rejection_message(self) -> None: + """Dangerous git operation rejection should mention the issue.""" + env = StaticBashEnvironment() + result = env.execute("git push --force") + + assert result.skip_message is not None + assert ( + "destructive" in result.skip_message.lower() + or "--force" in result.skip_message + or "force" in result.skip_message.lower() + ) + + def test_system_path_rejection_message(self) -> None: + """System path rejection should mention the path.""" + env = StaticBashEnvironment() + result = env.execute("rm /etc/passwd") + + assert result.skip_message is not None + assert "/etc" in result.skip_message or "allowed" in result.skip_message.lower() + + +@pytest.mark.integration +def test_tool_wrapping() -> None: + """Test that bash_executor can be wrapped as a MelleaTool.""" + try: + from mellea.backends.tools import MelleaTool + + tool = MelleaTool.from_callable(bash_executor) + + assert tool.name == "bash_executor" + # Check that the tool schema is generated correctly + schema = tool.as_json_tool + assert "parameters" in schema or "function" in schema # Schema format may vary + # The tool should be callable + assert callable(tool.run) + except ImportError: + pytest.skip("MelleaTool not available")