torch.compile with reduce-overhead + GatedNonLinearity Reimplementation#1398
Merged
ilyes319 merged 2 commits intoACEsuit:developfrom Mar 25, 2026
Merged
torch.compile with reduce-overhead + GatedNonLinearity Reimplementation#1398ilyes319 merged 2 commits intoACEsuit:developfrom
ilyes319 merged 2 commits intoACEsuit:developfrom
Conversation
bf9b236 to
a10f1c4
Compare
…ompile Add a pure-torch gated equivariant nonlinearity that eliminates graph breaks from e3nn's nn.Gate (data-dependent control flow, TorchScript codegen, ElementwiseTensorProduct). Supports both mul_ir and ir_mul layouts natively, removing the need for TransposeIrrepsLayout wrappers around the gate. - Add mace/modules/gate.py with GatedEquivariantBlock (cached normalize2mom) - Add get_layout() helper to wrapper_ops.py for backend layout detection - Update all Gate call sites in blocks.py with layout=get_layout(cueq_config) - Remove TransposeIrrepsLayoutWrapper from interaction/readout blocks - Replace torch.nonzero with torch.argmax in EquivariantProductBasisBlock - Use self-loop padding in padding_tools.py (consistent with TorchSim) - Add CUDA graph support to TorchSim wrapper (cudagraph_mark_step_begin, clone) - Skip cueq+reduce-overhead tests on PyTorch < 2.10 (CUDA graph partitioning) - Add 94 parametrized tests for GatedEquivariantBlock (both layouts, compile)
a10f1c4 to
9a63ef1
Compare
2406088 to
9e5f377
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR replaces e3nn's
nn.Gatewith a pure-torchGatedEquivariantBlockthat eliminates all graph breaks undertorch.compile, and adds CUDA graph support forreduce-overheadmode.e3nn's Gate causes three types of graph breaks: data-dependent
if gates.shape[-1]:,_Sortcut/Extractwith dynamic slicing viafx.Graphcodegen, andElementwiseTensorProductas TorchScript. The newGatedEquivariantBlockuses onlytorch.narrowandreshapewith Python int args, which dynamo handles natively. It supports bothmul_ir(e3nn) andir_mul(cuequivariance) layouts, so theTransposeIrrepsLayoutwrappers that were previously needed around the gate are no longer required.The TorchSim wrapper also gains CUDA graph support (
reduce-overheadmode), which requires PyTorch >= 2.10 for proper CUDA graph partitioning of custom ops.Additionally,
MaceTorchSimModelnow accepts aheadparameter (string name, integer index, or None) to select which head to use for multi-head models like mace-mh-1. Previously head index 0 was hardcoded.Changes by file
mace/modules/gate.py (new)
GatedEquivariantBlock: drop-in Gate replacement withlayoutparameter. Forward path branches on layout to reshape gated chunks correctly for bothmul_irandir_mul. Normalization constants (normalize2mom) are cached to avoid recomputing 1M-sample Monte Carlo estimates on every construction.mace/modules/blocks.py
nn.GatewithGatedEquivariantBlockat all 4 call sites:NonLinearDipoleReadoutBlock,NonLinearDipolePolarReadoutBlock,GeneralNonLinearBiasReadoutBlock,RealAgnosticResidualNonLinearInteractionBlock.TransposeIrrepsLayoutWrappercreation and forward-path conditionals from the two blocks that had them.torch.nonzero(node_attrs)[:, 1]withnode_attrs.argmax(dim=-1)inEquivariantProductBasisBlock(compile-friendly fixed output shape).mace/modules/wrapper_ops.py
get_layout()helper that extracts the layout string fromCuEquivarianceConfig.mace/calculators/mace_torchsim.py
cudagraphsparameter forreduce-overheadmode.torch.compiler.cudagraph_mark_step_begin()at forward entry.headparameter for multi-head model support.mace/data/padding_tools.py
tests/test_gate.py (new)
tests/test_compile.py
reduce-overhead/max-autotuneon PyTorch < 2.10 (CUDA graph partitioning for custom ops landed in 2.10).Test plan