(prepared with the help of claude code)
Describe the bug
SegmentedPolynomial checkpoints saved on CUDA cannot be loaded on CPU due to missing __reduce__ on backend classes, preventing model portability.
SegmentedPolynomial.__init__ selects a backend at construction time based on whether cuequivariance_ops_torch is installed:
- CUDA machine → self.m = SegmentedPolynomialFromUniform1dJit(...)
- CPU machine → self.m = SegmentedPolynomialNaive(...)
PyTorch's default nn.Module pickling freezes the class type into the checkpoint. When the checkpoint is loaded on a CPU-only machine, SegmentedPolynomialFromUniform1dJit is reconstructed directly (bypassing __init__), but tensor_product_uniform_1d_jit is None at the module level. The first forward() call raises:
TypeError: 'NoneType' object is not callable
Fix: Add __reduce__ to SegmentedPolynomialFromUniform1dJit (and FusedTP, IndexedLinear) to re-delegate to SegmentedPolynomial.__init__ at unpickle time, which re-runs the HAS_CUE_OPS check and selects the correct backend for the loading machine. Requires storing the original polynomial and math_dtype arguments as instance attributes.
Affected files:
- cuequivariance_torch/primitives/segmented_polynomial_uniform_1d.py — SegmentedPolynomialFromUniform1dJit (lines 89–301)
- cuequivariance_torch/primitives/segmented_polynomial_fused_tp.py — SegmentedPolynomialFusedTP
- cuequivariance_torch/primitives/segmented_polynomial_indexed_linear.py — SegmentedPolynomialIndexedLinear
- cuequivariance_torch/primitives/segmented_polynomial.py — SegmentedPolynomial (line 178–203 for the backend selection, no __reduce__)
To Reproduce
This Python script illustrates the issue (requires torch, cuequivariance, cuequivariance_torch):
issue-1-cuequivariance-checkpoint-portability-mre.py
Expected behavior
Serializing/deserializing models based on cuequivariance should use the hardware available on the current system, rather than the system used when the model was initially constructed.
GPU HW/SW(please complete the following information):
Environment:
torch version: 2.11.0+cu130
cuequivariance: 0.9.1
cuequivariance_torch: 0.9.1
(prepared with the help of claude code)
Describe the bug
SegmentedPolynomialcheckpoints saved on CUDA cannot be loaded on CPU due to missing__reduce__on backend classes, preventing model portability.SegmentedPolynomial.__init__selects a backend at construction time based on whethercuequivariance_ops_torchis installed:- CUDA machine →
self.m = SegmentedPolynomialFromUniform1dJit(...)- CPU machine →
self.m = SegmentedPolynomialNaive(...)PyTorch's default
nn.Modulepickling freezes the class type into the checkpoint. When the checkpoint is loaded on a CPU-only machine,SegmentedPolynomialFromUniform1dJitis reconstructed directly (bypassing__init__), buttensor_product_uniform_1d_jitisNoneat the module level. The firstforward()call raises:TypeError: 'NoneType' object is not callableFix: Add
__reduce__toSegmentedPolynomialFromUniform1dJit(andFusedTP,IndexedLinear) to re-delegate toSegmentedPolynomial.__init__at unpickle time, which re-runs theHAS_CUE_OPScheck and selects the correct backend for the loading machine. Requires storing the original polynomial andmath_dtypearguments as instance attributes.Affected files:
-
cuequivariance_torch/primitives/segmented_polynomial_uniform_1d.py—SegmentedPolynomialFromUniform1dJit(lines 89–301)-
cuequivariance_torch/primitives/segmented_polynomial_fused_tp.py—SegmentedPolynomialFusedTP-
cuequivariance_torch/primitives/segmented_polynomial_indexed_linear.py—SegmentedPolynomialIndexedLinear-
cuequivariance_torch/primitives/segmented_polynomial.py—SegmentedPolynomial(line 178–203 for the backend selection, no__reduce__)To Reproduce
This Python script illustrates the issue (requires
torch,cuequivariance,cuequivariance_torch):issue-1-cuequivariance-checkpoint-portability-mre.py
Expected behavior
Serializing/deserializing models based on
cuequivarianceshould use the hardware available on the current system, rather than the system used when the model was initially constructed.GPU HW/SW(please complete the following information):
Environment:
torch version: 2.11.0+cu130
cuequivariance: 0.9.1
cuequivariance_torch: 0.9.1