Skip to content

[Feature]: Migrate Python bindings from pybind11 to apache-tvm-ffi #2054

@carlushuang

Description

@carlushuang

[RFC] Migrate Python bindings from pybind11 to apache-tvm-ffi for lighter builds and framework-agnostic interop

Motivation

AITER's README states its design goal clearly:

Developers can focus on operators, and let the customers integrate this op collection into their own private/public/whatever framework.

However, the current pybind11 + torch::Tensor binding architecture contradicts this goal — every C++ interface is hard-coupled to PyTorch's headers, ABI, and build system. Migrating to apache-tvm-ffi with DLPack-based tensor interop would align the implementation with this stated vision.

Why this is needed

1. Extreme build time from PyTorch header inclusion

Every one of the ~70 csrc/pybind/*_pybind.cu files includes torch/extension.h, which transitively pulls in the entire PyTorch C++ API (ATen, c10, autograd, etc.). These headers are enormous and template-heavy. Since hipcc must process them for every .cu file, compilation time is dominated by PyTorch header parsing, not actual kernel compilation.

With tvm-ffi, kernel .hip files include only <hip/hip_runtime.h> and lightweight project headers. The thin FFI binding file includes only <tvm/ffi/tvm_ffi.h> (~a few hundred lines). This can reduce per-file compile time from minutes to seconds.

2. PyTorch ABI version coupling

pybind11 bindings compiled against one PyTorch version may not work with another due to ABI changes in torch::Tensor, at::Tensor, c10::optional, etc. This forces users to rebuild aiter whenever they update PyTorch, and makes it difficult to ship pre-built binaries.

tvm-ffi uses DLPack, which is a stable, minimal C ABI for tensor exchange. DLPack is supported by PyTorch, JAX, TensorFlow, MXNet, CuPy, and others. Once compiled, a tvm-ffi .so works across PyTorch versions without recompilation.

3. Framework lock-in vs. framework agnosticism

Currently, all C++ interfaces in csrc/include/*.h use torch::Tensor:

// csrc/include/norm.h (current)
#include <torch/extension.h>
torch::Tensor layernorm2d(torch::Tensor &input, torch::Tensor &weight, ...);

This makes it impossible to use aiter operators from JAX, TensorFlow, or plain NumPy/CuPy without going through PyTorch. With tvm-ffi, the C++ interface uses tvm::ffi::Tensor (a thin DLPack wrapper), and any framework that supports __dlpack__ can call the operators directly.

4. Binding boilerplate reduction

The current rocm_ops.hpp is ~1800 lines of pybind11 macro definitions. Each operator requires:

  • A *_PYBIND macro in rocm_ops.hpp with verbose py::arg() decorations
  • A separate *_pybind.cu file with PYBIND11_MODULE
  • A C++ function signature using torch::Tensor

With tvm-ffi, each operator just needs a one-line registration:

TVM_FFI_DLL_EXPORT_TYPED_FUNC(layernorm2d, layernorm2d_ffi);

No pybind11 macros, no py::arg(), no PYBIND11_MODULE.

5. JIT system simplification

The current aiter/jit/core.py is ~970 lines of complex build orchestration that wraps PyTorch's cpp_extension. It handles hipification, blob generation, file batons for multi-process locking, and PyTorch-specific build quirks.

tvm-ffi's JIT is much simpler: generate a build.ninja, run ninja, call tvm_ffi.load_module(). The kernel compilation doesn't need PyTorch headers at all, so there's no hipification step, no torch ABI checks, and no TORCH_EXTENSION_NAME machinery.

6. Dependency reduction

Current setup.py requires pybind11>=3.0.1 as an install dependency. tvm-ffi (pip install apache-tvm-ffi) is a lightweight pure-C library with no transitive dependencies.

How to achieve this

The migration can be done incrementally, operator-by-operator, without breaking existing users.

Architecture change

BEFORE (pybind11):
  csrc/include/norm.h          — declares functions with torch::Tensor
  csrc/kernels/norm_kernel.cu  — HIP kernel (includes torch headers)
  csrc/pybind/norm_pybind.cu   — PYBIND11_MODULE wrapper
  csrc/include/rocm_ops.hpp    — NORM_PYBIND macro definition

AFTER (tvm-ffi):
  csrc/include/norm.h          — declares functions with void* (framework-agnostic)
  csrc/kernels/norm_kernel.hip — HIP kernel (only HIP + project headers, no torch)
  csrc/ffi/norm_ffi.cc         — thin tvm_ffi registration (compiled with g++, not hipcc)

Step 1: Decouple kernel implementations from torch::Tensor

Replace torch::Tensor in kernel-facing C++ interfaces with raw pointers:

// csrc/include/norm.h (after)
#pragma once

void layernorm2d(
    const void* input, const void* weight, const void* bias,
    void* output,
    int m, int n, double epsilon,
    void* stream = nullptr);

The kernel .hip files already work with raw pointers internally — they just need to stop receiving torch::Tensor at the interface boundary.

Step 2: Create tvm-ffi binding files

Replace each csrc/pybind/*_pybind.cu with a csrc/ffi/*_ffi.cc.

Why tvm-ffi is better than pybind11 as a binding mechanism:

Note: this comparison is about pybind11 vs tvm-ffi as FFI libraries,
independent of the separate question of torch::Tensor vs DLPack.

Aspect pybind11 tvm-ffi
Header weight <pybind11/pybind11.h> — ~10K lines of C++ template metaprogramming <tvm/ffi/tvm_ffi.h> — ~few hundred lines, minimal templates
Python.h dependency Yes — pybind11 includes Python.h and links against libpython at compile time No — tvm-ffi .so has no Python header dependency; the tvm_ffi runtime handles Python interop
Registration syntax PYBIND11_MODULE(name, m) { m.def("fn", &fn, py::arg("x"), ...); } — verbose, needs arg names TVM_FFI_DLL_EXPORT_TYPED_FUNC(name, fn) — one-liner, no arg decoration
Module loading Python's import mechanism — requires PyInit_* symbol, Python ABI compatibility tvm_ffi.load_module() — plain dlopen + symbol lookup, no Python ABI dependency
ABI C++ ABI — name mangling, exception propagation, sensitive to compiler version C ABI — stable exported symbols, works across compiler versions
Linked libraries libpython3.x.so + pybind11 headers libtvm_ffi.so only (~2MB)
Compile time Slower (template-heavy pybind11 headers + Python.h) Faster (lightweight headers, no Python.h)
Python version coupling Compiled .so is tied to the specific Python minor version (3.10 vs 3.11 vs 3.12) .so is Python-version-independent; tvm_ffi runtime bridges the gap

The key insight: pybind11 creates a Python extension module (a .so with
PyInit_* that Python's import system loads directly), which couples the
compiled binary to both the C++ ABI and the Python ABI. tvm-ffi creates a
plain shared library with C-ABI exported symbols, and the tvm_ffi Python
runtime (installed via pip) handles all the Python interop. This means the
compiled .so is more portable and doesn't need recompilation when Python
or the C++ compiler changes.

For binding files specifically — which are pure glue code that just unpacks
tensor pointers and forwards to kernel launchers — the simplicity and
portability of tvm-ffi is a clear win over pybind11's heavier machinery.

Example:

// csrc/ffi/norm_ffi.cc
#include <tvm/ffi/tvm_ffi.h>
#include "norm.h"

void layernorm2d_ffi(
    tvm::ffi::Tensor input, tvm::ffi::Tensor weight,
    tvm::ffi::Tensor bias, tvm::ffi::Tensor output,
    double epsilon)
{
    void* inp_ptr = static_cast<char*>(input.data_ptr()) + input.byte_offset();
    void* wt_ptr  = static_cast<char*>(weight.data_ptr()) + weight.byte_offset();
    void* bi_ptr  = static_cast<char*>(bias.data_ptr()) + bias.byte_offset();
    void* out_ptr = static_cast<char*>(output.data_ptr()) + output.byte_offset();

    int m = static_cast<int>(input.shape()[0]);
    int n = static_cast<int>(input.shape()[1]);

    layernorm2d(inp_ptr, wt_ptr, bi_ptr, out_ptr, m, n, epsilon, nullptr);
}

TVM_FFI_DLL_EXPORT_TYPED_FUNC(layernorm2d, layernorm2d_ffi);

Step 3: Establish a two-layer Python architecture (Backend + Adaptor)

A key design point: tvm-ffi serves as the backend — a unified gateway
where all Python calls pass through tvm-ffi into the C++ world. On top of
that, each framework (PyTorch, JAX, etc.) gets a thin adaptor layer that
accepts native framework tensors, handles output allocation, and translates
to/from DLPack.

For PyTorch users, we always need an API that accepts torch.Tensor directly.
The current aiter Python API (aiter/ops/*.py, aiter/mla.py, etc.) already
plays this role — it is the torch adaptor. Under this proposal, its role
doesn't change; it simply swaps out the pybind11 module call for a tvm-ffi
module call underneath.

┌─────────────────────────────────────────────────────┐
│  User code                                          │
│    aiter.layernorm2d(torch_input, torch_weight, ...)│
└──────────────────────┬──────────────────────────────┘
                       │
┌──────────────────────▼──────────────────────────────┐
│  Torch Adaptor Layer  (aiter/ops/norm.py)           │
│    - Accepts torch.Tensor                           │
│    - Allocates output torch.Tensor                  │
│    - Validates shapes/dtypes                        │
│    - Converts via tvm_ffi.from_dlpack()             │
└──────────────────────┬──────────────────────────────┘
                       │
┌──────────────────────▼──────────────────────────────┐
│  tvm-ffi Backend  (compiled .so)                    │
│    - Framework-agnostic C++ functions               │
│    - Receives tvm::ffi::Tensor (DLPack)             │
│    - Dispatches to HIP kernels                      │
└─────────────────────────────────────────────────────┘

Concretely, the adaptor layer looks like:

# aiter/ops/norm.py  (Torch Adaptor — user-facing API, accepts torch.Tensor)
import torch
import tvm_ffi

def layernorm2d(input: torch.Tensor, weight: torch.Tensor,
                bias: torch.Tensor, epsilon: float = 1e-5) -> torch.Tensor:
    """User-facing PyTorch API — unchanged from today."""
    assert input.is_cuda and input.is_contiguous()
    output = torch.empty_like(input)

    # Cross into tvm-ffi backend via DLPack (zero-copy)
    _get_fn("layernorm2d")(
        tvm_ffi.from_dlpack(input),
        tvm_ffi.from_dlpack(weight),
        tvm_ffi.from_dlpack(bias),
        tvm_ffi.from_dlpack(output),
        epsilon,
    )
    return output
# aiter/ops/_backend.py  (tvm-ffi Backend — internal, not user-facing)
import tvm_ffi
from aiter.jit import load_module

_module = None

def _get_fn(name: str):
    global _module
    if _module is None:
        _module = load_module("module_norm")
    return _module[name]

Users still call aiter.layernorm2d(torch_tensor) exactly as before.
The adaptor layer is the only place that knows about torch.Tensor;
the tvm-ffi backend is completely framework-agnostic. Tomorrow, a JAX
adaptor could call the same backend .so without any C++ changes.

Note: The torch adaptor layer is always needed if you want a user-friendly
torch.Tensor-in, torch.Tensor-out API. Even in small tvm-ffi projects
(e.g., gcnasm/warp_sort_bitonic/tvmffi), the adaptor exists — it's just
inlined into __init__.py rather than a separate file. It accepts
torch.Tensor, allocates the output via torch.empty_like(), converts
to DLPack via tvm_ffi.from_dlpack(), calls the backend, and returns
torch.Tensor. The pattern is always the same; only the scale differs.

In aiter's case, the adaptor layer already exists — it's the current
aiter/ops/*.py and aiter/mla.py files. They already handle shape
validation, output allocation, and dispatch. This proposal does not
remove them. It only changes their internal plumbing: swap
pybind11_module.func(torch_tensor) for
tvm_ffi_module["func"](tvm_ffi.from_dlpack(torch_tensor)).

Step 4: Update build system

  • In aiter/jit/core.py, split compilation into two phases:
    • HIP phase: compile kernel .hip files with hipcc (no torch headers)
    • FFI phase: compile *_ffi.cc files with g++ (only tvm_ffi headers)
    • Link phase: link both into a single .so with -ltvm_ffi
  • Remove pybind11>=3.0.1 from setup.py install_requires
  • Add apache-tvm-ffi to install_requires
  • Delete csrc/include/rocm_ops.hpp (the 1800-line macro file)

Migration order (suggested)

  1. Start with a leaf operator (e.g., norm, activation) as a proof of concept
  2. Migrate attention/MLA kernels (highest value — slowest to compile today)
  3. Migrate GEMM/MoE kernels
  4. Remove pybind11 dependency entirely

Existing tvm-ffi examples in the wild

Two working examples of the tvm-ffi + Ninja JIT approach already exist in
gcnasm (AMD GPU kernel playground):

1. warp_sort_bitonic/tvmffi/ — simple single-kernel example

A warp-level bitonic merge sort kernel with tvm-ffi binding. Demonstrates
the minimal pattern: one .hip kernel + one tvm_api.cc FFI wrapper,
built via Python-driven Ninja JIT (no Makefile, no setup.py).

  • FFI binding (csrc/tvm_api.cc):

    #include <tvm/ffi/tvm_ffi.h>
    #include "warp_bitonic_sort.hpp"
    
    void warp_bitonic_sort_ffi(tvm::ffi::Tensor input, tvm::ffi::Tensor output,
                               int is_descending) {
        void* i_ptr = static_cast<char*>(input.data_ptr()) + input.byte_offset();
        void* o_ptr = static_cast<char*>(output.data_ptr()) + output.byte_offset();
        int num_element = static_cast<int>(input.shape()[0]);
        warp_bitonic_sort_kernel(i_ptr, o_ptr, num_element, is_descending ? 1 : 0);
    }
    
    TVM_FFI_DLL_EXPORT_TYPED_FUNC(warp_bitonic_sort, warp_bitonic_sort_ffi);
  • JIT build (warp_bitonic_sort/jit.py): ~130 lines of Python that
    generates build.ninja, runs ninja, and loads via tvm_ffi.load_module().
    Caches the .so in ~/.cache/ — subsequent imports skip compilation entirely.

  • Key takeaway: The HIP kernel file has zero PyTorch headers. The FFI
    .cc file is compiled with plain g++, not hipcc. Build time is seconds.

2. mla_agentic_v0/ — multi-kernel MLA decode with split-K

A more complex example: Flash MLA sparse decode with multiple templated MFMA
kernels, split-K parallelism, and a combine kernel — all bound through tvm-ffi.

  • 3 kernel functions exported via TVM_FFI_DLL_EXPORT_TYPED_FUNC:
    flash_mla_decode, flash_mla_sparse_decode, flash_mla_sparse_decode_splitk
  • Same JIT pattern: flash_mla/jit.py generates Ninja, compiles .hip +
    .cc, links against libtvm_ffi.so
  • Python side: tvm_ffi.from_dlpack(torch_tensor) for zero-copy tensor
    passing; tvm_ffi.load_module() for loading the compiled .so

Both examples use the same ~130-line JitSpec pattern for build/cache/load,
which could be extracted into a shared utility for aiter.

Summary

Aspect pybind11 (current) tvm-ffi (proposed)
Build deps per .cu Full PyTorch C++ headers Only <hip/hip_runtime.h>
FFI binding deps pybind11 + torch headers <tvm/ffi/tvm_ffi.h> only
Compile time (per module) Minutes (torch template expansion) Seconds
PyTorch version coupling Hard ABI dependency None (DLPack is stable)
Framework support PyTorch only Any DLPack framework
Binding boilerplate ~1800 lines in rocm_ops.hpp One-liner per function
Pre-built binary reuse Breaks across PyTorch versions Works across versions

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions