-
Notifications
You must be signed in to change notification settings - Fork 232
Description
[RFC] Migrate Python bindings from pybind11 to apache-tvm-ffi for lighter builds and framework-agnostic interop
Motivation
AITER's README states its design goal clearly:
Developers can focus on operators, and let the customers integrate this op collection into their own private/public/whatever framework.
However, the current pybind11 + torch::Tensor binding architecture contradicts this goal — every C++ interface is hard-coupled to PyTorch's headers, ABI, and build system. Migrating to apache-tvm-ffi with DLPack-based tensor interop would align the implementation with this stated vision.
Why this is needed
1. Extreme build time from PyTorch header inclusion
Every one of the ~70 csrc/pybind/*_pybind.cu files includes torch/extension.h, which transitively pulls in the entire PyTorch C++ API (ATen, c10, autograd, etc.). These headers are enormous and template-heavy. Since hipcc must process them for every .cu file, compilation time is dominated by PyTorch header parsing, not actual kernel compilation.
With tvm-ffi, kernel .hip files include only <hip/hip_runtime.h> and lightweight project headers. The thin FFI binding file includes only <tvm/ffi/tvm_ffi.h> (~a few hundred lines). This can reduce per-file compile time from minutes to seconds.
2. PyTorch ABI version coupling
pybind11 bindings compiled against one PyTorch version may not work with another due to ABI changes in torch::Tensor, at::Tensor, c10::optional, etc. This forces users to rebuild aiter whenever they update PyTorch, and makes it difficult to ship pre-built binaries.
tvm-ffi uses DLPack, which is a stable, minimal C ABI for tensor exchange. DLPack is supported by PyTorch, JAX, TensorFlow, MXNet, CuPy, and others. Once compiled, a tvm-ffi .so works across PyTorch versions without recompilation.
3. Framework lock-in vs. framework agnosticism
Currently, all C++ interfaces in csrc/include/*.h use torch::Tensor:
// csrc/include/norm.h (current)
#include <torch/extension.h>
torch::Tensor layernorm2d(torch::Tensor &input, torch::Tensor &weight, ...);This makes it impossible to use aiter operators from JAX, TensorFlow, or plain NumPy/CuPy without going through PyTorch. With tvm-ffi, the C++ interface uses tvm::ffi::Tensor (a thin DLPack wrapper), and any framework that supports __dlpack__ can call the operators directly.
4. Binding boilerplate reduction
The current rocm_ops.hpp is ~1800 lines of pybind11 macro definitions. Each operator requires:
- A
*_PYBINDmacro inrocm_ops.hppwith verbosepy::arg()decorations - A separate
*_pybind.cufile withPYBIND11_MODULE - A C++ function signature using
torch::Tensor
With tvm-ffi, each operator just needs a one-line registration:
TVM_FFI_DLL_EXPORT_TYPED_FUNC(layernorm2d, layernorm2d_ffi);No pybind11 macros, no py::arg(), no PYBIND11_MODULE.
5. JIT system simplification
The current aiter/jit/core.py is ~970 lines of complex build orchestration that wraps PyTorch's cpp_extension. It handles hipification, blob generation, file batons for multi-process locking, and PyTorch-specific build quirks.
tvm-ffi's JIT is much simpler: generate a build.ninja, run ninja, call tvm_ffi.load_module(). The kernel compilation doesn't need PyTorch headers at all, so there's no hipification step, no torch ABI checks, and no TORCH_EXTENSION_NAME machinery.
6. Dependency reduction
Current setup.py requires pybind11>=3.0.1 as an install dependency. tvm-ffi (pip install apache-tvm-ffi) is a lightweight pure-C library with no transitive dependencies.
How to achieve this
The migration can be done incrementally, operator-by-operator, without breaking existing users.
Architecture change
BEFORE (pybind11):
csrc/include/norm.h — declares functions with torch::Tensor
csrc/kernels/norm_kernel.cu — HIP kernel (includes torch headers)
csrc/pybind/norm_pybind.cu — PYBIND11_MODULE wrapper
csrc/include/rocm_ops.hpp — NORM_PYBIND macro definition
AFTER (tvm-ffi):
csrc/include/norm.h — declares functions with void* (framework-agnostic)
csrc/kernels/norm_kernel.hip — HIP kernel (only HIP + project headers, no torch)
csrc/ffi/norm_ffi.cc — thin tvm_ffi registration (compiled with g++, not hipcc)
Step 1: Decouple kernel implementations from torch::Tensor
Replace torch::Tensor in kernel-facing C++ interfaces with raw pointers:
// csrc/include/norm.h (after)
#pragma once
void layernorm2d(
const void* input, const void* weight, const void* bias,
void* output,
int m, int n, double epsilon,
void* stream = nullptr);The kernel .hip files already work with raw pointers internally — they just need to stop receiving torch::Tensor at the interface boundary.
Step 2: Create tvm-ffi binding files
Replace each csrc/pybind/*_pybind.cu with a csrc/ffi/*_ffi.cc.
Why tvm-ffi is better than pybind11 as a binding mechanism:
Note: this comparison is about pybind11 vs tvm-ffi as FFI libraries,
independent of the separate question of torch::Tensor vs DLPack.
| Aspect | pybind11 | tvm-ffi |
|---|---|---|
| Header weight | <pybind11/pybind11.h> — ~10K lines of C++ template metaprogramming |
<tvm/ffi/tvm_ffi.h> — ~few hundred lines, minimal templates |
| Python.h dependency | Yes — pybind11 includes Python.h and links against libpython at compile time |
No — tvm-ffi .so has no Python header dependency; the tvm_ffi runtime handles Python interop |
| Registration syntax | PYBIND11_MODULE(name, m) { m.def("fn", &fn, py::arg("x"), ...); } — verbose, needs arg names |
TVM_FFI_DLL_EXPORT_TYPED_FUNC(name, fn) — one-liner, no arg decoration |
| Module loading | Python's import mechanism — requires PyInit_* symbol, Python ABI compatibility |
tvm_ffi.load_module() — plain dlopen + symbol lookup, no Python ABI dependency |
| ABI | C++ ABI — name mangling, exception propagation, sensitive to compiler version | C ABI — stable exported symbols, works across compiler versions |
| Linked libraries | libpython3.x.so + pybind11 headers |
libtvm_ffi.so only (~2MB) |
| Compile time | Slower (template-heavy pybind11 headers + Python.h) | Faster (lightweight headers, no Python.h) |
| Python version coupling | Compiled .so is tied to the specific Python minor version (3.10 vs 3.11 vs 3.12) |
.so is Python-version-independent; tvm_ffi runtime bridges the gap |
The key insight: pybind11 creates a Python extension module (a .so with
PyInit_* that Python's import system loads directly), which couples the
compiled binary to both the C++ ABI and the Python ABI. tvm-ffi creates a
plain shared library with C-ABI exported symbols, and the tvm_ffi Python
runtime (installed via pip) handles all the Python interop. This means the
compiled .so is more portable and doesn't need recompilation when Python
or the C++ compiler changes.
For binding files specifically — which are pure glue code that just unpacks
tensor pointers and forwards to kernel launchers — the simplicity and
portability of tvm-ffi is a clear win over pybind11's heavier machinery.
Example:
// csrc/ffi/norm_ffi.cc
#include <tvm/ffi/tvm_ffi.h>
#include "norm.h"
void layernorm2d_ffi(
tvm::ffi::Tensor input, tvm::ffi::Tensor weight,
tvm::ffi::Tensor bias, tvm::ffi::Tensor output,
double epsilon)
{
void* inp_ptr = static_cast<char*>(input.data_ptr()) + input.byte_offset();
void* wt_ptr = static_cast<char*>(weight.data_ptr()) + weight.byte_offset();
void* bi_ptr = static_cast<char*>(bias.data_ptr()) + bias.byte_offset();
void* out_ptr = static_cast<char*>(output.data_ptr()) + output.byte_offset();
int m = static_cast<int>(input.shape()[0]);
int n = static_cast<int>(input.shape()[1]);
layernorm2d(inp_ptr, wt_ptr, bi_ptr, out_ptr, m, n, epsilon, nullptr);
}
TVM_FFI_DLL_EXPORT_TYPED_FUNC(layernorm2d, layernorm2d_ffi);Step 3: Establish a two-layer Python architecture (Backend + Adaptor)
A key design point: tvm-ffi serves as the backend — a unified gateway
where all Python calls pass through tvm-ffi into the C++ world. On top of
that, each framework (PyTorch, JAX, etc.) gets a thin adaptor layer that
accepts native framework tensors, handles output allocation, and translates
to/from DLPack.
For PyTorch users, we always need an API that accepts torch.Tensor directly.
The current aiter Python API (aiter/ops/*.py, aiter/mla.py, etc.) already
plays this role — it is the torch adaptor. Under this proposal, its role
doesn't change; it simply swaps out the pybind11 module call for a tvm-ffi
module call underneath.
┌─────────────────────────────────────────────────────┐
│ User code │
│ aiter.layernorm2d(torch_input, torch_weight, ...)│
└──────────────────────┬──────────────────────────────┘
│
┌──────────────────────▼──────────────────────────────┐
│ Torch Adaptor Layer (aiter/ops/norm.py) │
│ - Accepts torch.Tensor │
│ - Allocates output torch.Tensor │
│ - Validates shapes/dtypes │
│ - Converts via tvm_ffi.from_dlpack() │
└──────────────────────┬──────────────────────────────┘
│
┌──────────────────────▼──────────────────────────────┐
│ tvm-ffi Backend (compiled .so) │
│ - Framework-agnostic C++ functions │
│ - Receives tvm::ffi::Tensor (DLPack) │
│ - Dispatches to HIP kernels │
└─────────────────────────────────────────────────────┘
Concretely, the adaptor layer looks like:
# aiter/ops/norm.py (Torch Adaptor — user-facing API, accepts torch.Tensor)
import torch
import tvm_ffi
def layernorm2d(input: torch.Tensor, weight: torch.Tensor,
bias: torch.Tensor, epsilon: float = 1e-5) -> torch.Tensor:
"""User-facing PyTorch API — unchanged from today."""
assert input.is_cuda and input.is_contiguous()
output = torch.empty_like(input)
# Cross into tvm-ffi backend via DLPack (zero-copy)
_get_fn("layernorm2d")(
tvm_ffi.from_dlpack(input),
tvm_ffi.from_dlpack(weight),
tvm_ffi.from_dlpack(bias),
tvm_ffi.from_dlpack(output),
epsilon,
)
return output# aiter/ops/_backend.py (tvm-ffi Backend — internal, not user-facing)
import tvm_ffi
from aiter.jit import load_module
_module = None
def _get_fn(name: str):
global _module
if _module is None:
_module = load_module("module_norm")
return _module[name]Users still call aiter.layernorm2d(torch_tensor) exactly as before.
The adaptor layer is the only place that knows about torch.Tensor;
the tvm-ffi backend is completely framework-agnostic. Tomorrow, a JAX
adaptor could call the same backend .so without any C++ changes.
Note: The torch adaptor layer is always needed if you want a user-friendly
torch.Tensor-in, torch.Tensor-out API. Even in small tvm-ffi projects
(e.g., gcnasm/warp_sort_bitonic/tvmffi), the adaptor exists — it's just
inlined into __init__.py rather than a separate file. It accepts
torch.Tensor, allocates the output via torch.empty_like(), converts
to DLPack via tvm_ffi.from_dlpack(), calls the backend, and returns
torch.Tensor. The pattern is always the same; only the scale differs.
In aiter's case, the adaptor layer already exists — it's the current
aiter/ops/*.py and aiter/mla.py files. They already handle shape
validation, output allocation, and dispatch. This proposal does not
remove them. It only changes their internal plumbing: swap
pybind11_module.func(torch_tensor) for
tvm_ffi_module["func"](tvm_ffi.from_dlpack(torch_tensor)).
Step 4: Update build system
- In
aiter/jit/core.py, split compilation into two phases:- HIP phase: compile kernel
.hipfiles withhipcc(no torch headers) - FFI phase: compile
*_ffi.ccfiles withg++(only tvm_ffi headers) - Link phase: link both into a single
.sowith-ltvm_ffi
- HIP phase: compile kernel
- Remove
pybind11>=3.0.1fromsetup.pyinstall_requires - Add
apache-tvm-ffito install_requires - Delete
csrc/include/rocm_ops.hpp(the 1800-line macro file)
Migration order (suggested)
- Start with a leaf operator (e.g.,
norm,activation) as a proof of concept - Migrate attention/MLA kernels (highest value — slowest to compile today)
- Migrate GEMM/MoE kernels
- Remove pybind11 dependency entirely
Existing tvm-ffi examples in the wild
Two working examples of the tvm-ffi + Ninja JIT approach already exist in
gcnasm (AMD GPU kernel playground):
1. warp_sort_bitonic/tvmffi/ — simple single-kernel example
A warp-level bitonic merge sort kernel with tvm-ffi binding. Demonstrates
the minimal pattern: one .hip kernel + one tvm_api.cc FFI wrapper,
built via Python-driven Ninja JIT (no Makefile, no setup.py).
-
FFI binding (
csrc/tvm_api.cc):#include <tvm/ffi/tvm_ffi.h> #include "warp_bitonic_sort.hpp" void warp_bitonic_sort_ffi(tvm::ffi::Tensor input, tvm::ffi::Tensor output, int is_descending) { void* i_ptr = static_cast<char*>(input.data_ptr()) + input.byte_offset(); void* o_ptr = static_cast<char*>(output.data_ptr()) + output.byte_offset(); int num_element = static_cast<int>(input.shape()[0]); warp_bitonic_sort_kernel(i_ptr, o_ptr, num_element, is_descending ? 1 : 0); } TVM_FFI_DLL_EXPORT_TYPED_FUNC(warp_bitonic_sort, warp_bitonic_sort_ffi);
-
JIT build (
warp_bitonic_sort/jit.py): ~130 lines of Python that
generatesbuild.ninja, runsninja, and loads viatvm_ffi.load_module().
Caches the.soin~/.cache/— subsequent imports skip compilation entirely. -
Key takeaway: The HIP kernel file has zero PyTorch headers. The FFI
.ccfile is compiled with plaing++, nothipcc. Build time is seconds.
2. mla_agentic_v0/ — multi-kernel MLA decode with split-K
A more complex example: Flash MLA sparse decode with multiple templated MFMA
kernels, split-K parallelism, and a combine kernel — all bound through tvm-ffi.
- 3 kernel functions exported via
TVM_FFI_DLL_EXPORT_TYPED_FUNC:
flash_mla_decode,flash_mla_sparse_decode,flash_mla_sparse_decode_splitk - Same JIT pattern:
flash_mla/jit.pygenerates Ninja, compiles.hip+
.cc, links againstlibtvm_ffi.so - Python side:
tvm_ffi.from_dlpack(torch_tensor)for zero-copy tensor
passing;tvm_ffi.load_module()for loading the compiled.so
Both examples use the same ~130-line JitSpec pattern for build/cache/load,
which could be extracted into a shared utility for aiter.
Summary
| Aspect | pybind11 (current) | tvm-ffi (proposed) |
|---|---|---|
| Build deps per .cu | Full PyTorch C++ headers | Only <hip/hip_runtime.h> |
| FFI binding deps | pybind11 + torch headers | <tvm/ffi/tvm_ffi.h> only |
| Compile time (per module) | Minutes (torch template expansion) | Seconds |
| PyTorch version coupling | Hard ABI dependency | None (DLPack is stable) |
| Framework support | PyTorch only | Any DLPack framework |
| Binding boilerplate | ~1800 lines in rocm_ops.hpp | One-liner per function |
| Pre-built binary reuse | Breaks across PyTorch versions | Works across versions |