diff --git a/examples/sample_callback.rs b/examples/sample_callback.rs new file mode 100644 index 0000000..ae04954 --- /dev/null +++ b/examples/sample_callback.rs @@ -0,0 +1,310 @@ +//! Example demonstrating sample-level data access via ProgressCallback +//! +//! This example shows how to access per-sample data through the existing +//! ProgressCallback using the `latest_sample` field in ChainProgress. + +use std::{ + f64, + sync::{Arc, Mutex}, + time::Duration, +}; + +use anyhow::Result; +use nuts_rs::{ + CpuLogpFunc, CpuMath, CpuMathError, DiagGradNutsSettings, HashMapConfig, LogpError, Model, + ProgressCallback, Sampler, +}; +use nuts_storable::HasDims; +use rand::{Rng, RngExt}; +use thiserror::Error; + +// A simple multivariate normal distribution example +#[derive(Clone, Debug)] +struct MultivariateNormal { + mean: Vec, + precision: Vec>, +} + +impl MultivariateNormal { + fn new(mean: Vec, precision: Vec>) -> Self { + Self { mean, precision } + } +} + +// Custom LogpError implementation +#[allow(dead_code)] +#[derive(Debug, Error)] +enum MyLogpError { + #[error("Recoverable error in logp calculation: {0}")] + Recoverable(String), + #[error("Non-recoverable error in logp calculation: {0}")] + NonRecoverable(String), +} + +impl LogpError for MyLogpError { + fn is_recoverable(&self) -> bool { + matches!(self, MyLogpError::Recoverable(_)) + } +} + +// Implementation of the model's logp function +#[derive(Clone)] +struct MvnLogp { + model: MultivariateNormal, +} + +impl HasDims for MvnLogp { + fn dim_sizes(&self) -> std::collections::HashMap { + std::collections::HashMap::from([ + ( + "unconstrained_parameter".to_string(), + self.model.mean.len() as u64, + ), + ("dim".to_string(), self.model.mean.len() as u64), + ]) + } +} + +impl CpuLogpFunc for MvnLogp { + type LogpError = MyLogpError; + type FlowParameters = (); + type ExpandedVector = Vec; + + fn dim(&self) -> usize { + self.model.mean.len() + } + + fn logp(&mut self, x: &[f64], grad: &mut [f64]) -> Result { + let n = x.len(); + // Compute (x - mean) + let mut diff = vec![0.0; n]; + for i in 0..n { + diff[i] = x[i] - self.model.mean[i]; + } + + let mut quad = 0.0; + // Compute quadratic form and gradient: logp = -0.5 * diff^T * P * diff + for i in 0..n { + // Compute i-th component of P * diff + let mut pdot = 0.0; + for j in 0..n { + let pij = self.model.precision[i][j]; + pdot += pij * diff[j]; + quad += diff[i] * pij * diff[j]; + } + // gradient of logp w.r.t. x_i: derivative of -0.5 * diff^T P diff is - (P * diff)_i + grad[i] = -pdot; + } + + Ok(-0.5 * quad) + } + + fn expand_vector( + &mut self, + _rng: &mut R, + array: &[f64], + ) -> Result { + // Simply return the parameter values + Ok(array.to_vec()) + } +} + +struct MvnModel { + math: CpuMath, +} + +/// Implementation of Model for the HashMap backend +impl Model for MvnModel { + type Math<'model> + = CpuMath + where + Self: 'model; + + fn math(&self, _rng: &mut R) -> Result> { + Ok(self.math.clone()) + } + + /// Generate random initial positions for the chain + fn init_position(&self, rng: &mut R, position: &mut [f64]) -> Result<()> { + // Initialize position randomly in [-2, 2] + for p in position.iter_mut() { + *p = rng.random_range(-2.0..2.0); + } + Ok(()) + } +} + +fn main() -> Result<()> { + println!("=== Sample-Level Data via ProgressCallback Example ===\n"); + println!("This example demonstrates accessing per-sample data through ProgressCallback."); + println!("The callback fires periodically (rate-limited to 10ms) with chain progress,"); + println!("including the latest sample data for each chain.\n"); + + // Create a 2D multivariate normal distribution + let mean = vec![0.0, 0.0]; + let precision = vec![vec![1.0, 0.5], vec![0.5, 1.0]]; + let mvn = MultivariateNormal::new(mean, precision); + + // Number of chains + let num_chains = 2; + + // Configure number of draws + let num_tune = 50; + let num_draws = 100; + + // Configure MCMC settings + let mut settings = DiagGradNutsSettings::default(); + settings.num_chains = num_chains as _; + settings.num_tune = num_tune; + settings.num_draws = num_draws as _; + settings.seed = 42; + + let model = MvnModel { + math: CpuMath::new(MvnLogp { model: mvn }), + }; + + // Track callback invocations for demonstration + let callback_count = Arc::new(Mutex::new(0)); + let callback_count_clone = callback_count.clone(); + + let divergence_count = Arc::new(Mutex::new(0)); + let divergence_count_clone = divergence_count.clone(); + + // Create progress callback that accesses latest sample data + let progress_callback = ProgressCallback { + callback: Box::new(move |elapsed, chains| { + let mut count = callback_count_clone.lock().unwrap(); + *count += 1; + + // Print progress information periodically + if *count <= 10 { + println!( + "Progress callback #{}: Elapsed: {:.1}s, {} chains", + count, + elapsed.as_secs_f64(), + chains.len() + ); + + for chain_progress in chains.iter() { + // Access the latest sample data if available + if let Some(sample_data) = &chain_progress.latest_sample { + // Demonstrate accessing optional fields with proper handling + let energy_str = sample_data + .draw_energy + .map(|e| format!("{:.3}", e)) + .unwrap_or_else(|| "N/A".to_string()); + let diverging_str = sample_data + .diverging + .map(|d| d.to_string()) + .unwrap_or_else(|| "N/A".to_string()); + let tree_depth_str = sample_data + .tree_depth + .map(|d| d.to_string()) + .unwrap_or_else(|| "N/A".to_string()); + + println!( + " Chain {}: Draw {}/{}, Energy: {}, Diverging: {}, Tree depth: {}", + sample_data.chain_id, + chain_progress.finished_draws, + chain_progress.total_draws, + energy_str, + diverging_str, + tree_depth_str + ); + + if let Some(step_size) = sample_data.step_size { + println!( + " Step size: {:.6}, Tuning: {}", + step_size, sample_data.is_tuning + ); + } + + if let Some(max_depth) = sample_data.reached_max_treedepth { + if max_depth { + println!(" ⚠ Maximum tree depth reached!"); + } + } + + // Track divergences + if sample_data.diverging.unwrap_or(false) { + let mut div_count = divergence_count_clone.lock().unwrap(); + *div_count += 1; + } + } + } + println!(); + } else if *count == 11 { + println!(" ... (suppressing further callback output) ...\n"); + } + }), + rate: Duration::from_millis(10), // Rate limit: at most one callback per 10ms + }; + + // Create a new sampler with the progress callback + let trace_config = HashMapConfig::new(); + let mut sampler = Sampler::new( + model, + settings, + trace_config, + 4, // num_cores + Some(progress_callback), // progress callback with sample data access + )?; + + println!("Starting sampling with progress callback...\n"); + + // Wait for sampling to complete + let traces = loop { + match sampler.wait_timeout(std::time::Duration::from_millis(100)) { + nuts_rs::SamplerWaitResult::Trace(traces) => break traces, + nuts_rs::SamplerWaitResult::Timeout(s) => sampler = s, + nuts_rs::SamplerWaitResult::Err(e, _) => return Err(e), + } + }; + + println!("\n=== Sampling Complete ==="); + println!( + "Total callback invocations: {}", + *callback_count.lock().unwrap() + ); + println!( + "Divergences detected via callback: {}", + *divergence_count.lock().unwrap() + ); + println!("Number of chains: {}", traces.len()); + + // Show some basic statistics from the traces + for (chain_idx, chain_result) in traces.iter().enumerate() { + println!("\nChain {}:", chain_idx); + + // Count divergences from stats + if let Some(nuts_rs::HashMapValue::Bool(divergences)) = chain_result.stats.get("diverging") + { + let div_count = divergences.iter().filter(|&&d| d).count(); + println!(" Divergences in trace: {}", div_count); + } + + // Calculate mean position + if let Some(nuts_rs::HashMapValue::F64(positions)) = chain_result.draws.get("theta") { + if positions.len() >= 2 { + let x_mean: f64 = + positions.iter().step_by(2).sum::() / (positions.len() / 2) as f64; + let y_mean: f64 = + positions.iter().skip(1).step_by(2).sum::() / (positions.len() / 2) as f64; + println!(" Mean position: [{:.4}, {:.4}]", x_mean, y_mean); + } + } + } + + println!("\n✓ Example completed successfully!"); + println!("\nKey features demonstrated:"); + println!(" - ProgressCallback provides both chain progress and latest sample data"); + println!(" - Time-based rate limiting (10ms) prevents excessive overhead"); + println!( + " - latest_sample includes rich optional data (energy, divergence, tree depth, etc.)" + ); + println!(" - All sampler-specific stats are Option for compatibility with other samplers"); + println!(" - Works seamlessly with multi-chain sampling"); + println!(" - Single callback mechanism for all monitoring needs"); + + Ok(()) +} diff --git a/src/chain.rs b/src/chain.rs index 222cc54..df0bc1e 100644 --- a/src/chain.rs +++ b/src/chain.rs @@ -177,6 +177,10 @@ where tuning: self.strategy.is_tuning(), step_size: self.hamiltonian.step_size(), num_steps: self.strategy.last_num_steps(), + depth: Some(info.depth), + reached_maxdepth: Some(info.reached_maxdepth), + initial_energy: Some(info.initial_energy), + draw_energy: Some(info.draw_energy), }; self.draw_count += 1; diff --git a/src/lib.rs b/src/lib.rs index 29f7248..f262886 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -129,7 +129,7 @@ pub use model::Model; pub use nuts::NutsError; pub use sampler::{ ChainProgress, DiagGradNutsSettings, LowRankNutsSettings, NutsSettings, Progress, - ProgressCallback, Sampler, SamplerWaitResult, Settings, TransformedNutsSettings, + ProgressCallback, SampleData, Sampler, SamplerWaitResult, Settings, TransformedNutsSettings, sample_sequentially, }; pub use sampler_stats::SamplerStats; diff --git a/src/sampler.rs b/src/sampler.rs index 1a93c17..fa5fec2 100644 --- a/src/sampler.rs +++ b/src/sampler.rs @@ -140,6 +140,12 @@ pub struct Progress { pub tuning: bool, pub step_size: f64, pub num_steps: u64, + + // NUTS-specific fields (None for non-NUTS samplers) + pub depth: Option, + pub reached_maxdepth: Option, + pub initial_energy: Option, + pub draw_energy: Option, } mod private { @@ -466,8 +472,35 @@ pub fn sample_sequentially<'math, M: Math + 'math, R: Rng + ?Sized>( Ok((0..draws).map(move |_| sampler.draw())) } -#[non_exhaustive] -#[derive(Clone, Debug)] +/// Data for the most recent sample from a chain +#[derive(Debug, Clone)] +pub struct SampleData { + /// Chain identifier + pub chain_id: u64, + /// Draw number within the chain + pub draw: u64, + /// Whether this sample is from the tuning phase + pub is_tuning: bool, + + // NUTS-specific statistics (all optional to support other samplers) + /// NUTS tree depth for this sample + pub tree_depth: Option, + /// Whether the trajectory reached the maximum tree depth + pub reached_max_treedepth: Option, + /// Whether this sample had a divergence + pub diverging: Option, + + // Energy statistics + /// Energy at the start of the trajectory + pub initial_energy: Option, + /// Energy at the sampled point + pub draw_energy: Option, + + /// Current step size + pub step_size: Option, +} + +#[derive(Clone)] pub struct ChainProgress { pub finished_draws: usize, pub total_draws: usize, @@ -479,6 +512,8 @@ pub struct ChainProgress { pub step_size: f64, pub runtime: Duration, pub divergent_draws: Vec, + /// The most recent sample from this chain (if available) + pub latest_sample: Option, } impl ChainProgress { @@ -494,6 +529,7 @@ impl ChainProgress { total_num_steps: 0, runtime: Duration::ZERO, divergent_draws: Vec::new(), + latest_sample: None, } } @@ -509,6 +545,18 @@ impl ChainProgress { self.total_num_steps += stats.num_steps as usize; self.step_size = stats.step_size; self.runtime += draw_duration; + + self.latest_sample = Some(SampleData { + chain_id: stats.chain, + draw: stats.draw, + is_tuning: stats.tuning, + tree_depth: stats.depth, + reached_max_treedepth: stats.reached_maxdepth, + diverging: Some(stats.diverging), + initial_energy: stats.initial_energy, + draw_energy: stats.draw_energy, + step_size: Some(stats.step_size), + }); } } @@ -537,7 +585,7 @@ impl ChainProcess { } fn progress(&self) -> ChainProgress { - self.progress.lock().expect("Poisoned lock").clone() + (*self.progress.lock().expect("Poisoned lock")).clone() } fn resume(&self) -> Result<()> { @@ -720,6 +768,10 @@ pub struct ProgressCallback { } impl Sampler { + /// Create a new sampler. + /// + /// The optional `callback` is invoked periodically with progress updates for all chains. + /// Sample-level data is accessible via the `latest_sample` field in `ChainProgress`. pub fn new( model: M, settings: S, @@ -851,11 +903,13 @@ impl Sampler { Ok(SamplerCommand::Progress) => { let progress = chains.iter().map(|chain| chain.progress()).collect_vec(); - responses_tx.send(SamplerResponse::Progress(progress.into())).map_err(|e| { - anyhow::anyhow!( + responses_tx + .send(SamplerResponse::Progress(progress.into())) + .map_err(|e| { + anyhow::anyhow!( "Could not send progress response to controller thread: {e}" ) - })?; + })?; } Ok(SamplerCommand::Inspect) => { let traces = chains @@ -871,11 +925,13 @@ impl Sampler { .flatten() .collect_vec(); let finalized_trace = trace.inspect(traces)?; - responses_tx.send(SamplerResponse::Inspect(finalized_trace)).map_err(|e| { - anyhow::anyhow!( + responses_tx + .send(SamplerResponse::Inspect(finalized_trace)) + .map_err(|e| { + anyhow::anyhow!( "Could not send inspect response to controller thread: {e}" ) - })?; + })?; } Ok(SamplerCommand::Flush) => { for chain in chains.iter() { diff --git a/src/state.rs b/src/state.rs index fac1893..c0ae023 100644 --- a/src/state.rs +++ b/src/state.rs @@ -104,11 +104,10 @@ impl> State { impl> Drop for State { fn drop(&mut self) { let rc = unsafe { std::mem::ManuallyDrop::take(&mut self.inner) }; - if (Rc::strong_count(&rc) == 1) - && (Rc::weak_count(&rc) == 0) - && let Some(storage) = rc.reuser.upgrade() - { - storage.free_states.borrow_mut().push(rc); + if (Rc::strong_count(&rc) == 1) && (Rc::weak_count(&rc) == 0) { + if let Some(storage) = rc.reuser.upgrade() { + storage.free_states.borrow_mut().push(rc); + } } } } diff --git a/tests/sample_callback_test.rs b/tests/sample_callback_test.rs new file mode 100644 index 0000000..782d979 --- /dev/null +++ b/tests/sample_callback_test.rs @@ -0,0 +1,153 @@ +//! Unit tests for sample-level data access via ProgressCallback +//! +//! These tests verify that sample data can be accessed through ChainProgress +//! without requiring all heavy dependencies to compile. + +use std::sync::{Arc, Mutex}; +use std::time::Duration; + +#[test] +fn test_sample_data_type_exists() { + // This test verifies that SampleData is properly exported + // It will fail to compile if the type is not available + + let _sample_data = nuts_rs::SampleData { + chain_id: 0, + draw: 42, + is_tuning: true, + tree_depth: Some(5), + reached_max_treedepth: Some(false), + diverging: Some(false), + initial_energy: None, + draw_energy: Some(-10.5), + step_size: Some(0.1), + }; +} + +#[test] +fn test_progress_callback_creation() { + // Verify we can create a ProgressCallback that accesses sample data + let callback_count = Arc::new(Mutex::new(0)); + let callback_count_clone = callback_count.clone(); + + let _callback = nuts_rs::ProgressCallback { + callback: Box::new(move |elapsed, chains| { + let mut count = callback_count_clone.lock().unwrap(); + *count += 1; + + // Verify we can access elapsed time + let _ = elapsed.as_secs_f64(); + + // Verify we can access chain progress and sample data + for chain_progress in chains.iter() { + let _ = chain_progress.finished_draws; + let _ = chain_progress.total_draws; + + // Verify we can access latest_sample + if let Some(sample_data) = &chain_progress.latest_sample { + let _ = sample_data.chain_id; + let _ = sample_data.draw; + let _ = sample_data.is_tuning; + let _ = sample_data.tree_depth; + let _ = sample_data.reached_max_treedepth; + let _ = sample_data.diverging; + let _ = sample_data.initial_energy; + let _ = sample_data.draw_energy; + let _ = sample_data.step_size; + } + } + }), + rate: Duration::from_millis(100), + }; +} + +#[test] +fn test_chain_progress_with_sample_data() { + // Test that ChainProgress can hold sample data + let sample_data = nuts_rs::SampleData { + chain_id: 1, + draw: 10, + is_tuning: true, + tree_depth: Some(3), + reached_max_treedepth: Some(false), + diverging: Some(false), + initial_energy: None, + draw_energy: Some(-5.2), + step_size: Some(0.05), + }; + + // Verify we can access sample data fields + assert_eq!(sample_data.chain_id, 1); + assert_eq!(sample_data.draw, 10); + assert!(sample_data.is_tuning); + assert_eq!(sample_data.tree_depth, Some(3)); + assert_eq!(sample_data.reached_max_treedepth, Some(false)); + assert_eq!(sample_data.diverging, Some(false)); + assert_eq!(sample_data.initial_energy, None); + assert_eq!(sample_data.draw_energy, Some(-5.2)); + assert_eq!(sample_data.step_size, Some(0.05)); +} + +#[test] +fn test_sample_data_clone() { + // Verify SampleData implements Clone + let data1 = nuts_rs::SampleData { + chain_id: 0, + draw: 1, + is_tuning: false, + tree_depth: Some(4), + reached_max_treedepth: Some(true), + diverging: Some(true), + initial_energy: Some(-2.5), + draw_energy: Some(-3.0), + step_size: Some(0.1), + }; + + let data2 = data1.clone(); + + assert_eq!(data1.chain_id, data2.chain_id); + assert_eq!(data1.draw, data2.draw); + assert_eq!(data1.is_tuning, data2.is_tuning); + assert_eq!(data1.tree_depth, data2.tree_depth); + assert_eq!(data1.reached_max_treedepth, data2.reached_max_treedepth); + assert_eq!(data1.diverging, data2.diverging); + assert_eq!(data1.initial_energy, data2.initial_energy); + assert_eq!(data1.draw_energy, data2.draw_energy); + assert_eq!(data1.step_size, data2.step_size); +} + +#[test] +fn test_progress_callback_invocation() { + // Test that the progress callback can be invoked with sample data + let callback_count = Arc::new(Mutex::new(0)); + let sample_count = Arc::new(Mutex::new(0)); + + let callback_count_clone = callback_count.clone(); + let sample_count_clone = sample_count.clone(); + + let mut callback = nuts_rs::ProgressCallback { + callback: Box::new(move |_elapsed, chains| { + let mut count = callback_count_clone.lock().unwrap(); + *count += 1; + + // Count chains with sample data + for chain_progress in chains.iter() { + if let Some(sample_data) = &chain_progress.latest_sample { + let mut samples = sample_count_clone.lock().unwrap(); + *samples += 1; + + // Verify we can access the data + let _ = sample_data.chain_id; + let _ = sample_data.draw_energy; + } + } + }), + rate: Duration::from_millis(100), + }; + + // Simulate invoking the callback (in real usage, the sampler does this) + let chains = vec![]; + (callback.callback)(Duration::from_secs(1), chains.into()); + + assert_eq!(*callback_count.lock().unwrap(), 1); +}