Companion code for "Dual-space posterior sampling for Bayesian inference in constrained inverse problems". This package implements Stein Variational Gradient Descent (SVGD) with ADMM constraints for Bayesian inference and sampling.
Make sure you have matplotlib installed in your Python environment for plotting:
pip install matplotlibConfigure PyCall to use your Python installation:
ENV["PYTHON"] = "/usr/bin/python3" # Adjust path to your Python
using Pkg
Pkg.build("PyCall")
# Restart Juliajulia --project=. -e 'using Pkg; Pkg.instantiate()'The Rosenbrock example demonstrates conditional posterior inference
Configuration files are in config/. Key parameters:
| Parameter | ADMM-SVGD | Standard SVGD |
|---|---|---|
| Particles | 1000 | 1000 |
| Step size ( |
0.15 | 0.05 |
| Iterations | 2500 | 5000 |
| Penalty ( |
1.0 | -- |
Step 1: Run ADMM-SVGD conditional sampling. This generates the test instances (X_fixed, Y_fixed) and runs ADMM-SVGD for all five observations.
julia --project=. scripts/admm_svgd_conditional_sampling.jlStep 2: Run standard SVGD conditional sampling. This loads the same test instances from Step 1 and runs plain SVGD with the direct posterior gradient.
julia --project=. scripts/svgd_conditional_sampling.jlStep 3: Generate paper figures. This produces all comparison figures: prior, data, ADMM-SVGD posteriors, standard SVGD posteriors, combined overlay, convergence diagnostics, and Q-Q plots.
julia --project=. scripts/admm_svgd_conditional_paper_figures.jlFigures are saved to plots/ and to the paper figures directory.
Sample from the Rosenbrock distribution using ADMM-SVGD and visualize the results:
julia --project=. scripts/admm_svgd_sampling.jl
julia --project=. scripts/admm_svgd_visualization.jlOptionally, run pure SVGD (without ADMM) for comparison:
julia --project=. scripts/pure_svgd_sampling.jl
julia --project=. scripts/pure_svgd_visualization.jlconfig/ # Configuration files (JSON)
admm_svgd_conditional_sampling.json # ADMM-SVGD parameters
svgd_conditional_sampling.json # Standard SVGD parameters
src/
SVGDADMMSampler.jl # Main module
sampling/
admm_svgd.jl # ADMMSVGDSampler, step!, compute_bandwidth, svgd_update!
sample.jl # MCMC sampler (MALA)
scripts/
admm_svgd_conditional_sampling.jl # ADMM-SVGD conditional sampling
svgd_conditional_sampling.jl # Standard SVGD conditional sampling
admm_svgd_conditional_paper_figures.jl # Generate all paper figures
data/ # Saved results (JLD2, managed by DrWatson)
plots/ # Generated figures
Ali Siahkoohi (alisk@ucf.edu)
MIT