Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 52 additions & 2 deletions sdk/ml/azure-ai-ml/azure/ai/ml/operations/_local_job_invoker.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,15 @@
def unzip_to_temporary_file(job_definition: JobBaseData, zip_content: Any) -> Path:
temp_dir = Path(tempfile.gettempdir(), AZUREML_RUNS_DIR, job_definition.name)
temp_dir.mkdir(parents=True, exist_ok=True)
resolved_temp_dir = temp_dir.resolve()
with zipfile.ZipFile(io.BytesIO(zip_content)) as zip_ref:
for member in zip_ref.namelist():
member_path = (resolved_temp_dir / member).resolve()
# Ensure the member extracts within temp_dir (allow temp_dir itself for directory entries)
if member_path != resolved_temp_dir and not str(member_path).startswith(
str(resolved_temp_dir) + os.sep
):
raise ValueError(f"Zip archive contains a path traversal entry and cannot be extracted safely: {member}")
zip_ref.extractall(temp_dir)
return temp_dir

Expand Down Expand Up @@ -166,6 +174,49 @@ def is_local_run(job_definition: JobBaseData) -> bool:
local = job_definition.properties.services.get("Local", None)
return local is not None and EXECUTION_SERVICE_URL_KEY in local.endpoint

def _safe_tar_extractall(tar: tarfile.TarFile, dest_dir: str) -> None:
"""Extract tar archive members safely, preventing path traversal (TarSlip).

On Python 3.12+, uses the built-in 'data' filter. On older versions,
manually validates each member to ensure no path traversal, symlinks,
hard links, or other special entries that could write outside the
destination directory or create unsafe filesystem nodes.

:param tar: An opened tarfile.TarFile object.
:type tar: tarfile.TarFile
:param dest_dir: The destination directory for extraction.
:type dest_dir: str
:raises ValueError: If a tar member would escape the destination directory
or contains a symlink, hard link, or unsupported special entry type.
"""
resolved_dest = os.path.realpath(dest_dir)

# Python 3.12+ has built-in data_filter for safe extraction
if hasattr(tarfile, "data_filter"):
try:
tar.extractall(resolved_dest, filter="data")
except tarfile.TarError as exc:
raise ValueError(f"Failed to safely extract tar archive: {exc}") from exc
else:
for member in tar.getmembers():
# Reject symbolic and hard links
if member.issym() or member.islnk():
raise ValueError(
f"Tar archive contains a symbolic or hard link and cannot be extracted safely: {member.name}"
)
# Reject any non-regular, non-directory entries (e.g., devices, FIFOs)
if not (member.isfile() or member.isdir()):
raise ValueError(
f"Tar archive contains an unsupported special entry type and cannot be extracted safely: "
f"{member.name}"
)
member_path = os.path.realpath(os.path.join(resolved_dest, member.name))
if member_path != resolved_dest and not member_path.startswith(resolved_dest + os.sep):
raise ValueError(
f"Tar archive contains a path traversal entry and cannot be extracted safely: {member.name}"
)
# All members validated; safe to extract
tar.extractall(resolved_dest)

class CommonRuntimeHelper:
COMMON_RUNTIME_BOOTSTRAPPER_INFO = "common_runtime_bootstrapper_info.json"
Expand Down Expand Up @@ -266,8 +317,7 @@ def copy_bootstrapper_from_container(self, container: "docker.models.containers.
for chunk in data_stream:
f.write(chunk)
with tarfile.open(tar_file, mode="r") as tar:
for file_name in tar.getnames():
tar.extract(file_name, os.path.dirname(path_in_host))
_safe_tar_extractall(tar, os.path.dirname(path_in_host))
os.remove(tar_file)
except docker.errors.APIError as e:
msg = f"Copying {path_in_container} from container has failed. Detailed message: {e}"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
import io
import os
import shutil
import tarfile
import tempfile
import zipfile
from pathlib import Path
from unittest.mock import MagicMock

import pytest

from azure.ai.ml.operations._local_job_invoker import (
_get_creationflags_and_startupinfo_for_background_process,
_safe_tar_extractall,
patch_invocation_script_serialization,
unzip_to_temporary_file,
)


@pytest.mark.unittest
@pytest.mark.training_experiences_test
class TestLocalJobInvoker:
Expand Down Expand Up @@ -61,3 +67,118 @@ def test_creation_flags(self):
flags = _get_creationflags_and_startupinfo_for_background_process("linux")

assert flags == {"stderr": -2, "stdin": -3, "stdout": -3}
def _make_job_definition(name="test-run"):
job_def = MagicMock()
job_def.name = name
return job_def


@pytest.mark.unittest
@pytest.mark.training_experiences_test
class TestUnzipPathTraversalPrevention:
"""Tests for ZIP path traversal prevention in unzip_to_temporary_file."""

def test_normal_zip_extracts_successfully(self):
buf = io.BytesIO()
with zipfile.ZipFile(buf, "w") as zf:
zf.writestr("azureml-setup/invocation.sh", "#!/bin/bash\necho hello\n")
zf.writestr("azureml-setup/config.json", '{"key": "value"}')
zip_bytes = buf.getvalue()

job_def = _make_job_definition("safe-run")
result = unzip_to_temporary_file(job_def, zip_bytes)

try:
assert result.exists()
assert (result / "azureml-setup" / "invocation.sh").exists()
assert (result / "azureml-setup" / "config.json").exists()
finally:
if result.exists():
shutil.rmtree(result, ignore_errors=True)

def test_zip_with_path_traversal_is_rejected(self):
buf = io.BytesIO()
with zipfile.ZipFile(buf, "w") as zf:
zf.writestr("azureml-setup/invocation.sh", "#!/bin/bash\necho hello\n")
zf.writestr("../../etc/evil.sh", "#!/bin/bash\necho pwned\n")
zip_bytes = buf.getvalue()

job_def = _make_job_definition("traversal-run")
with pytest.raises(ValueError, match="path traversal"):
unzip_to_temporary_file(job_def, zip_bytes)

def test_zip_with_absolute_path_is_rejected(self):
buf = io.BytesIO()
with zipfile.ZipFile(buf, "w") as zf:
if os.name == "nt":
zf.writestr("C:/Windows/Temp/evil.sh", "#!/bin/bash\necho pwned\n")
else:
zf.writestr("/tmp/evil.sh", "#!/bin/bash\necho pwned\n")
zip_bytes = buf.getvalue()

job_def = _make_job_definition("absolute-path-run")
with pytest.raises(ValueError, match="path traversal"):
unzip_to_temporary_file(job_def, zip_bytes)


@pytest.mark.unittest
@pytest.mark.training_experiences_test
class TestSafeTarExtract:
"""Tests for tar path traversal prevention in _safe_tar_extractall."""

def test_normal_tar_extracts_successfully(self):
with tempfile.TemporaryDirectory() as dest:
buf = io.BytesIO()
with tarfile.open(fileobj=buf, mode="w") as tar:
data = b"normal content"
info = tarfile.TarInfo(name="vm-bootstrapper")
info.size = len(data)
tar.addfile(info, io.BytesIO(data))
buf.seek(0)

with tarfile.open(fileobj=buf, mode="r") as tar:
_safe_tar_extractall(tar, dest)

assert os.path.exists(os.path.join(dest, "vm-bootstrapper"))

def test_tar_with_path_traversal_is_rejected(self):
with tempfile.TemporaryDirectory() as dest:
buf = io.BytesIO()
with tarfile.open(fileobj=buf, mode="w") as tar:
data = b"evil content"
info = tarfile.TarInfo(name="../../evil_script.sh")
info.size = len(data)
tar.addfile(info, io.BytesIO(data))
buf.seek(0)

with tarfile.open(fileobj=buf, mode="r") as tar:
with pytest.raises(ValueError):
_safe_tar_extractall(tar, dest)

def test_tar_with_symlink_is_rejected(self):
with tempfile.TemporaryDirectory() as dest:
buf = io.BytesIO()
with tarfile.open(fileobj=buf, mode="w") as tar:
info = tarfile.TarInfo(name="evil_link")
info.type = tarfile.SYMTYPE
info.linkname = "/etc/passwd"
tar.addfile(info)
buf.seek(0)

with tarfile.open(fileobj=buf, mode="r") as tar:
with pytest.raises(ValueError):
_safe_tar_extractall(tar, dest)

def test_tar_with_hardlink_is_rejected(self):
with tempfile.TemporaryDirectory() as dest:
buf = io.BytesIO()
with tarfile.open(fileobj=buf, mode="w") as tar:
info = tarfile.TarInfo(name="evil_hardlink")
info.type = tarfile.LNKTYPE
info.linkname = "/etc/shadow"
tar.addfile(info)
buf.seek(0)

with tarfile.open(fileobj=buf, mode="r") as tar:
with pytest.raises(ValueError):
_safe_tar_extractall(tar, dest)
Loading