Skip to content

Exception when using FullyConnectedTensorProductConv's second order derivatives #264

@rwkeane

Description

@rwkeane

Describe the bug
When called using the torch.func stack (I havent verified with autograd since it isnt relevant for me), the FullyConnectedTensorProductConv causes a crash because it thinks the sizes it gets for the bwd_bwd pass are incorrect.

Note that, due to autograd limitations, torch.func is the ONLY path to support compiled second order derivatives. So this bug is blocking compilation support for the library.

NOTE: The repro may require applying the fix called out here: pytorch/pytorch#168393

To Reproduce

import torch
import cuequivariance as cue
from cuequivariance import ir_mul
from cuequivariance_torch.layers import FullyConnectedTensorProductConv

conv = FullyConnectedTensorProductConv(
    in_irreps=cue.Irreps("SO3", "3x0 + 3x1"),
    sh_irreps=cue.Irreps("SO3", "1x0 + 1x1"),
    out_irreps=cue.Irreps("SO3", "3x0 + 3x1"),
    batch_norm=False,
    mlp_channels=None,
    layout=ir_mul,
    use_fallback=None,
).cuda()

N, E, W = 8, 32, conv.tp.weight_numel  # W = 45
torch.manual_seed(0)
src_idx = torch.randint(0, N, (E,), device="cuda")
dst_idx = torch.randint(0, N, (E,), device="cuda")
graph = (torch.stack([src_idx, dst_idx]), (N, N))

src_features = torch.randn(N, 12, device="cuda")
edge_sh = torch.randn(E, 4, device="cuda", requires_grad=True)
edge_emb = torch.randn(E, W, device="cuda", requires_grad=True)


def energy_from_edge_sh(edge_sh_):
    return conv(
        src_features=src_features,
        edge_sh=edge_sh_,
        edge_emb=edge_emb,
        graph=graph,
        reduce="sum",
    ).sum()


def loss_from_edge_emb(edge_emb_):
    global edge_emb
    edge_emb = edge_emb_
    energy, vjp_fn = torch.func.vjp(energy_from_edge_sh, edge_sh)
    (neg_forces,) = vjp_fn(torch.ones_like(energy))
    return energy + neg_forces.sum()


# Crashes: ValueError: Received invalid size of tensor,
#          expected [32, 45] but received [32, 4].
torch.func.grad_and_value(loss_from_edge_emb)(edge_emb)

Note that I see the same behavior when using full_tensor_product with SegmentedPolynomial and method="fused_tp", but not when using method="uniform_1d"

Expected behavior
The bwd_bwd function should work as expected

GPU HW/SW(please complete the following information):

  • torch or ngc docker version: PyTorch 2.8
  • Driver version: 12.9
  • cuEquivaraince version: 0.9.1
  • full name of GPU: 5060 TI

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions