import torch
import cuequivariance as cue
from cuequivariance import ir_mul
from cuequivariance_torch.layers import FullyConnectedTensorProductConv
conv = FullyConnectedTensorProductConv(
in_irreps=cue.Irreps("SO3", "3x0 + 3x1"),
sh_irreps=cue.Irreps("SO3", "1x0 + 1x1"),
out_irreps=cue.Irreps("SO3", "3x0 + 3x1"),
batch_norm=False,
mlp_channels=None,
layout=ir_mul,
use_fallback=None,
).cuda()
N, E, W = 8, 32, conv.tp.weight_numel # W = 45
torch.manual_seed(0)
src_idx = torch.randint(0, N, (E,), device="cuda")
dst_idx = torch.randint(0, N, (E,), device="cuda")
graph = (torch.stack([src_idx, dst_idx]), (N, N))
src_features = torch.randn(N, 12, device="cuda")
edge_sh = torch.randn(E, 4, device="cuda", requires_grad=True)
edge_emb = torch.randn(E, W, device="cuda", requires_grad=True)
def energy_from_edge_sh(edge_sh_):
return conv(
src_features=src_features,
edge_sh=edge_sh_,
edge_emb=edge_emb,
graph=graph,
reduce="sum",
).sum()
def loss_from_edge_emb(edge_emb_):
global edge_emb
edge_emb = edge_emb_
energy, vjp_fn = torch.func.vjp(energy_from_edge_sh, edge_sh)
(neg_forces,) = vjp_fn(torch.ones_like(energy))
return energy + neg_forces.sum()
# Crashes: ValueError: Received invalid size of tensor,
# expected [32, 45] but received [32, 4].
torch.func.grad_and_value(loss_from_edge_emb)(edge_emb)
Describe the bug
When called using the
torch.funcstack (I havent verified with autograd since it isnt relevant for me), theFullyConnectedTensorProductConvcauses a crash because it thinks the sizes it gets for the bwd_bwd pass are incorrect.Note that, due to autograd limitations,
torch.funcis the ONLY path to support compiled second order derivatives. So this bug is blocking compilation support for the library.NOTE: The repro may require applying the fix called out here: pytorch/pytorch#168393
To Reproduce
Note that I see the same behavior when using
full_tensor_productwithSegmentedPolynomialandmethod="fused_tp", but not when usingmethod="uniform_1d"Expected behavior
The bwd_bwd function should work as expected
GPU HW/SW(please complete the following information):