Scaling Efficient Transformers via Low-Rank Semantic Routing
This repository now also contains an evolution of the original production implementation, which is a full, more modular rewrite, called caramba. It is being developed as a set of composable atomic building blocks to rapidly experiment with new architectures and optimizations, and contains some advanced features. One of the most interesting features the caramba implementation provides is the ability to "upcycle" a Llama model, retro-fitting Decoupled Bottleneck Attention via a method inspired by Attention Surgery (Ghafoorian et al., 2025) [arXiv:2509.24899].
This repo contains the production implementation and paper reproduction harness for Decoupled Bottleneck Attention.
- Single source of truth:
production/(invoked viamain.py) - Single dataset: FineWeb-Edu GPT-2 tokens (
fineweb_100m.npy,fineweb_20b.npy) - Single reproduction entrypoint:
paper_manifest.json+run_paper_manifest.py
Paper numbers and plots are generated from run artifacts into assets/paper_results.json by generate_paper_figures.py.
Key Insight: Attention is a router, not a processor. Semantic routing operates in ~32 dimensions; positional geometry needs ~64. By decoupling them, we slash memory while preserving quality.
- Training/eval metrics: written to
runs/<id>/train.jsonlby the production runner. - KV-cache @128k memory: written to
runs/<id>/mem128k.jsonbyproduction/bench_end_to_end_memory.py.- For decoupled runs we also report the paper decomposition:
- architecture-only: standard FP16 → decoupled FP16
- quant-only: decoupled FP16 → hetero Q4/Q8/Q4
- end-to-end: standard FP16 → hetero Q4/Q8/Q4
- For decoupled runs we also report the paper decomposition:
- Paper artifacts:
generate_paper_figures.pyproduces:assets/paper_results.jsonassets/fig_convergence.pngassets/fig_pareto_memory_vs_loss.pngassets/table_main.tex,assets/table_scale.tex
You need Python 3.10+ and PyTorch for either CUDA or MPS.
python -m pip install -r requirements.txtNotes:
- On CUDA machines, install a CUDA-enabled torch separately (do not
pip install torchover it). matplotlibis optional unless you want to generate plots.
The harness expects tokenized arrays:
fineweb_100m.npy(local suite)fineweb_20b.npy(A100 scale suite)
If you don’t have them, generate them with:
python prepare_fineweb.py --tokens 100M --output fineweb_100m.npyNote: the A100 suite expects fineweb_20b.npy. If you already have a prebuilt 20B-token shard on your A100 instance,
copy it into the repo root under that exact filename.
All paper runs are defined in paper_manifest.json. The harness:
- validates resolved configs (no flag ambiguity),
- writes per-run provenance (
command.txt,resolved_config.json,resolved_run.json), - runs training,
- optionally runs
mem128k.jsonbenchmarking after training.
python run_paper_manifest.py --group mac_fw100m --dry-run
python run_paper_manifest.py --group a100_fw1b_1bscale --dry-runpython run_paper_manifest.py --group mac_fw100m --post-mem128kpython run_paper_manifest.py --group a100_fw1b_1bscale --post-mem128kThe A100 runs are resumable. Re-running the same command will resume from runs/<id>/last.pt if present.
experiments/
├── production/ # Canonical implementation (single source of truth)
├── main.py # Canonical CLI entrypoint (calls production/cli.py)
├── paper_manifest.json # Canonical paper run manifest
├── run_paper_manifest.py # Canonical paper harness runner
├── generate_paper_figures.py # Generates assets/ paper artifacts from run dirs
├── paper.tex # Paper source (inputs assets/table_*.tex, assets/*.png)
├── references.bib # Bibliography
├── Makefile # Misc helpers (not used for paper reproduction)
├── experiments/ # Historical experiments (not paper-canonical)
├── runs/ # Training logs (and checkpoints once run locally)
├── assets/ # Generated figures
│ ├── paper_results.json
│ ├── fig_convergence.png
│ ├── fig_pareto_memory_vs_loss.png
│ ├── table_main.tex
│ └── table_scale.tex
└── docs/
└── RESEARCH_MASTER_PLAN.md
For any run id in paper_manifest.json, reviewers should be able to verify:
- Provenance
runs/<id>/command.txtmatches what was executedruns/<id>/resolved_config.jsoncontains the fully resolved config usedruns/<id>/train.jsonlcontainsmeta/resume_meta,train,eval,doneevents
- Memory measurement
runs/<id>/mem128k.jsonexists (when using--post-mem128k)- decoupled runs include
decomposition.estimate_bytesanddecomposition.measured
After running experiments:
python generate_paper_figures.pyThis writes:
assets/paper_results.jsonassets/fig_convergence.pngassets/fig_pareto_memory_vs_loss.pngassets/table_main.tex,assets/table_scale.tex
This project now defaults to a minimal, self-optimizing CLI: you specify intent (train/sample, model size, data, instrumentation) and the system self-tunes performance/kv-cache policies based on your hardware and workload.
python3 main.py --mode train --size medium --exp paper_decoupled --data fineweb_100m.npypython3 main.py --mode sample --ckpt runs/<run_id>/best.pt --prompt-tokens "0 1 2 3" --max-new-tokens 64Optimization behavior is fully self-driven; there are no environment-variable toggles for disabling core optimization paths.
Model/training recipes are defined via presets and manifests:
- Presets:
production/config.py - Paper harness:
run_paper_manifest.py+paper_manifest.json
In decoupled attention, attention logits are the sum of two paths:
- Semantic path: content similarity (no RoPE)
- Geometric path: relative positional similarity (RoPE applied here only)
Because RoPE encodes a geometric/rotational positional signal, the geometric K/V state is often more sensitive to quantization error. In practice this is why default heterogeneous KV-cache policies commonly use more aggressive quantization on k_sem (e.g. q4_0/nf4) while keeping k_geo at higher precision (e.g. q8_0 or fp16) to preserve positional fidelity over long contexts.
The production self-optimizer can still explore counterfactual configurations (including sem/geo swaps), but policy acceptance is guarded by short- and optional long-horizon quality checks vs an fp16-cache baseline.
@article{vandommelen2025decoupled,
title={Decoupled Bottleneck Attention: Scaling Efficient Transformers via Low-Rank Semantic Routing},
author={van Dommelen, Daniel Owen},
journal={arXiv preprint arXiv:2025.XXXXX},
year={2025}
}If you reference the Attention Surgery method discussed above:
@article{ghafoorian2025attentionsurgery,
title={Attention Surgery: An Efficient Recipe to Linearize Your Video Diffusion Transformer},
author={Ghafoorian, Mohsen and Korzhenkov, Denis and Habibian, Amirhossein},
journal={arXiv preprint arXiv:2509.24899},
year={2025},
doi={10.48550/arXiv.2509.24899}
}This work builds on insights from:
- LoRA (Hu et al., 2021) — Low-rank adaptation
- AdaRankGrad (Refael et al., 2024) — Gradient rank dynamics
- DeepSeek-V2 MLA — Multi-head latent attention
- ExLlamaV2 — 4-bit KV cache quantization
- llama.cpp — Production Q4_0 implementation
MIT License. See LICENSE for details.