Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@
from concurrent.futures import CancelledError
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path, PosixPath
from pathlib import Path
from queue import Queue
from urllib.parse import urlparse

import click
import pexpect
import requests
from jumpstarter_driver_composite.client import CompositeClient
from jumpstarter_driver_opendal.client import FlasherClient, OpendalClient, operator_for_path
from jumpstarter_driver_opendal.client import FlasherClient, OpendalClient, clean_filename, operator_for_path
from jumpstarter_driver_opendal.common import PathBuf
from jumpstarter_driver_pyserial.client import Console
from opendal import Metadata, Operator
Expand Down Expand Up @@ -167,10 +167,14 @@ def flash( # noqa: C901
"http", root="/", endpoint=f"{parsed.scheme}://{parsed.netloc}", token=bearer_token
)
operator_scheme = "http"
path = Path(parsed.path)
# Preserve query parameters so that signed URLs
# (e.g. CloudFront with ?Expires=...&Signature=...) work correctly.
path = parsed.path
if parsed.query:
path = f"{path}?{parsed.query}"
else:
path, operator, operator_scheme = operator_for_path(path)
image_url = self.http.get_url() + "/" + path.name
image_url = self.http.get_url() + "/" + self._filename(path)

# start counting time for the flash operation
start_time = time.time()
Expand Down Expand Up @@ -968,7 +972,7 @@ def _transfer_bg_thread(
"""
self.logger.info(f"Writing image to storage in the background: {src_path}")
try:
filename = Path(src_path).name if isinstance(src_path, (str, os.PathLike)) else src_path.name
filename = self._filename(src_path)

if src_operator_scheme == "fs":
file_hash = self._sha256_file(src_operator, src_path)
Expand Down Expand Up @@ -1088,8 +1092,8 @@ def dump(
raise NotImplementedError("Dump is not implemented for this driver yet")

def _filename(self, path: PathBuf) -> str:
"""Extract filename from url or path"""
if path.startswith("oci://"):
"""Extract filename from url or path, stripping any query parameters"""
if isinstance(path, str) and path.startswith("oci://"):
oci_path = path[6:] # Remove "oci://" prefix
if ":" in oci_path:
repository, tag = oci_path.rsplit(":", 1)
Expand All @@ -1098,10 +1102,8 @@ def _filename(self, path: PathBuf) -> str:
else:
repo_name = oci_path.split("/")[-1] if "/" in oci_path else oci_path
return repo_name
elif path.startswith(("http://", "https://")):
return urlparse(path).path.split("/")[-1]
else:
return Path(path).name
return clean_filename(path)

def _upload_artifact(self, storage, path: PathBuf, operator: Operator):
"""Upload artifact to storage"""
Expand Down Expand Up @@ -1636,17 +1638,12 @@ def _get_decompression_command(filename_or_url) -> str:
Determine the appropriate decompression command based on file extension

Args:
filename (str): Name of the file to check
filename_or_url (str): Name of the file or URL to check

Returns:
str: Decompression command ('zcat', 'xzcat', or 'cat' for uncompressed)
"""
if type(filename_or_url) is PosixPath:
filename = filename_or_url.name
elif filename_or_url.startswith(("http://", "https://")):
filename = urlparse(filename_or_url).path.split("/")[-1]

filename = filename.lower()
filename = clean_filename(filename_or_url).lower()
if filename.endswith((".gz", ".gzip")):
return "zcat |"
elif filename.endswith(".xz"):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,86 @@ def test_categorize_exception_preserves_cause_for_wrapped_exceptions():
assert "File not found" in str(result)


def test_filename_strips_query_params_from_url_path():
"""Test _filename strips query parameters from paths with signed URL params"""
client = MockFlasherClient()

# Full HTTP URL
assert client._filename("https://cdn.example.com/images/image.raw.xz") == "image.raw.xz"

# Full HTTP URL with query parameters (e.g. CloudFront signed URL)
assert (
client._filename("https://cdn.example.com/images/image.raw.xz?Expires=123&Signature=abc&Key-Pair-Id=xyz")
== "image.raw.xz"
)

# Path string with query parameters (as returned by operator_for_path after fix)
assert client._filename("/images/image.raw.xz?Expires=123&Signature=abc") == "image.raw.xz"

# Plain path without query parameters
assert client._filename("/images/image.raw.xz") == "image.raw.xz"

# OCI path
assert client._filename("oci://quay.io/org/myimage:latest") == "myimage-latest"


def test_decompression_command_with_query_params():
"""Test _get_decompression_command handles paths with query parameters"""
from pathlib import PosixPath

from .client import _get_decompression_command

# Standard PosixPath
assert _get_decompression_command(PosixPath("/images/image.raw.xz")) == "xzcat |"
assert _get_decompression_command(PosixPath("/images/image.raw.gz")) == "zcat |"
assert _get_decompression_command(PosixPath("/images/image.raw")) == ""

# Full HTTP URL
assert _get_decompression_command("https://cdn.example.com/images/image.raw.xz") == "xzcat |"

# String path with query parameters (as returned by operator_for_path for signed URLs)
assert _get_decompression_command("/images/image.raw.xz?Expires=123&Signature=abc") == "xzcat |"
assert _get_decompression_command("/images/image.raw.gz?Expires=123") == "zcat |"
assert _get_decompression_command("/images/image.raw?Expires=123") == ""


def test_flash_signed_url_preserves_query_params():
"""Test that flash with a signed HTTP URL preserves query parameters for image_url"""
client = MockFlasherClient()

class DummyService:
def __init__(self):
self.storage = object()

def start(self):
pass

def stop(self):
pass

def get_url(self):
return "http://exporter"

client.http = DummyService()
client.tftp = DummyService()
client.call = lambda *args, **kwargs: None

captured = {}

def capture_perform(*args):
captured["image_url"] = args[3]
captured["should_download_to_httpd"] = args[4]

client._perform_flash_operation = capture_perform

# Direct HTTP URL with query params (no force_exporter_http) should preserve full URL
signed_url = "https://cdn.example.com/images/image.raw.xz?Expires=123&Signature=abc&Key-Pair-Id=xyz"
client.flash(signed_url, method="fls", fls_version="")

assert captured["image_url"] == signed_url
assert captured["should_download_to_httpd"] is False
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice test! One thing to note: this covers the code path where the signed URL is passed through directly (no download to the exporter httpd). The more involved path where query params are reconstructed from parsed.path + parsed.query (the bearer token branch in flash()) is only covered indirectly via the unit tests of the individual helpers. Consider whether an integration-level test for that branch would be valuable -- not a blocker, just something to keep in mind.

AI-generated, human reviewed

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the note. Agreed that an integration-level test for the bearer token branch would add value. I'll keep that in mind as a follow-up -- the current unit tests for the individual helpers provide reasonable coverage for now and adding a full integration test for that path would require more complex mocking of the HTTP download + bearer auth flow which feels out of scope for this fix PR.



def test_resolve_flash_parameters():
"""Test flash parameter resolution for single file, partitions, and error cases"""
client = MockFlasherClient()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,21 @@ async def aclose(self):
pass


def clean_filename(path: PathBuf) -> str:
"""Extract a clean filename from a path or URL, stripping query parameters.

This handles paths returned by operator_for_path() which may contain
query parameters for signed URLs (e.g. /path/to/image.raw.xz?Expires=...&Signature=...).
"""
path_str = str(path)
if path_str.startswith(("http://", "https://")):
return urlparse(path_str).path.split("/")[-1]
name = Path(path_str).name
if "?" in name:
name = name.split("?", 1)[0]
return name


def operator_for_path(path: PathBuf) -> tuple[PathBuf, Operator, str]:
"""Create an operator for the given path
Return a tuple of:
Expand All @@ -54,7 +69,13 @@ def operator_for_path(path: PathBuf) -> tuple[PathBuf, Operator, str]:
if type(path) is str and path.startswith(("http://", "https://")):
parsed_url = urlparse(path)
operator = Operator("http", root="/", endpoint=f"{parsed_url.scheme}://{parsed_url.netloc}")
return Path(parsed_url.path), operator, "http"
# Preserve query parameters in the path so that signed URLs
# (e.g. CloudFront URLs with ?Expires=...&Signature=...&Key-Pair-Id=...)
# are fetched correctly by the OpenDAL HTTP operator.
op_path = parsed_url.path
if parsed_url.query:
op_path = f"{op_path}?{parsed_url.query}"
return op_path, operator, "http"
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that operator_for_path() returns a plain string with query parameters for HTTP URLs (instead of a Path), callers that do Path(returned_path).name will get a filename polluted with query params.

The flashers package handles this correctly after your changes, but the ridesx client at python/packages/jumpstarter-driver-ridesx/jumpstarter_driver_ridesx/client.py line 42 does:

filename = Path(path_buf).name

which would produce image.raw.xz?Expires=123&Signature=abc as the filename. That gets passed to write_from_path, creating a file with query params in its name on the exporter storage.

Would it make sense to either update the ridesx caller as well, or provide a shared helper for extracting a clean filename from the return value of operator_for_path()? The risk depends on whether ridesx is ever used with signed URLs, but it seems worth guarding against.

AI-generated, human reviewed

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch -- this is indeed a real bug. The ridesx client's _upload_file_if_needed() at line 42 would produce filenames like image.raw.xz?Expires=123&Signature=abc when passed a signed URL.

Fixed in 4aaa40a by:

  1. Extracting a shared clean_filename() helper in jumpstarter_driver_opendal.client that strips query parameters from any path/URL
  2. Updating the ridesx client to use clean_filename(path_buf) instead of Path(path_buf).name
  3. Also refactored _filename() and _get_decompression_command() in the flashers client to use the same shared helper, keeping everything in sync

else:
return Path(path).resolve(), Operator("fs", root="/"), "fs"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -322,3 +322,31 @@ def test_copy_and_rename_tracking(tmp_path):
assert "copied_dir" in created_paths
assert "renamed_dir" in created_paths
assert len(created_paths) == 4


def test_operator_for_path_preserves_query_params():
"""Test that operator_for_path preserves query parameters for HTTP URLs"""
from .client import operator_for_path

# HTTP URL without query parameters
path, operator, scheme = operator_for_path("https://cdn.example.com/images/image.raw.xz")
assert scheme == "http"
assert path == "/images/image.raw.xz"

# HTTP URL with query parameters (e.g. CloudFront signed URL)
path, operator, scheme = operator_for_path(
"https://cdn.example.com/images/image.raw.xz?Expires=123&Signature=abc&Key-Pair-Id=xyz"
)
assert scheme == "http"
assert path == "/images/image.raw.xz?Expires=123&Signature=abc&Key-Pair-Id=xyz"
assert "Expires=123" in path
assert "Signature=abc" in path
assert "Key-Pair-Id=xyz" in path

# Filesystem path (use resolve() for the expected value since macOS
# resolves /tmp to /private/tmp)
from pathlib import Path

path, operator, scheme = operator_for_path("/tmp/image.raw.xz")
assert scheme == "fs"
assert path == Path("/tmp/image.raw.xz").resolve()
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import click
from jumpstarter_driver_composite.client import CompositeClient
from jumpstarter_driver_opendal.client import FlasherClient, operator_for_path
from jumpstarter_driver_opendal.client import FlasherClient, clean_filename, operator_for_path
from jumpstarter_driver_power.client import PowerClient
from opendal import Operator

Expand Down Expand Up @@ -39,7 +39,7 @@ def _upload_file_if_needed(self, file_path: str, operator: Operator | None = Non
path_buf = Path(file_path)
operator_scheme = "unknown"

filename = Path(path_buf).name
filename = clean_filename(path_buf)

if self._should_upload_file(self.storage, filename, path_buf, operator, operator_scheme):
if operator_scheme == "http":
Expand Down
Loading