Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,16 @@
import subprocess
import sys
import time

from unittest import mock
from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, MlException


def _print_command_results(test_passed, time_taken, output):
print("Command {} in {} seconds.".format("successful" if test_passed else "failed", time_taken))
print(
"Command {} in {} seconds.".format(
"successful" if test_passed else "failed", time_taken
)
)
print("Output: \n{}\n".format(output))


Expand All @@ -27,36 +31,51 @@ def run_cli_command(
if not custom_environment:
custom_environment = os.environ

# We do this join to construct a command because "shell=True" flag, used below, doesn't work with the vector
# argv form on a mac OS.
command_to_execute = " ".join(cmd_arguments)
# Use argv form with shell=False to avoid shell injection risks while keeping behavior
# consistent across platforms (including macOS).
# On Windows, many CLI tools (e.g., "code") are .cmd/.bat shims that require shell
# execution. We use subprocess.list2cmdline to safely quote the arguments before
# passing them to the shell, preventing command injection.

if not do_not_print: # Avoid printing the az login service principal password, for example
print("Preparing to run CLI command: \n{}\n".format(command_to_execute))
if (
not do_not_print
): # Avoid printing the az login service principal password, for example
print("Preparing to run CLI command: \n{}\n".format(" ".join(cmd_arguments)))
print("Current directory: {}".format(os.getcwd()))

start_time = time.time()
try:
# We redirect stderr to stdout, so that in the case of an error, especially in negative tests,
# we get the error reply back to check if the error is expected or not.
# We need "shell=True" flag so that the "az" wrapper works.

# We also pass the environment variables, because for some tests we modify
# the environment variables.

subprocess_args = {
"shell": True,
"stderr": subprocess.STDOUT,
"env": custom_environment,
}

if not stderr_to_stdout:
subprocess_args = {"shell": True, "env": custom_environment}
subprocess_args = {"env": custom_environment}

if sys.version_info[0] != 2:
subprocess_args["timeout"] = timeout

output = subprocess.check_output(command_to_execute, **subprocess_args).decode(encoding="UTF-8")
# On Windows, many CLI commands are provided as .cmd/.bat shims that require
# shell execution. Use list2cmdline to build a safely quoted command string
# when invoking via the shell.
if os.name == "nt":
command_to_execute = subprocess.list2cmdline(cmd_arguments)
subprocess_args["shell"] = True
cmd_to_run = command_to_execute
else:
subprocess_args["shell"] = False
cmd_to_run = cmd_arguments

output = subprocess.check_output(cmd_to_run, **subprocess_args).decode(
encoding="UTF-8"
)

time_taken = time.time() - start_time
if not do_not_print:
Expand Down Expand Up @@ -109,3 +128,53 @@ def exclude_warnings(cmd_output):
curr_index = curr_index + 1

return json_output


def _test_run_cli_command_stderr_to_stdout_true():
"""Internal test to validate subprocess arguments when stderr_to_stdout is True."""
cmd = ["echo", "hello"]
custom_env = {"FOO": "BAR"}
with mock.patch("subprocess.check_output") as check_output_mock:
check_output_mock.return_value = b""
run_cli_command(
cmd_arguments=cmd,
custom_environment=custom_env,
return_json=False,
timeout=None,
do_not_print=True,
stderr_to_stdout=True,
)
# Verify argv (first positional argument) is passed through unchanged.
assert check_output_mock.call_args is not None
called_args, called_kwargs = check_output_mock.call_args
assert called_args[0] == cmd
# Verify shell and stderr behavior.
assert called_kwargs.get("shell") is False
assert called_kwargs.get("stderr") is subprocess.STDOUT
# Verify environment is forwarded.
assert called_kwargs.get("env") == custom_env


def _test_run_cli_command_stderr_to_stdout_false():
"""Internal test to validate subprocess arguments when stderr_to_stdout is False."""
cmd = ["echo", "hello"]
custom_env = {"FOO": "BAR"}
with mock.patch("subprocess.check_output") as check_output_mock:
check_output_mock.return_value = b""
run_cli_command(
cmd_arguments=cmd,
custom_environment=custom_env,
return_json=False,
timeout=None,
do_not_print=True,
stderr_to_stdout=False,
)
# Verify argv (first positional argument) is passed through unchanged.
assert check_output_mock.call_args is not None
called_args, called_kwargs = check_output_mock.call_args
assert called_args[0] == cmd
# Verify shell behavior and absence of stderr redirection.
assert called_kwargs.get("shell") is False
assert "stderr" not in called_kwargs
# Verify environment is forwarded.
assert called_kwargs.get("env") == custom_env
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

"""Unit tests for commandline_utility module."""

import subprocess
import os
from unittest import mock

import pytest

from azure.ai.ml._local_endpoints.utilities.commandline_utility import run_cli_command


class TestRunCliCommand:
"""Tests for run_cli_command function."""

def test_stderr_to_stdout_true_passes_correct_args(self):
"""Verify stderr=STDOUT is passed when stderr_to_stdout=True."""
cmd = ["echo", "hello"]
custom_env = {"FOO": "BAR"}

with mock.patch("subprocess.check_output") as check_output_mock:
check_output_mock.return_value = b"hello"
run_cli_command(
cmd_arguments=cmd,
custom_environment=custom_env,
return_json=False,
timeout=None,
do_not_print=True,
stderr_to_stdout=True,
)

assert check_output_mock.call_args is not None
called_args, called_kwargs = check_output_mock.call_args

if os.name == "nt":
assert called_args[0] == subprocess.list2cmdline(cmd)
assert called_kwargs.get("shell") is True
else:
assert called_args[0] == cmd
assert called_kwargs.get("shell") is False

assert called_kwargs.get("stderr") is subprocess.STDOUT
assert called_kwargs.get("env") == custom_env

def test_stderr_to_stdout_false_omits_stderr(self):
"""Verify stderr is not passed when stderr_to_stdout=False."""
cmd = ["echo", "hello"]
custom_env = {"FOO": "BAR"}

with mock.patch("subprocess.check_output") as check_output_mock:
check_output_mock.return_value = b"hello"
run_cli_command(
cmd_arguments=cmd,
custom_environment=custom_env,
return_json=False,
timeout=None,
do_not_print=True,
stderr_to_stdout=False,
)

assert check_output_mock.call_args is not None
_, called_kwargs = check_output_mock.call_args

assert "stderr" not in called_kwargs
assert called_kwargs.get("env") == custom_env

def test_shell_metacharacters_not_interpreted(self):
"""Verify shell metacharacters are not interpreted in arguments."""
malicious_arg = "safe_path; echo INJECTED; #"
cmd = ["echo", malicious_arg]

with mock.patch("subprocess.check_output") as check_output_mock:
check_output_mock.return_value = b"output"
run_cli_command(
cmd_arguments=cmd,
do_not_print=True,
)

assert check_output_mock.call_args is not None
called_args, called_kwargs = check_output_mock.call_args

if os.name == "nt":
command_str = called_args[0]
assert called_kwargs.get("shell") is True
assert command_str == subprocess.list2cmdline(cmd)
else:
assert called_args[0] == cmd
assert called_kwargs.get("shell") is False

def test_return_json_parses_output(self):
"""Verify JSON output is parsed correctly."""
cmd = ["echo", "test"]
# exclude_warnings expects { and } on separate lines
json_output = b'{\n"key": "value"\n}'

with mock.patch("subprocess.check_output") as check_output_mock:
check_output_mock.return_value = json_output
result = run_cli_command(
cmd_arguments=cmd,
return_json=True,
do_not_print=True,
)

assert result == {"key": "value"}

def test_called_process_error_is_raised(self):
"""Verify CalledProcessError propagates correctly."""
cmd = ["bad_command"]

with mock.patch("subprocess.check_output") as check_output_mock:
check_output_mock.side_effect = subprocess.CalledProcessError(
1, cmd, output=b"error output"
)
with pytest.raises(subprocess.CalledProcessError):
run_cli_command(
cmd_arguments=cmd,
do_not_print=True,
)
Loading