Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ share/python-wheels/
.installed.cfg
*.egg
MANIFEST

LICENSE.*
core.*
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
Expand Down
30 changes: 30 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,36 @@ High-Performance GPU Kernels for Inference
[![Build Status](https://ci.tlcpack.ai/job/flashinfer-ci/job/main/badge/icon)](https://ci.tlcpack.ai/job/flashinfer-ci/job/main/)
[![Documentation](https://github.com/flashinfer-ai/flashinfer/actions/workflows/build-doc.yml/badge.svg)](https://github.com/flashinfer-ai/flashinfer/actions/workflows/build-doc.yml)


> [!NOTE]
>
> This repo is a fork of the original torchcodec project, with modifications to enhance compatibility and integration with PaddlePaddle.
>
> **Installation**
>
> ```bash
> pip install paddlepaddle_gpu # Install PaddlePaddle with GPU support, refer to https://www.paddlepaddle.org.cn/install/quick for more details
> git clone https://github.com/PFCCLab/flashinfer.git
> cd flashinfer
> git submodule update --init
> pip install apache-tvm-ffi>=0.1.2 # Use TVM FFI 0.1.2 or above
> pip install filelock jinja2 # Install tools for jit compilation
> pip install --no-build-isolation . -v
> ```
>
> **Usage**
>
> ```python
> import paddle
> paddle.enable_compat(scope={"flashinfer"}) # Enable torch proxy before importing flashinfer
> import flashinfer
> # use flashinfer
> ```

The original README.md content is as follows:

---

**FlashInfer** is a library and kernel generator for inference that delivers state-of-the-art performance across diverse GPU architectures. It provides unified APIs for attention, GEMM, and MoE operations with multiple backend implementations including FlashAttention-2/3, cuDNN, CUTLASS, and TensorRT-LLM.

## Why FlashInfer?
Expand Down
16 changes: 13 additions & 3 deletions flashinfer/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,16 @@

import torch

# Paddle compat: `torch.cuda.OutOfMemoryError` is missing; use an unreachable sentinel
# so the existing `except _CudaOutOfMemoryError:` clauses don't silently swallow other errors.
try:
_CudaOutOfMemoryError = torch.cuda.OutOfMemoryError # type: ignore[attr-defined]
except AttributeError:

class _CudaOutOfMemoryError(Exception): # type: ignore[no-redef]
"""Placeholder for torch.cuda.OutOfMemoryError under paddle compat."""


# from tensorrt_llm.bindings.internal.runtime import delay_kernel
# from tensorrt_llm.logger import logger
from flashinfer.tllm_utils import delay_kernel
Expand Down Expand Up @@ -1175,7 +1185,7 @@ def choose_one(
time_measured = self._profile_single_kernel(
r, tensors, tac, tuning_config, **kwargs
)
except torch.cuda.OutOfMemoryError:
except _CudaOutOfMemoryError:
raise
except Exception as e:
skipped_count += 1
Expand Down Expand Up @@ -1251,7 +1261,7 @@ def choose_one(
f"[Autotuner]: profiling chosen runner: {runners[runner_id]} {tactic} for {cache_key}"
)

except torch.cuda.OutOfMemoryError:
except _CudaOutOfMemoryError:
torch.cuda.empty_cache()
logger.warning(
"[Autotuner]: OOM detected, falling back to default tactic"
Expand All @@ -1275,7 +1285,7 @@ def choose_one(
def _get_input_sizes(self, inputs: List[torch.Tensor]) -> List[torch.Size]:
"""Return ``torch.Size`` for each input, using ``(0,)`` for non-Tensor values."""
sizes = [
input.size() if isinstance(input, torch.Tensor) else torch.Size((0,))
tuple(input.size()) if isinstance(input, torch.Tensor) else (0,)
for input in inputs
]

Expand Down
6 changes: 5 additions & 1 deletion flashinfer/comm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,4 +76,8 @@
# from .mnnvl import MnnvlMemory, MnnvlMoe, MoEAlltoallInfo

# AllGatherMatmul
from .all_gather_matmul import all_gather_matmul as all_gather_matmul
# Paddle compat: optional - depends on cuda.tile which may be unavailable
try:
from .all_gather_matmul import all_gather_matmul as all_gather_matmul
except (ImportError, ModuleNotFoundError):
all_gather_matmul = None
7 changes: 5 additions & 2 deletions flashinfer/comm/all_gather_matmul/all_gather_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,11 @@

import torch
import torch.distributed as dist
import torch.distributed._symmetric_memory as symm_mem
from flashinfer.comm import all_gather_matmul
try:
import torch.distributed._symmetric_memory as symm_mem
except (ImportError, ModuleNotFoundError):
symm_mem = None
from flashinfer.comm import all_gather_matmul

# --- per-rank setup ---
device = torch.device(f"cuda:{rank}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@

import torch
import torch.distributed as dist
import torch.distributed._symmetric_memory as symm_mem

try:
import torch.distributed._symmetric_memory as symm_mem
except (ImportError, ModuleNotFoundError):
symm_mem = None
import cuda.tile as ct

from .broadcast_input import broadcast_input
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@

import torch
import torch.distributed as dist
import torch.distributed._symmetric_memory as symm_mem

try:
import torch.distributed._symmetric_memory as symm_mem
except (ImportError, ModuleNotFoundError):
symm_mem = None
import triton
import triton.language as tl

Expand Down
2 changes: 1 addition & 1 deletion flashinfer/comm/allreduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
from .workspace_base import AllReduceFusionWorkspace

import torch
from torch.distributed import ProcessGroup
from paddle.base.core import ProcessGroup

from flashinfer.api_logging import flashinfer_api
from flashinfer.trace.templates.comm import allreduce_fusion_trace
Expand Down
27 changes: 21 additions & 6 deletions flashinfer/comm/cuda_ipc.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from typing import Any, Dict, List, Optional

import torch.distributed as dist
from torch.distributed import ProcessGroup
from paddle.base.core import ProcessGroup

# NOTE(Zihao): we should use cuda-python instead of ctypes cuda runtime bindings.
# However, cuda-python's API is not stable yet, so we use ctypes bindings instead.
Expand Down Expand Up @@ -204,13 +204,23 @@ def create_shared_buffer(
pointer = cudart.cudaMalloc(size_in_bytes)
handle = cudart.cudaIpcGetMemHandle(pointer)
if group is None:
group = dist.group.WORLD
try:
group = dist.group.WORLD
except AttributeError:
import paddle.distributed as _pdist

group = _pdist.get_group()
world_size = dist.get_world_size(group=group)
rank = dist.get_rank(group=group)
handles = [None] * world_size
dist.all_gather_object(handles, handle, group=group)
handles = [None] * world_size
# Paddle compat: dist.all_gather_object uses append semantics rather than index
# assignment; start with a single placeholder and trim it after the gather.
handles: list = []
dist.all_gather_object(handles, handle, group=group)
if len(handles) == world_size + 1 and handles[0] is None:
handles = handles[1:]
assert len(handles) == world_size, (
f"all_gather_object returned {len(handles)} entries, expected {world_size}"
)

pointers: List[int] = []
for i, h in enumerate(handles):
Expand All @@ -230,7 +240,12 @@ def free_shared_buffer(
Frees a shared buffer.
"""
if group is None:
group = dist.group.WORLD
try:
group = dist.group.WORLD
except AttributeError:
import paddle.distributed as _pdist

group = _pdist.get_group()
rank = dist.get_rank(group=group)
if pointers and len(pointers) > rank and pointers[rank] is not None:
cudart.cudaFree(ctypes.c_void_p(pointers[rank]))
Expand Down
2 changes: 1 addition & 1 deletion flashinfer/comm/nvshmem_allreduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import numpy as np
import torch
from torch.distributed import ProcessGroup
from paddle.base.core import ProcessGroup

import nvshmem.core
from cuda.core import Buffer, Device
Expand Down
13 changes: 11 additions & 2 deletions flashinfer/comm/torch_symmetric_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,17 @@
from typing import Any

import torch
import torch.distributed._symmetric_memory as symm_mem
import torch.distributed.distributed_c10d as c10d

# Paddle compat: torch.distributed._symmetric_memory is not available under paddle proxy.
try:
import torch.distributed._symmetric_memory as symm_mem
import torch.distributed.distributed_c10d as c10d

_SYMM_MEM_AVAILABLE = True
except (ImportError, ModuleNotFoundError):
symm_mem = None
c10d = None
_SYMM_MEM_AVAILABLE = False

_compat_patched = False

Expand Down
91 changes: 61 additions & 30 deletions flashinfer/comm/trtllm_ar.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@
from flashinfer.comm.mnnvl import CommBackend, SymmDeviceMemory, TorchDistBackend
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from paddle.base.core import ProcessGroup

from ..jit.comm import gen_trtllm_comm_module
from ..utils import register_custom_op, round_up

logger = logging.getLogger(__name__)
from .cuda_ipc import cudart
from .cuda_ipc import cudart, create_shared_buffer
from .torch_symmetric_memory import _alloc_symm_buffer_bytes


Expand Down Expand Up @@ -629,36 +629,42 @@ def trtllm_create_ipc_workspace_for_all_reduce_fusion(
lamport_buffer_size = lamport_comm_size * 3

device = torch.device(f"cuda:{torch.cuda.current_device()}")
group_name = (
group.group_name
if group is not None
else torch.distributed.group.WORLD.group_name
)
symm_refs: list[torch.Tensor] = []

# we should init 3 buffers for all reduce fusion:
# [buffer_size, flag_size, lamport_buffer_size]
# Paddle compat: symmetric memory is not available; fall back to cudaIpc-based shared buffers.
from .torch_symmetric_memory import _SYMM_MEM_AVAILABLE

symm_refs: list[torch.Tensor] = []
ipc_handles: List[List[int]] = list()
mem_handles: List[SymmDeviceMemory] = list()
lamport_buffer_dtype = torch.float16 if not use_fp32_lamport else torch.float32
for size, dtype in [
(buffer_size, torch.float32),
(flag_size, torch.int32),
(lamport_buffer_size, lamport_buffer_dtype),
]:
aligned_size = round_up(size, 16)

ptrs, tensor, handle = _alloc_symm_buffer_bytes(
aligned_size,
tp_size,
dtype,
device,
group_name,
if not _SYMM_MEM_AVAILABLE:
for size in [buffer_size, flag_size, lamport_buffer_size]:
aligned_size = round_up(size, 1 << 21)
ipc_handles.append(create_shared_buffer(aligned_size, group))
else:
group_name = (
group.group_name
if group is not None
else torch.distributed.group.WORLD.group_name
)
symm_refs.append((tensor, handle))
ipc_handles.append(ptrs)
mem_handles.append(handle)
for size, dtype in [
(buffer_size, torch.float32),
(flag_size, torch.int32),
(lamport_buffer_size, lamport_buffer_dtype),
]:
aligned_size = round_up(size, 16)

ptrs, tensor, handle = _alloc_symm_buffer_bytes(
aligned_size,
tp_size,
dtype,
device,
group_name,
)
symm_refs.append((tensor, handle))
ipc_handles.append(ptrs)
mem_handles.append(handle)

logger.debug(
"rank %s allocated ipc_handles: %s",
Expand Down Expand Up @@ -712,10 +718,9 @@ def trtllm_create_ipc_workspace_for_all_reduce_fusion(
for i in range(len(workspace)):
logger.debug("Rank %s workspace[%d] %s", tp_rank, i, hex(workspace[i]))

# Store workspace pointers in device tensor
workspace_tensor = torch.tensor(
workspace, dtype=torch.int64, device=torch.device("cuda")
)
# Store workspace pointers in device tensor.
# Paddle compat: torch.tensor(..., device=torch.device("cuda")) is buggy; use .cuda().
workspace_tensor = torch.tensor(workspace, dtype=torch.int64).cuda()

if use_symm_dev_mem:
comm_backend.barrier() # must sync after create_workspace
Expand Down Expand Up @@ -891,11 +896,37 @@ def trtllm_custom_all_reduce(
}


# Paddle compat: paddle.dtype has no `itemsize` attribute; maintain a mapping fallback.
_DTYPE_SIZE_MAP: dict = {
torch.float16: 2,
torch.bfloat16: 2,
torch.float32: 4,
torch.float64: 8,
torch.int8: 1,
torch.int16: 2,
torch.int32: 4,
torch.int64: 8,
torch.uint8: 1,
torch.bool: 1,
torch.complex64: 8,
torch.complex128: 16,
}


def _dtype_itemsize(dtype) -> int:
itemsize = getattr(dtype, "itemsize", None)
if itemsize is not None:
return itemsize
if dtype in _DTYPE_SIZE_MAP:
return _DTYPE_SIZE_MAP[dtype]
raise TypeError(f"Cannot determine itemsize for dtype {dtype!r}")


def _should_use_oneshot(
token_num: int, hidden_dim: int, dtype: torch.dtype, world_size: int
) -> bool:
comm_size_mb = (
token_num * hidden_dim * 2 * world_size * dtype.itemsize / 1024 / 1024
token_num * hidden_dim * 2 * world_size * _dtype_itemsize(dtype) / 1024 / 1024
)
return comm_size_mb <= _use_oneshot_heuristics[world_size]

Expand Down
2 changes: 2 additions & 0 deletions flashinfer/cute_dsl/fp4_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
utilities used by both rmsnorm_fp4quant.py and add_rmsnorm_fp4quant.py.
"""

from __future__ import annotations

import functools
import math
import operator
Expand Down
2 changes: 1 addition & 1 deletion flashinfer/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -2645,7 +2645,7 @@ def trtllm_batch_decode_with_kv_cache(
0, # sparse_mla_top_k
sm_count,
enable_pdl,
workspace_buffer.numel() * workspace_buffer.element_size(),
int(workspace_buffer.numel() * workspace_buffer.element_size()),
sinks,
cum_seq_lens_q,
k_block_scales,
Expand Down
Loading
Loading