diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_local_job_invoker.py b/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_local_job_invoker.py index b7436d307902..497476198a25 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_local_job_invoker.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_local_job_invoker.py @@ -41,7 +41,17 @@ def unzip_to_temporary_file(job_definition: JobBaseData, zip_content: Any) -> Path: temp_dir = Path(tempfile.gettempdir(), AZUREML_RUNS_DIR, job_definition.name) temp_dir.mkdir(parents=True, exist_ok=True) + resolved_temp_dir = temp_dir.resolve() with zipfile.ZipFile(io.BytesIO(zip_content)) as zip_ref: + for member in zip_ref.namelist(): + member_path = (resolved_temp_dir / member).resolve() + # Ensure the member extracts within temp_dir (allow temp_dir itself for directory entries) + if member_path != resolved_temp_dir and not str(member_path).startswith( + str(resolved_temp_dir) + os.sep + ): + raise ValueError( + f"Zip archive contains a path traversal entry and cannot be extracted safely: {member}" + ) zip_ref.extractall(temp_dir) return temp_dir @@ -94,19 +104,25 @@ def patch_invocation_script_serialization(invocation_path: Path) -> None: if searchRes: patched_json = searchRes.group(2).replace('"', '\\"') patched_json = patched_json.replace("'", '"') - invocation_path.write_text(searchRes.group(1) + patched_json + searchRes.group(3)) + invocation_path.write_text( + searchRes.group(1) + patched_json + searchRes.group(3) + ) def invoke_command(project_temp_dir: Path) -> None: if os.name == "nt": - invocation_script = project_temp_dir / AZUREML_RUN_SETUP_DIR / INVOCATION_BAT_FILE + invocation_script = ( + project_temp_dir / AZUREML_RUN_SETUP_DIR / INVOCATION_BAT_FILE + ) # There is a bug in Execution service on the serialized json for snapshots. # This is a client-side patch until the service fixes it, at which point it should # be a no-op patch_invocation_script_serialization(invocation_script) invoked_command = ["cmd.exe", "/c", "{0}".format(invocation_script)] else: - invocation_script = project_temp_dir / AZUREML_RUN_SETUP_DIR / INVOCATION_BASH_FILE + invocation_script = ( + project_temp_dir / AZUREML_RUN_SETUP_DIR / INVOCATION_BASH_FILE + ) subprocess.check_output(["chmod", "+x", invocation_script]) invoked_command = ["/bin/bash", "-c", "{0}".format(invocation_script)] @@ -142,10 +158,12 @@ def get_execution_service_response( try: local = job_definition.properties.services.get("Local", None) - (url, encodedBody) = local.endpoint.split(EXECUTION_SERVICE_URL_KEY) + url, encodedBody = local.endpoint.split(EXECUTION_SERVICE_URL_KEY) body = urllib.parse.unquote_plus(encodedBody) body_dict: Dict = json.loads(body) - response = requests_pipeline.post(url, json=body_dict, headers={"Authorization": "Bearer " + token}) + response = requests_pipeline.post( + url, json=body_dict, headers={"Authorization": "Bearer " + token} + ) response.raise_for_status() return (response.content, body_dict.get("SnapshotId", None)) # type: ignore[return-value] except AzureError as err: @@ -167,6 +185,53 @@ def is_local_run(job_definition: JobBaseData) -> bool: return local is not None and EXECUTION_SERVICE_URL_KEY in local.endpoint +def _safe_tar_extractall(tar: tarfile.TarFile, dest_dir: str) -> None: + """Extract tar archive members safely, preventing path traversal (TarSlip). + + On Python 3.12+, uses the built-in 'data' filter. On older versions, + manually validates each member to ensure no path traversal, symlinks, + hard links, or other special entries that could write outside the + destination directory or create unsafe filesystem nodes. + + :param tar: An opened tarfile.TarFile object. + :type tar: tarfile.TarFile + :param dest_dir: The destination directory for extraction. + :type dest_dir: str + :raises ValueError: If a tar member would escape the destination directory + or contains a symlink, hard link, or unsupported special entry type. + """ + resolved_dest = os.path.realpath(dest_dir) + + # Python 3.12+ has built-in data_filter for safe extraction + if hasattr(tarfile, "data_filter"): + try: + tar.extractall(resolved_dest, filter="data") + except tarfile.TarError as exc: + raise ValueError(f"Failed to safely extract tar archive: {exc}") from exc + else: + for member in tar.getmembers(): + # Reject symbolic and hard links + if member.issym() or member.islnk(): + raise ValueError( + f"Tar archive contains a symbolic or hard link and cannot be extracted safely: {member.name}" + ) + # Reject any non-regular, non-directory entries (e.g., devices, FIFOs) + if not (member.isfile() or member.isdir()): + raise ValueError( + f"Tar archive contains an unsupported special entry type and cannot be extracted safely: " + f"{member.name}" + ) + member_path = os.path.realpath(os.path.join(resolved_dest, member.name)) + if member_path != resolved_dest and not member_path.startswith( + resolved_dest + os.sep + ): + raise ValueError( + f"Tar archive contains a path traversal entry and cannot be extracted safely: {member.name}" + ) + # All members validated; safe to extract + tar.extractall(resolved_dest) + + class CommonRuntimeHelper: COMMON_RUNTIME_BOOTSTRAPPER_INFO = "common_runtime_bootstrapper_info.json" COMMON_RUNTIME_JOB_SPEC = "common_runtime_job_spec.json" @@ -192,14 +257,18 @@ class CommonRuntimeHelper: "Unable to communicate with Docker daemon. Is Docker running/installed?\n " "For local submissions, we need to build a Docker container to run your job in.\n Detailed message: {}" ) - DOCKER_LOGIN_FAILURE_MSG = "Login to Docker registry '{}' failed. See error message: {}" + DOCKER_LOGIN_FAILURE_MSG = ( + "Login to Docker registry '{}' failed. See error message: {}" + ) BOOTSTRAP_BINARY_FAILURE_MSG = ( "Azure Common Runtime execution failed. See detailed message below for troubleshooting " "information or re-submit with flag --use-local-runtime to try running on your local runtime: {}" ) def __init__(self, job_name: str): - self.common_runtime_temp_folder = os.path.join(Path.home(), ".azureml-common-runtime", job_name) + self.common_runtime_temp_folder = os.path.join( + Path.home(), ".azureml-common-runtime", job_name + ) if os.path.exists(self.common_runtime_temp_folder): shutil.rmtree(self.common_runtime_temp_folder) Path(self.common_runtime_temp_folder).mkdir(parents=True) @@ -208,10 +277,14 @@ def __init__(self, job_name: str): CommonRuntimeHelper.VM_BOOTSTRAPPER_FILE_NAME, ) self.stdout = open( # pylint: disable=consider-using-with - os.path.join(self.common_runtime_temp_folder, "stdout"), "w+", encoding=DefaultOpenEncoding.WRITE + os.path.join(self.common_runtime_temp_folder, "stdout"), + "w+", + encoding=DefaultOpenEncoding.WRITE, ) self.stderr = open( # pylint: disable=consider-using-with - os.path.join(self.common_runtime_temp_folder, "stderr"), "w+", encoding=DefaultOpenEncoding.WRITE + os.path.join(self.common_runtime_temp_folder, "stderr"), + "w+", + encoding=DefaultOpenEncoding.WRITE, ) # Bug Item number: 2885723 @@ -243,9 +316,13 @@ def get_docker_client(self, registry: Dict) -> "docker.DockerClient": # type: i registry=registry.get("url"), ) except Exception as e: - raise RuntimeError(self.DOCKER_LOGIN_FAILURE_MSG.format(registry.get("url"), e)) from e + raise RuntimeError( + self.DOCKER_LOGIN_FAILURE_MSG.format(registry.get("url"), e) + ) from e else: - raise RuntimeError("Registry information is missing from bootstrapper configuration.") + raise RuntimeError( + "Registry information is missing from bootstrapper configuration." + ) return client @@ -266,14 +343,15 @@ def copy_bootstrapper_from_container(self, container: "docker.models.containers. for chunk in data_stream: f.write(chunk) with tarfile.open(tar_file, mode="r") as tar: - for file_name in tar.getnames(): - tar.extract(file_name, os.path.dirname(path_in_host)) + _safe_tar_extractall(tar, os.path.dirname(path_in_host)) os.remove(tar_file) except docker.errors.APIError as e: msg = f"Copying {path_in_container} from container has failed. Detailed message: {e}" raise MlException(message=msg, no_personal_data_message=msg) from e - def get_common_runtime_info_from_response(self, response: Any) -> Tuple[Dict[str, str], str]: + def get_common_runtime_info_from_response( + self, response: Any + ) -> Tuple[Dict[str, str], str]: """Extract common-runtime info from Execution Service response. :param response: Content of zip file from Execution Service containing all the @@ -284,13 +362,22 @@ def get_common_runtime_info_from_response(self, response: Any) -> Tuple[Dict[str """ with zipfile.ZipFile(io.BytesIO(response)) as zip_ref: - bootstrapper_path = f"{AZUREML_RUN_SETUP_DIR}/{self.COMMON_RUNTIME_BOOTSTRAPPER_INFO}" + bootstrapper_path = ( + f"{AZUREML_RUN_SETUP_DIR}/{self.COMMON_RUNTIME_BOOTSTRAPPER_INFO}" + ) job_spec_path = f"{AZUREML_RUN_SETUP_DIR}/{self.COMMON_RUNTIME_JOB_SPEC}" - if not all(file_path in zip_ref.namelist() for file_path in [bootstrapper_path, job_spec_path]): - raise RuntimeError(f"{bootstrapper_path}, {job_spec_path} are not in the execution service response.") + if not all( + file_path in zip_ref.namelist() + for file_path in [bootstrapper_path, job_spec_path] + ): + raise RuntimeError( + f"{bootstrapper_path}, {job_spec_path} are not in the execution service response." + ) with zip_ref.open(bootstrapper_path, "r") as bootstrapper_file: - bootstrapper_json = json.loads(base64.b64decode(bootstrapper_file.read())) + bootstrapper_json = json.loads( + base64.b64decode(bootstrapper_file.read()) + ) with zip_ref.open(job_spec_path, "r") as job_spec_file: job_spec = job_spec_file.read().decode("utf-8") @@ -312,9 +399,13 @@ def get_bootstrapper_binary(self, bootstrapper_info: Dict) -> None: tag = bootstrapper_info.get("tag") if repo_prefix: - bootstrapper_image = f"{repository}/{repo_prefix}/boot/vm-bootstrapper/binimage/linux:{tag}" + bootstrapper_image = ( + f"{repository}/{repo_prefix}/boot/vm-bootstrapper/binimage/linux:{tag}" + ) else: - bootstrapper_image = f"{repository}/boot/vm-bootstrapper/binimage/linux:{tag}" + bootstrapper_image = ( + f"{repository}/boot/vm-bootstrapper/binimage/linux:{tag}" + ) try: boot_img = docker_client.images.pull(bootstrapper_image) @@ -328,7 +419,9 @@ def get_bootstrapper_binary(self, bootstrapper_info: Dict) -> None: boot_container.stop() boot_container.remove() - def execute_bootstrapper(self, bootstrapper_binary: str, job_spec: str) -> subprocess.Popen: + def execute_bootstrapper( + self, bootstrapper_binary: str, job_spec: str + ) -> subprocess.Popen: """Runs vm-bootstrapper with the job specification passed to it. This will build the Docker container, create all necessary files and directories, and run the job locally. Command args are defined by Common Runtime team here: https://msdata.visualstudio.com/Vienna/_git/vienna?path=/src/azureml- job-runtime/common- @@ -372,7 +465,9 @@ def execute_bootstrapper(self, bootstrapper_binary: str, job_spec: str) -> subpr process.kill() raise RuntimeError(LOCAL_JOB_FAILURE_MSG.format(self.stderr.read())) - def check_bootstrapper_process_status(self, bootstrapper_process: subprocess.Popen) -> Optional[int]: + def check_bootstrapper_process_status( + self, bootstrapper_process: subprocess.Popen + ) -> Optional[int]: """Check if bootstrapper process status is non-zero. :param bootstrapper_process: bootstrapper process @@ -383,7 +478,9 @@ def check_bootstrapper_process_status(self, bootstrapper_process: subprocess.Pop return_code = bootstrapper_process.poll() if return_code: self.stderr.seek(0) - raise RuntimeError(self.BOOTSTRAP_BINARY_FAILURE_MSG.format(self.stderr.read())) + raise RuntimeError( + self.BOOTSTRAP_BINARY_FAILURE_MSG.format(self.stderr.read()) + ) return return_code @@ -408,7 +505,9 @@ def start_run_if_local( :rtype: str """ token = credential.get_token(ws_base_url + "/.default").token - (zip_content, snapshot_id) = get_execution_service_response(job_definition, token, requests_pipeline) + zip_content, snapshot_id = get_execution_service_response( + job_definition, token, requests_pipeline + ) try: temp_dir = unzip_to_temporary_file(job_definition, zip_content) diff --git a/sdk/ml/azure-ai-ml/tests/job_common/unittests/test_local_job_invoker.py b/sdk/ml/azure-ai-ml/tests/job_common/unittests/test_local_job_invoker.py index 53f1ab6db677..cca25c7313c6 100644 --- a/sdk/ml/azure-ai-ml/tests/job_common/unittests/test_local_job_invoker.py +++ b/sdk/ml/azure-ai-ml/tests/job_common/unittests/test_local_job_invoker.py @@ -1,12 +1,19 @@ +import io import os +import shutil +import tarfile import tempfile +import zipfile from pathlib import Path +from unittest.mock import MagicMock import pytest from azure.ai.ml.operations._local_job_invoker import ( _get_creationflags_and_startupinfo_for_background_process, + _safe_tar_extractall, patch_invocation_script_serialization, + unzip_to_temporary_file, ) @@ -61,3 +68,120 @@ def test_creation_flags(self): flags = _get_creationflags_and_startupinfo_for_background_process("linux") assert flags == {"stderr": -2, "stdin": -3, "stdout": -3} + + +def _make_job_definition(name="test-run"): + job_def = MagicMock() + job_def.name = name + return job_def + + +@pytest.mark.unittest +@pytest.mark.training_experiences_test +class TestUnzipPathTraversalPrevention: + """Tests for ZIP path traversal prevention in unzip_to_temporary_file.""" + + def test_normal_zip_extracts_successfully(self): + buf = io.BytesIO() + with zipfile.ZipFile(buf, "w") as zf: + zf.writestr("azureml-setup/invocation.sh", "#!/bin/bash\necho hello\n") + zf.writestr("azureml-setup/config.json", '{"key": "value"}') + zip_bytes = buf.getvalue() + + job_def = _make_job_definition("safe-run") + result = unzip_to_temporary_file(job_def, zip_bytes) + + try: + assert result.exists() + assert (result / "azureml-setup" / "invocation.sh").exists() + assert (result / "azureml-setup" / "config.json").exists() + finally: + if result.exists(): + shutil.rmtree(result, ignore_errors=True) + + def test_zip_with_path_traversal_is_rejected(self): + buf = io.BytesIO() + with zipfile.ZipFile(buf, "w") as zf: + zf.writestr("azureml-setup/invocation.sh", "#!/bin/bash\necho hello\n") + zf.writestr("../../etc/evil.sh", "#!/bin/bash\necho pwned\n") + zip_bytes = buf.getvalue() + + job_def = _make_job_definition("traversal-run") + with pytest.raises(ValueError, match="path traversal"): + unzip_to_temporary_file(job_def, zip_bytes) + + def test_zip_with_absolute_path_is_rejected(self): + buf = io.BytesIO() + with zipfile.ZipFile(buf, "w") as zf: + if os.name == "nt": + zf.writestr("C:/Windows/Temp/evil.sh", "#!/bin/bash\necho pwned\n") + else: + zf.writestr("/tmp/evil.sh", "#!/bin/bash\necho pwned\n") + zip_bytes = buf.getvalue() + + job_def = _make_job_definition("absolute-path-run") + with pytest.raises(ValueError, match="path traversal"): + unzip_to_temporary_file(job_def, zip_bytes) + + +@pytest.mark.unittest +@pytest.mark.training_experiences_test +class TestSafeTarExtract: + """Tests for tar path traversal prevention in _safe_tar_extractall.""" + + def test_normal_tar_extracts_successfully(self): + with tempfile.TemporaryDirectory() as dest: + buf = io.BytesIO() + with tarfile.open(fileobj=buf, mode="w") as tar: + data = b"normal content" + info = tarfile.TarInfo(name="vm-bootstrapper") + info.size = len(data) + tar.addfile(info, io.BytesIO(data)) + buf.seek(0) + + with tarfile.open(fileobj=buf, mode="r") as tar: + _safe_tar_extractall(tar, dest) + + assert os.path.exists(os.path.join(dest, "vm-bootstrapper")) + + def test_tar_with_path_traversal_is_rejected(self): + with tempfile.TemporaryDirectory() as dest: + buf = io.BytesIO() + with tarfile.open(fileobj=buf, mode="w") as tar: + data = b"evil content" + info = tarfile.TarInfo(name="../../evil_script.sh") + info.size = len(data) + tar.addfile(info, io.BytesIO(data)) + buf.seek(0) + + with tarfile.open(fileobj=buf, mode="r") as tar: + with pytest.raises(ValueError): + _safe_tar_extractall(tar, dest) + + def test_tar_with_symlink_is_rejected(self): + with tempfile.TemporaryDirectory() as dest: + buf = io.BytesIO() + with tarfile.open(fileobj=buf, mode="w") as tar: + info = tarfile.TarInfo(name="evil_link") + info.type = tarfile.SYMTYPE + info.linkname = "/etc/passwd" + tar.addfile(info) + buf.seek(0) + + with tarfile.open(fileobj=buf, mode="r") as tar: + with pytest.raises(ValueError): + _safe_tar_extractall(tar, dest) + + def test_tar_with_hardlink_is_rejected(self): + with tempfile.TemporaryDirectory() as dest: + buf = io.BytesIO() + with tarfile.open(fileobj=buf, mode="w") as tar: + info = tarfile.TarInfo(name="evil_hardlink") + info.type = tarfile.LNKTYPE + info.linkname = "/etc/shadow" + tar.addfile(info) + buf.seek(0) + + with tarfile.open(fileobj=buf, mode="r") as tar: + with pytest.raises(ValueError): + _safe_tar_extractall(tar, dest)