diff --git a/python/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client.py b/python/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client.py index c8d737668..33c457a5a 100644 --- a/python/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client.py +++ b/python/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client.py @@ -10,7 +10,7 @@ from concurrent.futures import CancelledError from contextlib import contextmanager from dataclasses import dataclass -from pathlib import Path, PosixPath +from pathlib import Path from queue import Queue from urllib.parse import urlparse @@ -18,7 +18,7 @@ import pexpect import requests from jumpstarter_driver_composite.client import CompositeClient -from jumpstarter_driver_opendal.client import FlasherClient, OpendalClient, operator_for_path +from jumpstarter_driver_opendal.client import FlasherClient, OpendalClient, clean_filename, operator_for_path from jumpstarter_driver_opendal.common import PathBuf from jumpstarter_driver_pyserial.client import Console from opendal import Metadata, Operator @@ -167,10 +167,14 @@ def flash( # noqa: C901 "http", root="/", endpoint=f"{parsed.scheme}://{parsed.netloc}", token=bearer_token ) operator_scheme = "http" - path = Path(parsed.path) + # Preserve query parameters so that signed URLs + # (e.g. CloudFront with ?Expires=...&Signature=...) work correctly. + path = parsed.path + if parsed.query: + path = f"{path}?{parsed.query}" else: path, operator, operator_scheme = operator_for_path(path) - image_url = self.http.get_url() + "/" + path.name + image_url = self.http.get_url() + "/" + self._filename(path) # start counting time for the flash operation start_time = time.time() @@ -968,7 +972,7 @@ def _transfer_bg_thread( """ self.logger.info(f"Writing image to storage in the background: {src_path}") try: - filename = Path(src_path).name if isinstance(src_path, (str, os.PathLike)) else src_path.name + filename = self._filename(src_path) if src_operator_scheme == "fs": file_hash = self._sha256_file(src_operator, src_path) @@ -1088,8 +1092,8 @@ def dump( raise NotImplementedError("Dump is not implemented for this driver yet") def _filename(self, path: PathBuf) -> str: - """Extract filename from url or path""" - if path.startswith("oci://"): + """Extract filename from url or path, stripping any query parameters""" + if isinstance(path, str) and path.startswith("oci://"): oci_path = path[6:] # Remove "oci://" prefix if ":" in oci_path: repository, tag = oci_path.rsplit(":", 1) @@ -1098,10 +1102,8 @@ def _filename(self, path: PathBuf) -> str: else: repo_name = oci_path.split("/")[-1] if "/" in oci_path else oci_path return repo_name - elif path.startswith(("http://", "https://")): - return urlparse(path).path.split("/")[-1] else: - return Path(path).name + return clean_filename(path) def _upload_artifact(self, storage, path: PathBuf, operator: Operator): """Upload artifact to storage""" @@ -1636,17 +1638,12 @@ def _get_decompression_command(filename_or_url) -> str: Determine the appropriate decompression command based on file extension Args: - filename (str): Name of the file to check + filename_or_url (str): Name of the file or URL to check Returns: str: Decompression command ('zcat', 'xzcat', or 'cat' for uncompressed) """ - if type(filename_or_url) is PosixPath: - filename = filename_or_url.name - elif filename_or_url.startswith(("http://", "https://")): - filename = urlparse(filename_or_url).path.split("/")[-1] - - filename = filename.lower() + filename = clean_filename(filename_or_url).lower() if filename.endswith((".gz", ".gzip")): return "zcat |" elif filename.endswith(".xz"): diff --git a/python/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client_test.py b/python/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client_test.py index 44f1214c4..e462f2a87 100644 --- a/python/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client_test.py +++ b/python/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client_test.py @@ -428,6 +428,86 @@ def test_categorize_exception_preserves_cause_for_wrapped_exceptions(): assert "File not found" in str(result) +def test_filename_strips_query_params_from_url_path(): + """Test _filename strips query parameters from paths with signed URL params""" + client = MockFlasherClient() + + # Full HTTP URL + assert client._filename("https://cdn.example.com/images/image.raw.xz") == "image.raw.xz" + + # Full HTTP URL with query parameters (e.g. CloudFront signed URL) + assert ( + client._filename("https://cdn.example.com/images/image.raw.xz?Expires=123&Signature=abc&Key-Pair-Id=xyz") + == "image.raw.xz" + ) + + # Path string with query parameters (as returned by operator_for_path after fix) + assert client._filename("/images/image.raw.xz?Expires=123&Signature=abc") == "image.raw.xz" + + # Plain path without query parameters + assert client._filename("/images/image.raw.xz") == "image.raw.xz" + + # OCI path + assert client._filename("oci://quay.io/org/myimage:latest") == "myimage-latest" + + +def test_decompression_command_with_query_params(): + """Test _get_decompression_command handles paths with query parameters""" + from pathlib import PosixPath + + from .client import _get_decompression_command + + # Standard PosixPath + assert _get_decompression_command(PosixPath("/images/image.raw.xz")) == "xzcat |" + assert _get_decompression_command(PosixPath("/images/image.raw.gz")) == "zcat |" + assert _get_decompression_command(PosixPath("/images/image.raw")) == "" + + # Full HTTP URL + assert _get_decompression_command("https://cdn.example.com/images/image.raw.xz") == "xzcat |" + + # String path with query parameters (as returned by operator_for_path for signed URLs) + assert _get_decompression_command("/images/image.raw.xz?Expires=123&Signature=abc") == "xzcat |" + assert _get_decompression_command("/images/image.raw.gz?Expires=123") == "zcat |" + assert _get_decompression_command("/images/image.raw?Expires=123") == "" + + +def test_flash_signed_url_preserves_query_params(): + """Test that flash with a signed HTTP URL preserves query parameters for image_url""" + client = MockFlasherClient() + + class DummyService: + def __init__(self): + self.storage = object() + + def start(self): + pass + + def stop(self): + pass + + def get_url(self): + return "http://exporter" + + client.http = DummyService() + client.tftp = DummyService() + client.call = lambda *args, **kwargs: None + + captured = {} + + def capture_perform(*args): + captured["image_url"] = args[3] + captured["should_download_to_httpd"] = args[4] + + client._perform_flash_operation = capture_perform + + # Direct HTTP URL with query params (no force_exporter_http) should preserve full URL + signed_url = "https://cdn.example.com/images/image.raw.xz?Expires=123&Signature=abc&Key-Pair-Id=xyz" + client.flash(signed_url, method="fls", fls_version="") + + assert captured["image_url"] == signed_url + assert captured["should_download_to_httpd"] is False + + def test_resolve_flash_parameters(): """Test flash parameter resolution for single file, partitions, and error cases""" client = MockFlasherClient() diff --git a/python/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/client.py b/python/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/client.py index eca6aa7d1..89024463e 100644 --- a/python/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/client.py +++ b/python/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/client.py @@ -44,6 +44,21 @@ async def aclose(self): pass +def clean_filename(path: PathBuf) -> str: + """Extract a clean filename from a path or URL, stripping query parameters. + + This handles paths returned by operator_for_path() which may contain + query parameters for signed URLs (e.g. /path/to/image.raw.xz?Expires=...&Signature=...). + """ + path_str = str(path) + if path_str.startswith(("http://", "https://")): + return urlparse(path_str).path.split("/")[-1] + name = Path(path_str).name + if "?" in name: + name = name.split("?", 1)[0] + return name + + def operator_for_path(path: PathBuf) -> tuple[PathBuf, Operator, str]: """Create an operator for the given path Return a tuple of: @@ -54,7 +69,13 @@ def operator_for_path(path: PathBuf) -> tuple[PathBuf, Operator, str]: if type(path) is str and path.startswith(("http://", "https://")): parsed_url = urlparse(path) operator = Operator("http", root="/", endpoint=f"{parsed_url.scheme}://{parsed_url.netloc}") - return Path(parsed_url.path), operator, "http" + # Preserve query parameters in the path so that signed URLs + # (e.g. CloudFront URLs with ?Expires=...&Signature=...&Key-Pair-Id=...) + # are fetched correctly by the OpenDAL HTTP operator. + op_path = parsed_url.path + if parsed_url.query: + op_path = f"{op_path}?{parsed_url.query}" + return op_path, operator, "http" else: return Path(path).resolve(), Operator("fs", root="/"), "fs" diff --git a/python/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/driver_test.py b/python/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/driver_test.py index 2668b1760..8cfb74744 100644 --- a/python/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/driver_test.py +++ b/python/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/driver_test.py @@ -322,3 +322,31 @@ def test_copy_and_rename_tracking(tmp_path): assert "copied_dir" in created_paths assert "renamed_dir" in created_paths assert len(created_paths) == 4 + + +def test_operator_for_path_preserves_query_params(): + """Test that operator_for_path preserves query parameters for HTTP URLs""" + from .client import operator_for_path + + # HTTP URL without query parameters + path, operator, scheme = operator_for_path("https://cdn.example.com/images/image.raw.xz") + assert scheme == "http" + assert path == "/images/image.raw.xz" + + # HTTP URL with query parameters (e.g. CloudFront signed URL) + path, operator, scheme = operator_for_path( + "https://cdn.example.com/images/image.raw.xz?Expires=123&Signature=abc&Key-Pair-Id=xyz" + ) + assert scheme == "http" + assert path == "/images/image.raw.xz?Expires=123&Signature=abc&Key-Pair-Id=xyz" + assert "Expires=123" in path + assert "Signature=abc" in path + assert "Key-Pair-Id=xyz" in path + + # Filesystem path (use resolve() for the expected value since macOS + # resolves /tmp to /private/tmp) + from pathlib import Path + + path, operator, scheme = operator_for_path("/tmp/image.raw.xz") + assert scheme == "fs" + assert path == Path("/tmp/image.raw.xz").resolve() diff --git a/python/packages/jumpstarter-driver-ridesx/jumpstarter_driver_ridesx/client.py b/python/packages/jumpstarter-driver-ridesx/jumpstarter_driver_ridesx/client.py index 5edcbe391..f42e683a9 100644 --- a/python/packages/jumpstarter-driver-ridesx/jumpstarter_driver_ridesx/client.py +++ b/python/packages/jumpstarter-driver-ridesx/jumpstarter_driver_ridesx/client.py @@ -5,7 +5,7 @@ import click from jumpstarter_driver_composite.client import CompositeClient -from jumpstarter_driver_opendal.client import FlasherClient, operator_for_path +from jumpstarter_driver_opendal.client import FlasherClient, clean_filename, operator_for_path from jumpstarter_driver_power.client import PowerClient from opendal import Operator @@ -39,7 +39,7 @@ def _upload_file_if_needed(self, file_path: str, operator: Operator | None = Non path_buf = Path(file_path) operator_scheme = "unknown" - filename = Path(path_buf).name + filename = clean_filename(path_buf) if self._should_upload_file(self.storage, filename, path_buf, operator, operator_scheme): if operator_scheme == "http":