Skip to content

torch.compile with reduce-overhead + GatedNonLinearity Reimplementation#1398

Merged
ilyes319 merged 2 commits intoACEsuit:developfrom
ilyes319:torchsim-compile-clean
Mar 25, 2026
Merged

torch.compile with reduce-overhead + GatedNonLinearity Reimplementation#1398
ilyes319 merged 2 commits intoACEsuit:developfrom
ilyes319:torchsim-compile-clean

Conversation

@ilyes319
Copy link
Copy Markdown
Contributor

@ilyes319 ilyes319 commented Mar 5, 2026

Summary

This PR replaces e3nn's nn.Gate with a pure-torch GatedEquivariantBlock that eliminates all graph breaks under torch.compile, and adds CUDA graph support for reduce-overhead mode.

e3nn's Gate causes three types of graph breaks: data-dependent if gates.shape[-1]:, _Sortcut/Extract with dynamic slicing via fx.Graph codegen, and ElementwiseTensorProduct as TorchScript. The new GatedEquivariantBlock uses only torch.narrow and reshape with Python int args, which dynamo handles natively. It supports both mul_ir (e3nn) and ir_mul (cuequivariance) layouts, so the TransposeIrrepsLayout wrappers that were previously needed around the gate are no longer required.

The TorchSim wrapper also gains CUDA graph support (reduce-overhead mode), which requires PyTorch >= 2.10 for proper CUDA graph partitioning of custom ops.

Additionally, MaceTorchSimModel now accepts a head parameter (string name, integer index, or None) to select which head to use for multi-head models like mace-mh-1. Previously head index 0 was hardcoded.

Changes by file

mace/modules/gate.py (new)

  • GatedEquivariantBlock: drop-in Gate replacement with layout parameter. Forward path branches on layout to reshape gated chunks correctly for both mul_ir and ir_mul. Normalization constants (normalize2mom) are cached to avoid recomputing 1M-sample Monte Carlo estimates on every construction.

mace/modules/blocks.py

  • Replace nn.Gate with GatedEquivariantBlock at all 4 call sites: NonLinearDipoleReadoutBlock, NonLinearDipolePolarReadoutBlock, GeneralNonLinearBiasReadoutBlock, RealAgnosticResidualNonLinearInteractionBlock.
  • Remove TransposeIrrepsLayoutWrapper creation and forward-path conditionals from the two blocks that had them.
  • Replace torch.nonzero(node_attrs)[:, 1] with node_attrs.argmax(dim=-1) in EquivariantProductBasisBlock (compile-friendly fixed output shape).

mace/modules/wrapper_ops.py

  • Add get_layout() helper that extracts the layout string from CuEquivarianceConfig.

mace/calculators/mace_torchsim.py

  • Add cudagraphs parameter for reduce-overhead mode.
  • Call torch.compiler.cudagraph_mark_step_begin() at forward entry.
  • Clone outputs when CUDA graphs are active (buffer reuse safety).
  • Add head parameter for multi-head model support.

mace/data/padding_tools.py

  • Change fake edge construction from modular cross-atom indexing to self-loops on the last fake atom, consistent with the TorchSim wrapper's isolated-system padding.

tests/test_gate.py (new)

  • 94 parametrized tests: forward/backward/second-derivative match vs e3nn, cross-layout equivalence via manual transpose, zero graph breaks for both layouts, custom activations, caching.

tests/test_compile.py

  • Skip cueq + reduce-overhead/max-autotune on PyTorch < 2.10 (CUDA graph partitioning for custom ops landed in 2.10).

Test plan

  • 94 gate tests pass (both layouts, all match e3nn, zero graph breaks)
  • Padding tests pass
  • Compile tests pass (CPU + CUDA, stress, graph breaks = 0)
  • cueq + reduce-overhead verified on PyTorch 2.10 / H100
  • Head selection: string, int, None all work; error on invalid head
  • mace-mh-1 + head=mp_pbe_refit_add + cueq + reduce-overhead: correct forces
  • Pre-commit clean (black, isort, pylint 10/10)

@ilyes319 ilyes319 force-pushed the torchsim-compile-clean branch from bf9b236 to a10f1c4 Compare March 5, 2026 16:33
…ompile

Add a pure-torch gated equivariant nonlinearity that eliminates graph breaks
from e3nn's nn.Gate (data-dependent control flow, TorchScript codegen,
ElementwiseTensorProduct). Supports both mul_ir and ir_mul layouts natively,
removing the need for TransposeIrrepsLayout wrappers around the gate.

- Add mace/modules/gate.py with GatedEquivariantBlock (cached normalize2mom)
- Add get_layout() helper to wrapper_ops.py for backend layout detection
- Update all Gate call sites in blocks.py with layout=get_layout(cueq_config)
- Remove TransposeIrrepsLayoutWrapper from interaction/readout blocks
- Replace torch.nonzero with torch.argmax in EquivariantProductBasisBlock
- Use self-loop padding in padding_tools.py (consistent with TorchSim)
- Add CUDA graph support to TorchSim wrapper (cudagraph_mark_step_begin, clone)
- Skip cueq+reduce-overhead tests on PyTorch < 2.10 (CUDA graph partitioning)
- Add 94 parametrized tests for GatedEquivariantBlock (both layouts, compile)
@ilyes319 ilyes319 force-pushed the torchsim-compile-clean branch from a10f1c4 to 9a63ef1 Compare March 5, 2026 17:44
@ilyes319 ilyes319 force-pushed the torchsim-compile-clean branch from 2406088 to 9e5f377 Compare March 9, 2026 11:08
@ilyes319 ilyes319 merged commit 9207e22 into ACEsuit:develop Mar 25, 2026
10 of 85 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant