Skip to content

[ROCm] Include ROCM support for CUDA extensions#4180

Open
amd-sriram wants to merge 23 commits intopytorch:mainfrom
ROCm:main
Open

[ROCm] Include ROCM support for CUDA extensions#4180
amd-sriram wants to merge 23 commits intopytorch:mainfrom
ROCm:main

Conversation

@amd-sriram
Copy link
Copy Markdown

@amd-sriram amd-sriram commented Feb 23, 2026

Motivation

Port cuda extensions to ROCm:

  • RNNTLoss
  • lfilter (iir)
  • forced align
  • CU CTC
  • rir (removed)

Technical Details

Changes to tools/setup_helpers/extension.py

cuda source files are added for _USE_ROCM flag.
e.g.
if _USE_CUDA or _USE_ROCM:
sources.append("iir_cuda.cu")

Fixing compilation issues

The following fixes have been made to fix the following errors:

1. TORCH_HIP_VERSION is not defined

/skishore/github/audio/src/libtorchaudio/utils_hip.cpp:20:10: error: ‘TORCH_HIP_VERSION’ was not declared in this scope; did you mean ‘TORCH_ABI_VERSION’? 

TORCH_HIP_VERSION is defined in tools/setup_helpers/extension.py , similiar to ttps://github.com/ROCm/pytorch/blob/develop/cmake/public/LoadHIP.cmake#L166 math(EXPR TORCH_HIP_VERSION "(${HIP_VERSION_MAJOR} * 100) + ${HIP_VERSION_MINOR}")

2. kernel launch parameters are not proper

/skishore/github/audio/src/libtorchaudio/iir_hip.hip:75:8: error: too few arguments provided to function-like macro invocation 

         75 |        hipLaunchKernelGGL(( (iir_cu_kernel<scalar_t>), dim3(blocks), dim3(threads), 0, 0, 

Correct the parameters in THO_DISPATCH_V2 based on https://github.com/ROCm/pytorch/blob/develop/test/cpp_extensions/libtorch_agn_2_9_extension/csrc/kernel.cpp#L361

  THO_DISPATCH_V2(m.scalar_type(), "mv_tensor_accessor_cpu",
                  AT_WRAP(([&]() {
                    auto resa = Accessor_cpu<scalar_t, 1>(reinterpret_cast<scalar_t*>(res.data_ptr()), res.sizes().data(), res.strides().data());
                    auto ma = Accessor_cpu<scalar_t, 2>(reinterpret_cast<scalar_t*>(m.data_ptr()), m.sizes().data(), m.strides().data());
                    auto va = Accessor_cpu<scalar_t, 1>(reinterpret_cast<scalar_t*>(v.data_ptr()), v.sizes().data(), v.strides().data());
                    mv_tensor_accessor_kernel<Accessor_cpu, scalar_t>(resa, ma, va);
                  })),
                  AT_FLOATING_TYPES);

Test Plan

Run this branch in both Nvidia machine and AMD machine, check if it installs and run the unit tests for the cuda extensions:

python -m pip install . --no-build-isolation

pytest test/torchaudio_unittest/functional/functional_cuda_test.py -k test_rnnt  
pytest test/torchaudio_unittest/functional/torchscript_consistency_cuda_test.py -k test_rnnt 
pytest test/torchaudio_unittest/functional/autograd_cuda_test.py -k test_rnnt
pytest test/torchaudio_unittest/transforms/autograd_cuda_test.py -k test_rnnt
pytest test/torchaudio_unittest/transforms/torchscript_consistency_cuda_test.py -k test_rnnt 

pytest test/torchaudio_unittest/functional/functional_cuda_test.py -k test_lfilter
pytest test/torchaudio_unittest/functional/autograd_cuda_test.py -k test_lfilter
pytest test/torchaudio_unittest/functional/batch_consistency_test.py -k test_lfilter
pytest test/torchaudio_unittest/functional/torchscript_consistency_cuda_test.py -k test_lfilter

pytest test/torchaudio_unittest/functional/functional_cuda_test.py -k test_forced_align

pytest test/torchaudio_unittest/models/decoder/cuda_ctc_decoder_test.py

Test Result

Number of passed unit tests:

Syntax Number of unit tests passing for each test run
RNNT loss 18, 1, 3, 3, 1
lfilter 19, 6, 2, 1
forced_align 120
cu ctc 3

Attached log for torch 2.11
torch211_log.txt

Doesn't support torch 2.10

@amd-sriram amd-sriram requested a review from a team as a code owner February 23, 2026 19:45
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Feb 23, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/audio/4180

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed label Feb 23, 2026
@amd-sriram amd-sriram marked this pull request as draft February 25, 2026 14:04
@amd-sriram amd-sriram marked this pull request as ready for review March 6, 2026 21:35
@amd-sriram
Copy link
Copy Markdown
Author

@NicolasHug Could you please review this PR. Thanks.

@ckimamd

This comment was marked as spam.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants