Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion skops/io/_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,9 @@ def __init__(
) -> None:
super().__init__(state, load_context, trusted)
self.trusted = self._get_trusted(trusted, [bytes])
self.children = {"content": io.BytesIO(load_context.src.read(state["file"]))}
self.children = {
"content": io.BytesIO(load_context.read_zip_member(state["file"]))
}

def _construct(self):
content = self.children["content"].getvalue()
Expand Down
2 changes: 1 addition & 1 deletion skops/io/_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(
)
if self.type == "numpy":
self.children = {
"content": io.BytesIO(load_context.src.read(state["file"]))
"content": io.BytesIO(load_context.read_zip_member(state["file"]))
}
elif self.type == "json":
self.children = {
Expand Down
4 changes: 3 additions & 1 deletion skops/io/_scipy.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ def __init__(
f"Cannot load object of type {self.module_name}.{self.class_name}"
)

self.children = {"content": io.BytesIO(load_context.src.read(state["file"]))}
self.children = {
"content": io.BytesIO(load_context.read_zip_member(state["file"]))
}

def _construct(self):
# scipy load_npz uses numpy.save with allow_pickle=False under the
Expand Down
31 changes: 31 additions & 0 deletions skops/io/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@

from ._protocol import PROTOCOL

MAX_ZIP_MEMBER_SIZE = 1024 * 1024 * 1024
MIN_ZIP_MEMBER_SIZE_FOR_RATIO_CHECK = 10 * 1024 * 1024
MAX_ZIP_COMPRESSION_RATIO = 100


# The following two functions are copied from cpython's pickle.py file.
# ---------------------------------------------------------------------
Expand Down Expand Up @@ -141,6 +145,33 @@ class LoadContext:
protocol: int
memo: dict[int, Any] = field(default_factory=dict)

def read_zip_member(self, name: str) -> bytes:
info = self.src.getinfo(name)
if info.file_size > MAX_ZIP_MEMBER_SIZE:
raise ValueError(
f"Zip member {name!r} is too large to load safely: "
f"{info.file_size} bytes"
)

if info.compress_size == 0:
if info.file_size == 0:
return self.src.read(name)
raise ValueError(
f"Zip member {name!r} has an invalid compressed size"
)

compression_ratio = info.file_size / info.compress_size
if (
info.file_size > MIN_ZIP_MEMBER_SIZE_FOR_RATIO_CHECK
and compression_ratio > MAX_ZIP_COMPRESSION_RATIO
):
raise ValueError(
f"Zip member {name!r} has a suspicious compression ratio: "
f"{compression_ratio:.1f}"
)

return self.src.read(name)

def memoize(self, obj: Any, id: int) -> None:
self.memo[id] = obj

Expand Down
40 changes: 39 additions & 1 deletion skops/io/tests/test_persist.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,15 @@
SCIPY_UFUNC_TYPE_NAMES,
SKLEARN_ESTIMATOR_TYPE_NAMES,
)
from skops.io._utils import LoadContext, SaveContext, _get_state, get_state, gettype
from skops.io._utils import (
MAX_ZIP_COMPRESSION_RATIO,
MIN_ZIP_MEMBER_SIZE_FOR_RATIO_CHECK,
LoadContext,
SaveContext,
_get_state,
get_state,
gettype,
)
from skops.io.exceptions import UnsupportedTypeException, UntrustedTypesFoundException
from skops.io.tests._utils import assert_method_outputs_equal, assert_params_equal
from skops.utils._fixes import construct_instances, get_tags
Expand Down Expand Up @@ -1196,6 +1204,36 @@ def test_compression_level():
assert len(dumped_raw) > 5 * len(dumped_compressed)


def test_rejects_suspicious_zip_member_compression_ratio():
dumped = dumps(np.array([1], dtype=np.int64))
buffer = io.BytesIO()

with ZipFile(io.BytesIO(dumped), "r") as src, ZipFile(buffer, "w") as dst:
schema = json.loads(src.read("schema.json"))
npy_name = schema["file"]

dst.writestr("schema.json", json.dumps(schema))
dst.writestr(
npy_name,
b"\x00" * (MIN_ZIP_MEMBER_SIZE_FOR_RATIO_CHECK + 1),
compress_type=ZIP_DEFLATED,
compresslevel=9,
)

dumped = buffer.getvalue()
message = "suspicious compression ratio"

with ZipFile(io.BytesIO(dumped), "r") as zip_file:
info = zip_file.getinfo(npy_name)
assert info.file_size / info.compress_size > MAX_ZIP_COMPRESSION_RATIO

with pytest.raises(ValueError, match=message):
loads(dumped)

with pytest.raises(ValueError, match=message):
get_untrusted_types(data=dumped)


@pytest.mark.parametrize("call_has_canonical_format", [False, True])
def test_sparse_matrix(call_has_canonical_format):
# see https://github.com/skops-dev/skops/pull/375
Expand Down