diff --git a/skops/io/_general.py b/skops/io/_general.py index 8437f863..8105a600 100644 --- a/skops/io/_general.py +++ b/skops/io/_general.py @@ -633,7 +633,9 @@ def __init__( ) -> None: super().__init__(state, load_context, trusted) self.trusted = self._get_trusted(trusted, [bytes]) - self.children = {"content": io.BytesIO(load_context.src.read(state["file"]))} + self.children = { + "content": io.BytesIO(load_context.read_zip_member(state["file"])) + } def _construct(self): content = self.children["content"].getvalue() diff --git a/skops/io/_numpy.py b/skops/io/_numpy.py index 39ec8dcb..181b66a9 100644 --- a/skops/io/_numpy.py +++ b/skops/io/_numpy.py @@ -70,7 +70,7 @@ def __init__( ) if self.type == "numpy": self.children = { - "content": io.BytesIO(load_context.src.read(state["file"])) + "content": io.BytesIO(load_context.read_zip_member(state["file"])) } elif self.type == "json": self.children = { diff --git a/skops/io/_scipy.py b/skops/io/_scipy.py index 6c8ff4dd..421963b2 100644 --- a/skops/io/_scipy.py +++ b/skops/io/_scipy.py @@ -48,7 +48,9 @@ def __init__( f"Cannot load object of type {self.module_name}.{self.class_name}" ) - self.children = {"content": io.BytesIO(load_context.src.read(state["file"]))} + self.children = { + "content": io.BytesIO(load_context.read_zip_member(state["file"])) + } def _construct(self): # scipy load_npz uses numpy.save with allow_pickle=False under the diff --git a/skops/io/_utils.py b/skops/io/_utils.py index a3519ab9..4b63d3dd 100644 --- a/skops/io/_utils.py +++ b/skops/io/_utils.py @@ -11,6 +11,10 @@ from ._protocol import PROTOCOL +MAX_ZIP_MEMBER_SIZE = 1024 * 1024 * 1024 +MIN_ZIP_MEMBER_SIZE_FOR_RATIO_CHECK = 10 * 1024 * 1024 +MAX_ZIP_COMPRESSION_RATIO = 100 + # The following two functions are copied from cpython's pickle.py file. # --------------------------------------------------------------------- @@ -141,6 +145,33 @@ class LoadContext: protocol: int memo: dict[int, Any] = field(default_factory=dict) + def read_zip_member(self, name: str) -> bytes: + info = self.src.getinfo(name) + if info.file_size > MAX_ZIP_MEMBER_SIZE: + raise ValueError( + f"Zip member {name!r} is too large to load safely: " + f"{info.file_size} bytes" + ) + + if info.compress_size == 0: + if info.file_size == 0: + return self.src.read(name) + raise ValueError( + f"Zip member {name!r} has an invalid compressed size" + ) + + compression_ratio = info.file_size / info.compress_size + if ( + info.file_size > MIN_ZIP_MEMBER_SIZE_FOR_RATIO_CHECK + and compression_ratio > MAX_ZIP_COMPRESSION_RATIO + ): + raise ValueError( + f"Zip member {name!r} has a suspicious compression ratio: " + f"{compression_ratio:.1f}" + ) + + return self.src.read(name) + def memoize(self, obj: Any, id: int) -> None: self.memo[id] = obj diff --git a/skops/io/tests/test_persist.py b/skops/io/tests/test_persist.py index 6b8eed41..bf66fa91 100644 --- a/skops/io/tests/test_persist.py +++ b/skops/io/tests/test_persist.py @@ -71,7 +71,15 @@ SCIPY_UFUNC_TYPE_NAMES, SKLEARN_ESTIMATOR_TYPE_NAMES, ) -from skops.io._utils import LoadContext, SaveContext, _get_state, get_state, gettype +from skops.io._utils import ( + MAX_ZIP_COMPRESSION_RATIO, + MIN_ZIP_MEMBER_SIZE_FOR_RATIO_CHECK, + LoadContext, + SaveContext, + _get_state, + get_state, + gettype, +) from skops.io.exceptions import UnsupportedTypeException, UntrustedTypesFoundException from skops.io.tests._utils import assert_method_outputs_equal, assert_params_equal from skops.utils._fixes import construct_instances, get_tags @@ -1196,6 +1204,36 @@ def test_compression_level(): assert len(dumped_raw) > 5 * len(dumped_compressed) +def test_rejects_suspicious_zip_member_compression_ratio(): + dumped = dumps(np.array([1], dtype=np.int64)) + buffer = io.BytesIO() + + with ZipFile(io.BytesIO(dumped), "r") as src, ZipFile(buffer, "w") as dst: + schema = json.loads(src.read("schema.json")) + npy_name = schema["file"] + + dst.writestr("schema.json", json.dumps(schema)) + dst.writestr( + npy_name, + b"\x00" * (MIN_ZIP_MEMBER_SIZE_FOR_RATIO_CHECK + 1), + compress_type=ZIP_DEFLATED, + compresslevel=9, + ) + + dumped = buffer.getvalue() + message = "suspicious compression ratio" + + with ZipFile(io.BytesIO(dumped), "r") as zip_file: + info = zip_file.getinfo(npy_name) + assert info.file_size / info.compress_size > MAX_ZIP_COMPRESSION_RATIO + + with pytest.raises(ValueError, match=message): + loads(dumped) + + with pytest.raises(ValueError, match=message): + get_untrusted_types(data=dumped) + + @pytest.mark.parametrize("call_has_canonical_format", [False, True]) def test_sparse_matrix(call_has_canonical_format): # see https://github.com/skops-dev/skops/pull/375