A probabilistic programming language implemented in ClojureScript on Node.js (nbb), using the MLX framework for GPU acceleration via mlx-node.
GenMLX implements Gen's Generative Function Interface (GFI) — the same architecture used by GenJAX (JAX) and Gen.jl (Julia).
Gen implementations exist for Julia and JAX — but nothing for MLX. MLX's unified memory model is a natural fit for probabilistic programming: MCMC control flow runs on CPU while all numerics stay on GPU, with zero data transfer cost. ClojureScript on Node.js gives direct access to MLX through a native addon with no FFI overhead, and nbb provides a fast REPL for interactive model development.
- MLX-native — unified memory, lazy evaluation, dynamic shapes,
mx/gradthrough entire models (Apple Silicon) - ~32,000 lines of ClojureScript — protocols, records, persistent data structures, the whole thing is readable in an afternoon
- GPU end-to-end — scores and choice values are MLX arrays throughout, extracted with
mx/itemonly at inference boundaries - 5-level compilation ladder — progressively moves work from the host interpreter into fused MLX computation graphs, from shape-based batching (L0) through auto-analytical elimination (L3) to single fused graphs (L4)
- macOS with Apple Silicon (M1/M2/M3/M4). macOS-only for now —
mlx-nodedoes not support Linux/CUDA yet. - C++ toolchain
- Xcode Command Line Tools —
xcode-select --install(providesclang++,make, and the macOS SDK). Note: the CLT do not include CMake — install it separately (below). - First launch setup —
sudo xcodebuild -runFirstLaunch - Metal Toolchain —
xcodebuild -downloadComponent MetalToolchain(required on macOS 26+; the build hard-fails without it)
- Xcode Command Line Tools —
- CMake ≥ 3.25 — drives the vendored MLX C++ build (it is not bundled with the Xcode CLT, despite the CLT providing the C++ compiler).
brew install cmake. - Rust toolchain —
mlx-nodeis a Rust NAPI addon, socargo/rustcmust be onPATHto build it.curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh(then restart your shell, orsource "$HOME/.cargo/env")
- Node.js ≥ 18 — provides
npm(for installing nbb and yarn).brew install nodeon macOS, or nodejs.org. - Bun —
curl -fsSL https://bun.sh/install | bash. Thebun run --bun nbb …commands throughout this README use it (recommended — 3-4x faster than Node.js for iterative inference). The installer only updates~/.bashrc/~/.zshrc; fish users must add it toPATHthemselves:fish_add_path ~/.bun/bin. To run on Node.js instead, replacebun run --bun nbbwithnbbandbun installwithnpm install. - nbb —
1.4.208—npm install -g nbb@1.4.208. The repo also pins it via thenbbscript inpackage.json, sobun run --bun nbb …resolves to1.4.208regardless of any globally installed nbb.⚠️ Avoid1.4.207specifically: it shipped a short-lived SCI regression that failed to resolve record types across namespaces (Unable to resolve symbol: cm/Node), breaking the handler loop and therefore all inference.1.4.208fixes that and additionally exposes theIPrintWithWriterprotocol, which is what lets GenMLX run stock upstreammalli(no fork needed). - Yarn — needed only to build the
mlx-nodesubmodule. You don't need a specific version: mlx-node vendors the exact release it wants (.yarn/releases/yarn-4.13.0.cjs) and anyyarnlauncher onPATHauto-delegates to it via theyarnPathsetting in.yarnrc.yml. Install one withnpm install -g yarnorbrew install yarn. (If you prefer Corepack, note it is no longer bundled with current Node.js —npm install -g corepack && corepack enable— but it's unnecessary here.) gitwith submodule support — GenMLX vendors its dependencies as nested submodules (see below).
# Clone with ALL submodules (mlx-node, mlx, instaparse, test.check forks + upstream malli).
# --recurse-submodules is required: it also pulls the MLX C++ fork nested one
# level deeper inside mlx-node (crates/mlx-sys/mlx).
git clone --recurse-submodules https://github.com/robert-johansson/genmlx
cd genmlx
# Already cloned without --recurse-submodules? Initialize them now:
# git submodule update --init --recursive
# Build mlx-node (compiles the vendored MLX C++ via CMake + the Rust NAPI addon
# via cargo + the TypeScript wrapper via tsc, ~3-4 min first build). `yarn`
# auto-delegates to the pinned 4.13.0 release vendored in .yarn/releases —
# any yarn launcher on PATH works.
cd mlx-node
yarn install
yarn build # runs build:native (cmake + cargo) AND build:ts (tsc -b).
# Do NOT use `yarn build:native` alone — without build:ts,
# @mlx-node/lm has no dist/ and the LLM API won't import.
# Needs cmake, cargo, and (macOS 26+) the Metal Toolchain on PATH.
cd ..
# Install GenMLX's native dependency. @mlx-node/core and @mlx-node/lm resolve from
# the local mlx-node submodule via file: paths in package.json, so the build above
# must have completed first.
bun install
# Verify the install — this test runs real inference through the handler loop AND
# the freshly-built native MLX addon, so it confirms both the build and your nbb
# version are good. Should print "0 failures, 0 errors".
bun run --bun nbb test/genmlx/inference_test.cljs
# Confirm a COMPATIBLE native core loaded (not a stale prebuilt). Prints
# {:ok? true ... :version "x.y.z"}; on a stale/incompatible binary the Tier-A
# install guard throws a clear rebuild instruction instead of a cryptic NAPI
# error deep in an op. (LLM checkpoints are additionally capability-checked at
# load-model — Tier-B in genmlx.llm.backend.)
bun run --bun nbb -e '(require (quote [genmlx.mlx :as mx])) (prn (mx/native-core-report))'Note — the submodules. GenMLX vendors five git submodules. Four are forks maintained alongside this repo; the fifth,
malli, now tracks upstreammetosin/mallidirectly (the earlier fork existed only for nbb 1.4.206 compatibility, which1.4.208makes unnecessary):
mlx-node— adds a customgenmlx.rsRust module with 138 module-level NAPI exports tuned for ClojureScript (Either<MxArray, number>inputs,number[]shapes, CPU-stream PRNG, fused scalar extraction). It vendors the MLX C++ source as a further nested submodule (mlxatcrates/mlx-sys/mlx), which--recurse-submodulespulls automatically andyarn buildcompiles statically (thebuild:nativepart of the script).
instaparseandtest.check— forks patched for nbb/Babashka compatibility.
malli— the official upstream library, pinned to a specific commit; no fork is needed under nbb1.4.208.All three sit directly on the nbb classpath (see
nbb.edn), so no build step is needed. malli backs schema validation (genmlx.schemas), instaparse the LLM grammar layer (genmlx.llm.grammar), and test.check the property-based tests.
Run the included examples:
bun run --bun nbb examples/genmlx/linear_regression.cljs # Bayesian regression (IS + MALA)
bun run --bun nbb examples/genmlx/hidden_markov_model.cljs # HMM with particle filtering
bun run --bun nbb examples/genmlx/adaptive_hmc.cljs # HMC with adaptive step-size
bun run --bun nbb examples/genmlx/variational_inference.cljs # VI vs MCMC comparison
bun run --bun nbb examples/genmlx/analytical_elimination.cljs # Auto-conjugacy detection (L3)
bun run --bun nbb examples/genmlx/fit_api.cljs # One-call inference (L4)
bun run --bun nbb examples/genmlx/compositional_models.cljs # Splice, Map, Switch, GFI algebra
bun run --bun nbb examples/genmlx/scan_time_series.cljs # Scan combinator for SSMs
bun run --bun nbb examples/genmlx/neural_probabilistic.cljs # Neural networks + probabilistic models(ns my-model
(:require [genmlx.mlx :as mx]
[genmlx.dist :as dist]
[genmlx.choicemap :as cm]
[genmlx.selection :as sel]
[genmlx.inference.mcmc :as mcmc])
(:require-macros [genmlx.gen :refer [gen]]))
;; Bayesian linear regression — all values stay as MLX arrays
;; Inside gen bodies, trace/splice/param are local bindings injected by the macro
(def model
(gen [xs]
(let [slope (trace :slope (dist/gaussian 0 10))
intercept (trace :intercept (dist/gaussian 0 10))]
(doseq [[j x] (map-indexed vector xs)]
(trace (keyword (str "y" j))
(dist/gaussian (mx/add (mx/multiply slope (mx/scalar x))
intercept) 1)))
slope)))
;; Observations — data generated from y = 2x + 0.1 + noise
(def xs [1.0 2.0 3.0 4.0 5.0])
(def observations
(cm/choicemap :y0 (mx/scalar 2.1) :y1 (mx/scalar 3.9) :y2 (mx/scalar 6.2)
:y3 (mx/scalar 7.8) :y4 (mx/scalar 10.1)))
;; Metropolis-Hastings
(def traces
(mcmc/mh {:samples 500 :burn 100
:selection (sel/select :slope :intercept)}
model [xs] observations))
;; Examine posterior
(let [slopes (mapv #(mx/item (cm/get-choice (:choices %) [:slope])) traces)]
(println "Posterior slope mean:" (/ (reduce + slopes) (count slopes))))
;; => ~2.0 (true slope)GenMLX progressively moves work from the host interpreter into fused MLX computation graphs. Each level adds performance without breaking the GFI semantic contract — a model written at L0 runs unchanged at L4.
Level 0: Shape-based batching — [N]-shaped arrays + broadcasting (certified, 68/68 tests)
Level 1: Compiled gen functions — schema extraction, noise transforms, compiled simulate (506+ tests)
Level 2: Compiled inference sweeps — compiled MCMC, differentiable inference, particle methods (881+ tests)
Level 3: Auto-analytical elimination — conjugacy detection, Kalman/EKF/HMM handlers (426 tests)
Level 3.5: Extended analytical — N-d EKF, combinator-aware conjugacy, MVN Kalman (150 tests)
Level 4: Single fused graph — compiled optimizer, method selection, fit API (260+ tests)
The handler system is ground truth; compilation is optimization. Compiled paths produce identical traces, scores, and weights as the handler path.
Layer 0: MLX Foundation
genmlx.mlx — thin wrapper over mlx-node (Rust genmlx.rs + CLJS)
genmlx.mlx.random — functional PRNG key management
Layer 1: Core Data Types
genmlx.choicemap — hierarchical address → value trees
genmlx.trace — immutable execution records
genmlx.selection — composable address selection algebra
Layer 2: GFI Protocols & Execution
genmlx.protocols — IGenerativeFunction, IGenerate, IUpdate, IRegenerate,
IAssess, IProject, IEdit, IUpdateWithDiffs
genmlx.handler — pure state transitions (volatile! runtime in runtime.cljs)
genmlx.edit — parametric edit interface (constraint, selection, proposal)
genmlx.diff — incremental change tracking
Layer 3: DSL + Schema + Compilation
genmlx.gen — gen macro
genmlx.dynamic — DynamicDSLFunction (full GFI + vsimulate/vgenerate)
genmlx.schema — schema extraction from gen body source forms (L1)
genmlx.compiled — compiled simulate, partial prefix, branch rewriting (L1)
genmlx.compiled_gen — compiled generate with middle-tier score function (L1)
genmlx.affine — affine dependency analysis for conjugacy (L3)
genmlx.conjugacy — conjugate prior detection and Rao-Blackwellization (L3)
genmlx.method_selection — decision tree from model metadata (L4)
genmlx.fit — one-call entry point: (fit model args data) (L4)
Layer 4: Distributions
genmlx.dist — 36 distributions, each a GFI participant
genmlx.custom_gradient — CustomGradientGF (custom gradient gen functions)
genmlx.nn — Neural network generative functions (nn->gen-fn)
Layer 5: Combinators
genmlx.combinators — Map, Unfold, Switch, Scan, Mask, Mix, Recurse,
Vectorized Switch, Contramap/Dimap
genmlx.vmap — Vmap combinator (vmap-gf / repeat-gf, full GFI)
Layer 6: Inference
genmlx.inference — IS, MH, MALA, HMC, NUTS, Gibbs, Elliptical Slice,
Involutive MCMC, MAP, SMC, SMCP3, VI, ADEV,
Amortized Inference, Kernel Composition
— Kalman, EKF, N-d EKF, HMM forward algorithm (L3)
— Compiled gradient, compiled optimizer (L4)
— Compiled SMC, differentiable inference, PMCMC (L2)
Layer 7: Vectorized
genmlx.vectorized — VectorizedTrace, batched execution, dispatch amortization
Layer 8: Verification
genmlx.gfi — the GFI algebraic law catalog from Cusumano-Towner 2020 thesis
genmlx.verify — Static validator (validate-gen-fn)
Layer 9: Domain Verticals (domain-as-GF; inference orthogonal)
genmlx.agents.gridworld/agent — MDP/POMDP agents as GFs (build-mdp, make-mdp-agent;
the policy is a GF, inference is pluggable)
genmlx.agents.inverse — goal inference / IRL (Bayesian inference over the agent)
genmlx.agents.pomdp — belief-space planning, bandits; .biased-planners, .differentiable
genmlx.llm.core — make-llm-gf: wrap LLM as DynamicGF (each token = trace site)
genmlx.llm.backend — mlx-node loader, forward pass, KV cache
genmlx.llm.grammar — DFA-constrained generation (regex → token mask)
genmlx.llm.bytes — byte-level marginalization via TokenByteTrie
genmlx.llm.codegen — reader-as-grammar for syntactically-valid ClojureScript
genmlx.llm.msa — Model Synthesis Architecture (LLM proposes prob programs)
genmlx.llm.vision — VLM input adaptation
| Distribution | Constructor | Reparameterized |
|---|---|---|
| Gaussian | (gaussian mu sigma) |
Yes |
| Uniform | (uniform lo hi) |
Yes |
| Bernoulli | (bernoulli p) |
— |
| Beta | (beta-dist alpha beta) |
— |
| Gamma | (gamma-dist shape rate) |
— |
| Exponential | (exponential rate) |
Yes |
| Categorical | (categorical logits) |
— |
| Poisson | (poisson rate) |
— |
| Laplace | (laplace loc scale) |
Yes |
| Student-t | (student-t df loc scale) |
— |
| Log-Normal | (log-normal mu sigma) |
Yes |
| Multivariate Normal | (multivariate-normal mean cov) |
Yes |
| Dirichlet | (dirichlet alpha) |
— |
| Delta | (delta v) |
— |
| Cauchy | (cauchy loc scale) |
Yes |
| Inverse Gamma | (inv-gamma shape scale) |
— |
| Geometric | (geometric p) |
— |
| Negative Binomial | (neg-binomial r p) |
— |
| Binomial | (binomial n p) |
— |
| Discrete Uniform | (discrete-uniform lo hi) |
— |
| Truncated Normal | (truncated-normal mu sigma lo hi) |
Yes |
| Piecewise Uniform | (piecewise-uniform bounds probs) |
— |
| Beta-Uniform Mixture | (beta-uniform-mixture theta alpha beta) |
— |
| Wishart | (wishart df scale) |
— |
| Inverse Wishart | (inv-wishart df scale) |
— |
| Broadcasted Normal | (broadcasted-normal mu sigma) |
Yes |
| Gaussian Vec | (gaussian-vec mu sigma) |
Yes |
| Von Mises | (von-mises mu kappa) |
— |
| Wrapped Cauchy | (wrapped-cauchy mu gamma) |
— |
| Wrapped Normal | (wrapped-normal mu sigma) |
— |
| Mixture | (mixture components log-weights) |
— |
| Product | (product dists) |
— |
| IID | (iid base-dist t) |
— |
| IID Gaussian | (iid-gaussian mu sigma t) |
Yes |
| Categorical (weights) | (categorical-weights weights) |
— |
| Weighted | (weighted probs) |
— |
Aliases: normal → gaussian, flip → bernoulli
- Importance Sampling —
importance-sampling,importance-resampling; vectorized variants viavgenerate - Metropolis-Hastings —
mh(via GFIregenerate),mh-custom(with proposal generative function) - Compiled MCMC —
compiled-mh,compiled-mala(random-walk in parameter space,mx/tidyper step) - Fused MCMC —
fused-mh,fused-mala,fused-hmc(entire chains compiled viamx/compile-fn) - Vectorized MCMC —
vectorized-compiled-mh,vectorized-mala(batched particles) - MALA —
mala(gradient-informed Langevin proposals) - HMC —
hmc(compiled leapfrog integration, adaptive step-size via dual averaging) - NUTS —
nuts(adaptive trajectory length, No-U-Turn Sampler) - Gibbs Sampling —
gibbs(systematic scan with enumerable support) - Elliptical Slice Sampling —
elliptical-slice(for Gaussian priors) - Involutive MCMC —
involutive-mh(deterministic involution-based proposals) - MAP Optimization —
map-optimize,vectorized-map-optimize(point estimates via gradient ascent) - SMC —
smc(particle filtering with resampling + rejuvenation),csmc,vsmc(vectorized) - SMCP3 —
smcp3(Sequential Monte Carlo with Probabilistic Program Proposals) - Compiled SMC —
compiled-smc(fused particle operations) - PMMH —
pmmh(Particle Marginal Metropolis-Hastings) - Particle Gibbs —
particle-gibbs(conditional SMC + Gibbs) - Variational Inference —
vi(ADVI with mean-field Gaussian guide),programmable-viwith pluggable objectives (elbo,iwelbo,wake-sleep) and gradient estimators (reinforce, reparameterization); compiled variants viacompiled-vi,compiled-programmable-vi - ADEV — automatic differentiation of expected values with reparameterization and REINFORCE strategies, vectorized GPU execution, compiled optimization loops, baseline variance reduction
- Amortized Inference —
neural-importance-sampling(learned neural proposals) - Analytical Elimination — auto-conjugacy detection (5 families), joint linear-Gaussian regression, Rao-Blackwellization; exact marginal likelihood (matches the closed form to the float32 floor, ~1e-6 nats), ESS gains up to ~50×
- Kalman Filter — handler middleware for linear-Gaussian SSMs, sequential updates, exact marginal LL
- Extended Kalman Filter — nonlinear SSMs via auto-diff linearization (1D and N-dimensional)
- HMM Forward Algorithm — discrete latent state-space models, exact marginal likelihood
- Exact Enumeration — Exact enumeration via
enumeratewith full GFI (simulate, generate, update, regenerate, assess, project); discrete latent variables with finite support - Fisher Information —
observed-fisher(Fisher information matrix, Cramer-Rao bounds) - Differentiable Inference — differentiable resampling, differentiable importance sampling
- Compiled Optimization —
mx/compile-fn+ Adam, 9.2x speedup over handler loops - Method Selection — automatic inference method from model structure (L4
fitAPI) - Kernel Composition —
chain,cycle-kernels,mix-kernels,repeat-kernel,seed - Diagnostics —
ess,r-hat,summarize,sample-quantiles
- Map — apply a generative function independently across sequences
- Unfold — sequential state-threading for time-series models
- Switch — select between branches for mixture models
- Scan — state-threading with accumulation (like
jax.lax.scan) - Mask — conditionally gate execution on a boolean
- Mix — first-class mixture model support
- Recurse — fixed-point combinator for recursive generative functions
- Vmap —
vmap-gf/repeat-gfwith full GFI (simulate, generate, update, regenerate, assess, propose, project) - Vectorized Switch — executes all branches with
[N]-shaped arrays, selects viamx/where - Contramap / Dimap — transform arguments and/or return values of generative functions
Every generative function supports the full Gen interface:
;; Forward sample
(p/simulate model args) ;; => Trace
;; Constrained execution
(p/generate model args constraints) ;; => {:trace Trace :weight MLX-scalar}
;; Update trace with new constraints
(p/update model trace new-constraints) ;; => {:trace Trace :weight MLX-scalar :discard ChoiceMap}
;; Resample selected addresses
(p/regenerate model trace selection) ;; => {:trace Trace :weight MLX-scalar}
;; Score fully-specified choices
(p/assess model args choices) ;; => {:weight MLX-scalar}
;; Log-probability of selected addresses
(p/project model trace selection) ;; => MLX-scalar
;; Parametric edit (constraint, selection, or proposal)
(p/edit model trace edit-request) ;; => {:trace Trace :weight MLX-scalar :discard ChoiceMap}The GFI is a substrate. A domain becomes a library by writing its model as a generative function — and inference stays orthogonal: the same model runs under exact, Monte Carlo, or mixed inference, chosen by the caller, never baked into the domain. Two Layer-9 verticals ship today; a third is planned.
MDP/POMDP planning, inverse planning, and multi-agent models. An agent's policy is a generative function, so every GFI operation works on it, and inference (exact value iteration, Monte Carlo rollouts, inverse-planning posteriors) is pluggable and orthogonal to the agent.
(require '[genmlx.agents.gridworld :as gw]
'[genmlx.agents.agent :as agent])
;; A gridworld MDP: reach goal :G under a small per-step cost
(def mdp (gw/build-mdp {:grid [[:empty :G] [:empty :empty]]
:utilities {:G 2.0 :timeCost -0.1}
:start [0 1] :gamma 1.0}))
;; A softmax-rational agent — its :policy IS a generative function
(def ag (agent/make-mdp-agent {:mdp mdp :alpha 1.0 :gamma 1.0 :n-iters 6}))
;; The full GFI works on the policy:
(p/simulate (dyn/auto-key (:policy ag)) [(:start-idx mdp)]) ;; => Trace with an :action choice
(p/assess (dyn/auto-key (:policy ag)) [(:start-idx mdp)]
(cm/choicemap :action 0)) ;; => {:weight log p(action)}Also in the vertical: genmlx.agents.inverse (goal inference / IRL), genmlx.agents.pomdp (belief-space planning, bandits), genmlx.agents.biased-planners (time-inconsistent agents), and gridworld/restaurant environments. Because an agent is just a GF, inverse planning is ordinary Bayesian inference over it — observe actions, infer the goal.
A local LLM wrapped as a DynamicGF: each generated token is a trace site (:t0, :t1, …) sampling from a categorical over the model's logits, with a KV cache for O(n) generation. All GFI operations apply — simulate generates text, generate constrains and scores tokens, assess scores a sequence — and grammar / byte / reader constraints compose as handler middleware (genmlx.llm.grammar, genmlx.llm.bytes, genmlx.llm.codegen). Following "sync math, async events": model loading and tokenization are async (promesa) at the I/O boundary, while the GFI ops themselves are synchronous.
(require '[genmlx.llm.backend :as llm]
'[genmlx.llm.core :as llm-core]
'[promesa.core :as pr])
(pr/let [m (llm/load-model "<model-dir>/qwen3.5-…")
gf (llm-core/make-llm-gf m) ;; a DynamicGF over [prompt-ids max-tokens]
tok (:tokenizer m)
ids (vec (llm/encode tok "The best programming language is"))
tid (first (vec (llm/encode tok " Clojure")))]
(p/simulate gf [ids 10]) ;; free generation from the LLM prior
(p/generate gf [ids 10] ;; force + score the first token
(cm/set-value cm/EMPTY :t0 (mx/scalar tid mx/int32))))The third axis: control ⊥ inference ⊥ model. A metareasoner over computation — steppable/budgeted inference, a compute-cost meter, and a controller that is itself an agent-GF (reusing genmlx.agents) pointed at the inference process — turning "how much inference, and which kind" into a first-class, rational, anytime decision. Designed, not yet in src/.
The key insight: MLX operations broadcast naturally. Sample [N] values instead of [] at each trace site, and all downstream arithmetic (log-prob, score accumulation, weight computation) just works.
;; Run model body ONCE for N particles
(dyn/vsimulate model args n key) ;; => VectorizedTrace
(dyn/vgenerate model args obs n key) ;; => VectorizedTrace with weightsVectorizedTrace— choices where leaves hold[N]-shaped arrays- 26 distributions have native batch sampling (
dist-sample-n), others fall back to sequential - Vectorized importance sampling and SMC initialization built on
vgenerate
GenMLX bridges neural networks and probabilistic programming via genmlx.nn:
;; Define a neural network
(def net (nn/sequential (nn/linear 10 32) nn/relu (nn/linear 32 1)))
;; Wrap as a deterministic generative function
(def nn-gf (nn/nn->gen-fn net))
;; Use inside gen bodies — gradients flow through the network
(def model
(gen [x]
(let [pred (splice :nn nn-gf [x])]
(trace :y (dist/gaussian pred (mx/scalar 0.1))))))Layers: linear, sequential, relu, gelu, tanh-act, sigmoid-act, layer-norm, embedding, dropout. Native MLX optimizers via nn/optimizer (:adam, :sgd, :adamw).
Save and load traces and choicemaps to JSON:
(require '[genmlx.serialize :as ser])
;; Save/load choices (recommended — reconstruct trace via generate)
(ser/save-choices-to-file (:choices trace) "choices.json")
(let [choices (ser/load-choices-from-file "choices.json")
{:keys [trace]} (p/generate model args choices)]
trace)
;; Save/load full traces (convenience)
(ser/save-trace-to-file trace "trace.json")
(ser/load-trace-from-file "trace.json" model args)# Individual test files
bun run --bun nbb test/genmlx/dist_test.cljs
bun run --bun nbb test/genmlx/schema_test.cljs
bun run --bun nbb test/genmlx/inference_test.cljs
# All core tests
for f in choicemap_test trace_test selection_test handler_test dist_test combinators_test inference_test; do
bun run --bun nbb "test/genmlx/${f}.cljs"
done
# Compatibility suites
bun run --bun nbb test/genmlx/gen_clj_compat_test.cljs # 356 assertions (from Gen.clj)
bun run --bun nbb test/genmlx/genjax_compat_test.cljs # 73 assertions (GenJAX compat)
# Vectorized tests + benchmarks
bun run --bun nbb test/genmlx/vectorized_test.cljs
bun run --bun nbb test/genmlx/vectorized_benchmark.cljs356 assertions (across 17 grouped tests) adapted from Gen.clj's test suite verify that GenMLX produces matching results:
- Distribution logpdf spot checks — values verified against scipy.stats and Gen.jl (within float32 tolerance)
- Mathematical properties — symmetry, normalization, shift invariance
- GFI semantics — simulate, generate, update (discard/weight), regenerate
- Dynamic DSL —
genmacro,trace,splice, nested tracing, score computation - Importance sampling — rejection robustness with branching models
- End-to-end — line model with constrained observations
73 assertions (across 9 grouped tests) verify parity with GenJAX's design:
- Edit interface — constraint, selection, and proposal edits
- Diff tracking — incremental update with change hints
- SMCP3 — reversible kernel proposals
- Combinators — Scan, Mask, Mix, Vectorized Switch
- Programmable VI — ELBO, IWELBO, wake-sleep objectives
- Compilation ladder — 5 levels (L0–L4) progressively fuse more work into MLX graphs
- Loop compilation — entire MCMC chains compiled into single Metal dispatches via
mx/compile-fn mx/compile-fnon score functions — JIT-compiles into cached Metal programsmx/value-and-grad— fused forward+backward in a single GPU dispatch- Auto-analytical elimination — conjugacy detection and Rao-Blackwellization (L3)
- Compiled Adam — 9.2x faster than handler loop via
mx/compile-fn+mx/value-and-grad(L4) - Adaptive step-size — HMC dual averaging (Hoffman & Gelman 2014) auto-tunes during burn-in
mx/tidy+mx/eval!discipline — bounds graph size, prevents Metal resource exhaustionmx/vmapin combinators — batch GPU execution across particles/instances- Unified memory — MCMC control flow on CPU, all numerics on GPU, zero transfer cost
Apple Silicon has a hard kernel-level limit of 499,000 simultaneous Metal buffer objects per process. This limit is identical across M1–M4 and all Pro/Max/Ultra variants. GenMLX manages this automatically in all built-in inference algorithms, but understanding it helps when writing custom inference loops or running very large models.
;; Quick snapshot of Metal resource usage
(mx/memory-report)
;; => {:active-bytes 124344
;; :cache-bytes 0
;; :peak-bytes 126712
;; :wrappers 42
;; :resource-limit 499000}
;; Individual queries
(mx/get-active-memory) ;; bytes of Metal buffers in use
(mx/get-cache-memory) ;; bytes in recycling cache
(mx/get-peak-memory) ;; high-water mark since last reset
(mx/get-wrappers-count) ;; live JS-wrapped MLX array objectsAll built-in inference algorithms clean up Metal buffers between iterations via mx/eval! (detaches computation graphs) and periodic mx/clear-cache! (releases cached buffers). Long chains run indefinitely without hitting the resource limit.
| Category | Algorithms | Cleanup strategy |
|---|---|---|
MCMC (via collect-samples) |
mh, mh-custom, gibbs, elliptical-slice, involutive-mh, all kernel combinators |
tidy-step + with-resource-guard + periodic clear-cache! |
| Compiled MCMC | compiled-mh, compiled-mala, hmc, nuts |
mx/tidy + mx/eval! per step |
| Vectorized MCMC | vectorized-compiled-mh, vectorized-mala |
Internal eval + periodic clear-cache! |
| Particle methods | smc, csmc, smcp3 |
Per-particle eval + periodic clear-cache! |
| Importance sampling | importance-sampling, vectorized IS |
Per-sample eval |
| Optimization | vi, programmable-vi, adev-optimize, wake-sleep |
mx/eval! per iter + periodic clear-cache! |
When writing your own loops over MLX operations, follow this pattern:
(loop [i 0, state init-state]
(if (>= i n-iters)
state
(let [new-state (my-step state)
;; Materialize arrays — detaches computation graph
_ (mx/eval! (:score new-state) (:weight new-state))
;; Periodically release cached Metal buffers
_ (when (zero? (mod i 50)) (mx/clear-cache!))]
(recur (inc i) new-state))))For tighter control, set the cache limit lower at program start:
(mx/set-cache-limit! (* 128 1024 1024)) ;; 128 MB cache (default is unlimited)Q: I'm getting [metal::malloc] Resource limit (499000) exceeded
This means too many Metal buffer objects are alive simultaneously. Solutions:
- Use vectorized inference —
vectorized-importance-sampling,vsmcrun the model body once for N particles instead of N times - Reduce sample count — fewer particles/samples means fewer simultaneous buffers
- Set a cache limit — add
(mx/set-cache-limit! (* 128 1024 1024))at program start to cap the buffer recycling cache - Clear cache between runs — call
(mx/clear-cache!)between separate inference calls - Use compiled inference —
compiled-mh,hmc,nutsmanage resources automatically viamx/tidy
Q: Inference is slow / memory keeps growing
- Call
(mx/eval!)on result arrays inside your loop to materialize the computation graph. Without eval, MLX builds an ever-growing lazy graph. - Check
(mx/get-active-memory)— if it grows linearly with iterations, arrays aren't being freed. Usemx/tidyto release them.
MIT
