diff --git a/CMakeLists.txt b/CMakeLists.txt index 09eeecd6b28..2cd3dd7389d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -997,6 +997,18 @@ if(EXECUTORCH_BUILD_PYBIND) EXPORT ExecuTorchTargets LIBRARY DESTINATION executorch/extension/pybindings ) + + # pybind data_loader module - provides PyDataLoader type for external + # pybinding extensions to create custom data loaders + pybind11_add_module( + data_loader SHARED extension/pybindings/pybindings_data_loader.cpp + ) + target_include_directories(data_loader PRIVATE ${_common_include_directories}) + target_compile_options(data_loader PUBLIC ${_pybind_compile_options}) + target_link_libraries(data_loader PRIVATE executorch) + install(TARGETS data_loader + LIBRARY DESTINATION executorch/extension/pybindings + ) endif() if(EXECUTORCH_BUILD_WASM) diff --git a/extension/pybindings/BUCK b/extension/pybindings/BUCK index 878626a1361..4a1ce1a5b76 100644 --- a/extension/pybindings/BUCK +++ b/extension/pybindings/BUCK @@ -1,4 +1,6 @@ load("@fbcode_macros//build_defs:build_file_migration.bzl", "fbcode_target", "non_fbcode_target") +load("@fbcode_macros//build_defs:cpp_library.bzl", "cpp_library") +load("@fbcode_macros//build_defs:cpp_python_extension.bzl", "cpp_python_extension") # Any targets that should be shared between fbcode and xplat must be defined in # targets.bzl. This file can contain fbcode-only targets. @@ -72,3 +74,33 @@ fbcode_target(_kind = runtime.python_library, "//executorch/exir:_warnings", ], ) + +# Header-only library that provides PyDataLoader for external pybinding extensions. +# This allows external libraries (like PTEZ) to create custom data loaders that can +# be passed to _load_for_executorch_from_data_loader(). +fbcode_target( + _kind = cpp_library, + name = "data_loader_types", + headers = ["pybindings_data_loader.h"], + exported_deps = [ + "//executorch/runtime/core:core", + ], + visibility = ["PUBLIC"], +) + +# Python extension that registers the PyDataLoader type. +# This allows external libraries to create PyDataLoader instances without +# importing the full core pybindings. +fbcode_target( + _kind = cpp_python_extension, + name = "data_loader", + srcs = ["pybindings_data_loader.cpp"], + base_module = "executorch.extension.pybindings", + deps = [ + ":data_loader_types", + ], + external_deps = [ + "pybind11", + ], + visibility = ["PUBLIC"], +) diff --git a/extension/pybindings/data_loader.pyi b/extension/pybindings/data_loader.pyi new file mode 100644 index 00000000000..90ae1c1889d --- /dev/null +++ b/extension/pybindings/data_loader.pyi @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +class PyDataLoader: + """Pybind11 wrapper for DataLoader.""" + + ... diff --git a/extension/pybindings/pybindings.cpp b/extension/pybindings/pybindings.cpp index eb81bda22f7..684a345e334 100644 --- a/extension/pybindings/pybindings.cpp +++ b/extension/pybindings/pybindings.cpp @@ -24,6 +24,7 @@ #include #include #include +#include #include #include #include @@ -85,6 +86,8 @@ using ::executorch::extension::BufferDataLoader; using ::executorch::extension::MallocMemoryAllocator; using ::executorch::extension::MmapDataLoader; using ::executorch::extension::ET_BUNDLED_MODULE_NAMESPACE::BundledModule; +using ::executorch::extension::pybindings::PyDataLoader; +using ::executorch::extension::pybindings::SharedPtrDataLoader; using ::executorch::runtime::ArrayRef; using ::executorch::runtime::DataLoader; using ::executorch::runtime::Error; @@ -246,6 +249,29 @@ inline std::unique_ptr load_module_from_buffer_with_data_file( std::move(data_loader)); } +inline std::unique_ptr load_module_from_data_loader( + std::shared_ptr loader, + std::optional data_map_path, + std::unique_ptr event_tracer) { + EXECUTORCH_SCOPE_PROF("load_module_from_data_loader"); + + if (data_map_path.has_value()) { + auto data_map_loader = loader_from_file(data_map_path.value()); + return std::make_unique( + loader->make_delegating_loader(), + nullptr, // memory_allocator + nullptr, // temp_allocator + std::move(event_tracer), // event_tracer + std::move(data_map_loader)); // data_map_loader + } + return std::make_unique( + loader->make_delegating_loader(), + nullptr, // memory_allocator + nullptr, // temp_allocator + std::move(event_tracer), // event_tracer + nullptr); // data_map_loader +} + inline py::list get_outputs_as_py_list( const std::vector& outputs, bool clone_outputs = true) { @@ -601,6 +627,17 @@ struct PyModule final { setup_event_tracer(enable_etdump, debug_buffer_size), program_verification)) {} + explicit PyModule( + std::shared_ptr loader, + std::optional data_path, + bool enable_etdump, + size_t debug_buffer_size = 0) + : debug_buffer_size_(debug_buffer_size), + module_(load_module_from_data_loader( + std::move(loader), + data_path, + setup_event_tracer(enable_etdump, debug_buffer_size))) {} + PyModule(const PyModule&) = delete; PyModule& operator=(const PyModule&) = delete; PyModule(PyModule&&) = default; @@ -676,6 +713,17 @@ struct PyModule final { Program::Verification::InternalConsistency); } + // Load from an external data loader. + // This allows external libraries (like PTEZ) to provide custom data loaders. + static std::unique_ptr load_from_data_loader( + std::shared_ptr loader, + std::optional data_path, + bool enable_etdump, + size_t debug_buffer_size = 0) { + return std::make_unique( + std::move(loader), data_path, enable_etdump, debug_buffer_size); + } + py::list run_method( const std::string& method_name, const py::sequence& inputs, @@ -1529,6 +1577,20 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) { py::arg("buffer"), py::arg("non_const_pool_size") = kDEFAULT_BUNDLED_INPUT_POOL_SIZE, call_guard); + + // Import the PyDataLoader type from the shared module. + // This ensures the type is registered once and shared across all modules. + py::module_::import("executorch.extension.pybindings.data_loader"); + + m.def( + "_load_for_executorch_from_data_loader", + &PyModule::load_from_data_loader, + py::arg("loader"), + py::arg("data_path") = py::none(), + py::arg("enable_etdump") = false, + py::arg("debug_buffer_size") = 0, + call_guard); + m.def( "_dump_profile_results", []() { diff --git a/extension/pybindings/pybindings.pyi b/extension/pybindings/pybindings.pyi index 9e5ab6211ce..9e38b2b3c6d 100644 --- a/extension/pybindings/pybindings.pyi +++ b/extension/pybindings/pybindings.pyi @@ -154,6 +154,9 @@ class MethodMeta: def __repr__(self) -> str: ... +# Re-export PyDataLoader from the shared module for backward compatibility. +from executorch.extension.pybindings.data_loader import PyDataLoader as PyDataLoader + @experimental("This API is experimental and subject to change without notice.") def _load_for_executorch( program_path: str, @@ -215,6 +218,33 @@ def _load_for_executorch_from_bundled_program( """ ... +@experimental("This API is experimental and subject to change without notice.") +def _load_for_executorch_from_data_loader( + loader: PyDataLoader, + data_path: Optional[str] = None, + enable_etdump: bool = False, + debug_buffer_size: int = 0, +) -> ExecuTorchModule: + """Load an ExecuTorch Program from a PyDataLoader. + + This function allows external libraries to provide custom data loaders + (e.g., for compressed files) and load programs using them. + + .. warning:: + + This API is experimental and subject to change without notice. + + Args: + loader: A PyDataLoader wrapping a custom DataLoader implementation. + data_path: Optional path to a data file (e.g., for external weights). + enable_etdump: If true, enables an ETDump which can store profiling information. + debug_buffer_size: If non-zero, enables a debug buffer for intermediate results. + + Returns: + An ExecuTorchModule ready for execution. + """ + ... + @experimental("This API is experimental and subject to change without notice.") def _load_bundled_program_from_buffer( buffer: bytes, non_const_pool_size: int = ... diff --git a/extension/pybindings/pybindings_data_loader.cpp b/extension/pybindings/pybindings_data_loader.cpp new file mode 100644 index 00000000000..eaa346a84a1 --- /dev/null +++ b/extension/pybindings/pybindings_data_loader.cpp @@ -0,0 +1,19 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +namespace py = pybind11; + +using ::executorch::extension::pybindings::PyDataLoader; + +PYBIND11_MODULE(data_loader, m) { + py::class_>(m, "PyDataLoader"); +} diff --git a/extension/pybindings/pybindings_data_loader.h b/extension/pybindings/pybindings_data_loader.h new file mode 100644 index 00000000000..fbb76c00c17 --- /dev/null +++ b/extension/pybindings/pybindings_data_loader.h @@ -0,0 +1,77 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include +#include +#include +#include + +namespace executorch { +namespace extension { +namespace pybindings { + +/// DataLoader wrapper holding a shared_ptr, allowing sharing between Python +/// and C++ while Module takes ownership via unique_ptr. +class SharedPtrDataLoader final : public runtime::DataLoader { + public: + explicit SharedPtrDataLoader(std::shared_ptr loader) + : loader_(std::move(loader)) {} + + ET_NODISCARD runtime::Result load( + size_t offset, + size_t size, + const SegmentInfo& segment_info) const override { + return loader_->load(offset, size, segment_info); + } + + ET_NODISCARD runtime::Result size() const override { + return loader_->size(); + } + + ET_NODISCARD runtime::Error load_into( + size_t offset, + size_t size, + const SegmentInfo& segment_info, + void* buffer) const override { + return loader_->load_into(offset, size, segment_info, buffer); + } + + private: + std::shared_ptr loader_; +}; + +/// Pybind11 wrapper for DataLoader. Use shared_ptr holder type in pybind11. +struct PyDataLoader { + explicit PyDataLoader(std::shared_ptr loader) + : loader_(std::move(loader)) {} + + PyDataLoader(const PyDataLoader&) = delete; + PyDataLoader& operator=(const PyDataLoader&) = delete; + PyDataLoader(PyDataLoader&&) = default; + PyDataLoader& operator=(PyDataLoader&&) = default; + + std::shared_ptr get() const { + return loader_; + } + + /// Creates a unique_ptr DataLoader that delegates to the shared loader. + std::unique_ptr make_delegating_loader() const { + return std::make_unique(loader_); + } + + private: + std::shared_ptr loader_; +}; + +} // namespace pybindings +} // namespace extension +} // namespace executorch diff --git a/setup.py b/setup.py index f60e6202c30..e8527c47d95 100644 --- a/setup.py +++ b/setup.py @@ -805,6 +805,13 @@ def run(self): # noqa C901 modpath="executorch.extension.pybindings._portable_lib", dependent_cmake_flags=["EXECUTORCH_BUILD_PYBIND"], ), + # Install the data_loader pybindings extension which provides the + # PyDataLoader type for external pybinding extensions. + BuiltExtension( + src="data_loader.cp*" if _is_windows() else "data_loader.*", + modpath="executorch.extension.pybindings.data_loader", + dependent_cmake_flags=["EXECUTORCH_BUILD_PYBIND"], + ), BuiltExtension( src="extension/training/_training_lib.*", # @lint-ignore https://github.com/pytorch/executorch/blob/cb3eba0d7f630bc8cec0a9cc1df8ae2f17af3f7a/scripts/lint_xrefs.sh modpath="executorch.extension.training.pybindings._training_lib", diff --git a/shim_et/xplat/executorch/extension/pybindings/pybindings.bzl b/shim_et/xplat/executorch/extension/pybindings/pybindings.bzl index 7e14ca8713a..23bd0153e9c 100644 --- a/shim_et/xplat/executorch/extension/pybindings/pybindings.bzl +++ b/shim_et/xplat/executorch/extension/pybindings/pybindings.bzl @@ -61,6 +61,7 @@ def executorch_pybindings(python_module_name, srcs = [], cppdeps = [], visibilit deps = [ "//executorch/runtime/core:core", "//executorch/extension/threadpool:threadpool", + "//executorch/extension/pybindings:data_loader_types", ] + cppdeps, external_deps = [ "pybind11",