Skip to content
This repository was archived by the owner on Aug 28, 2025. It is now read-only.
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions serotiny/datamodules/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .dataframe import DataframeDatamodule
from .folder import make_folder_dataloader
84 changes: 84 additions & 0 deletions serotiny/datamodules/folder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from typing import Union, Sequence, Callable, Optional
from upath import UPath as Path
from omegaconf import ListConfig
from monai.data import Dataset, PersistentDataset, DataLoader
from monai.transforms import Compose

from serotiny.dataframe.transforms.filter import _filter_columns as filter_filenames


def make_folder_dataloader(
path: Union[Path, str],
transforms: Union[Sequence[Callable], Callable],
cache_dir: Optional[Union[Path, str]] = None,
regex: Optional[str] = None,
startswith: Optional[str] = None,
endswith: Optional[str] = None,
contains: Optional[str] = None,
excludes: Optional[str] = None,
**dataloader_kwargs
):
"""Create a dataloader based on a folder of samples. If no transforms are
applied, each sample is a dictionary with a key "input" containing the
corresponding path and a key "orig_fname" containing the original filename
(with no extension).

Files can be filtered out of the list with name-based rules, using `regex`,
`startswith`, `endswith`, `contains`, `excludes`.

Parameters
----------
path: Union[Path, str],
Path to folder

transforms: Union[Sequence[Callable], Callable],
Transforms to apply to each sample

cache_dir: Optional[Union[Path, str]] = None
Path to a directory in which to store cached transformed inputs, to
accelerate batch loading.

regex: Optional[str] = None
A string containing a regular expression to be matched

startswith: Optional[str] = None
A substring the matching columns must start with

endswith: Optional[str] = None
A substring the matching columns must end with

contains: Optional[str] = None
A substring the matching columns must contain

excludes: Optional[str] = None
A substring the matching columns must not contain

dataloader_kwargs:
Additional keyword arguments are passed to the
torch.utils.data.DataLoader class when instantiating it (aside from
`shuffle` which is only used for the train dataloader).
Among these args are `num_workers`, `batch_size`, `shuffle`, etc.
See the PyTorch docs for more info on these args:
https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader
"""

if isinstance(transforms, (list, tuple, ListConfig)):
transforms = Compose(transforms)

data = filter_filenames(
list(map(str, Path(path).glob("*"))),
regex,
startswith,
endswith,
contains,
excludes,
)

data = [{"input": path, "orig_fname": Path(path).stem} for path in data]

if cache_dir is not None:
dataset = PersistentDataset(data, transform=transforms, cache_dir=cache_dir)
else:
dataset = Dataset(data, transform=transforms)

return DataLoader(dataset, **dataloader_kwargs)