Conversation
…perature scaling ### Changes: * Add `num_particles`, `gradient_clip_norm`, and `kl_tau` to VI `update_kwargs` - `num_particles` is passed to `TraceMeanField_ELBO`/`Trace_ELBO` for multi-particle gradient estimates - `gradient_clip_norm` chains `optax.clip_by_global_norm` before the base optimizer - `kl_tau` scales prior log-prob by `tau * N_data / N_neurons` via `numpyro.handlers.scale` (Huix et al. 2022) * Switch all optimizers from numpyro-native to optax via `optax_to_numpyro`, removing the dual-path logic * Remove `clipped_adam`, `momentum`, and `reduce_on_plateau` scheduler * Remove `batch_size` parameter from `create_update_model`; read from `_update_kwargs` instead * Add `optax` to `pyproject.toml` dependencies
* Introduces `kl_annealing_fraction` as a new VI update kwarg for BayesianNeuralNetwork. When set, the KL divergence term in the ELBO is linearly warmed up from 0 → 1 over the first `kl_annealing_fraction * total_steps` training steps, then held at 1.0 for the remainder. * The annealing factor is threaded as a traced JAX scalar through the `model` and `guide` callables so the warm-up schedule runs entirely inside `lax.scan` without Python-loop overhead or recompilation. The factor composes multiplicatively with the existing `kl_tau` temperature scale when both are active. * Validation rejects values outside [0, 1]. Hypothesis-based tests cover the full cross-product with existing VI options (num_particles, gradient_clip_norm, kl_tau, lr_scheduler).
📝 WalkthroughWalkthroughThe changes introduce optax-based optimizer construction with optional learning rate scheduling and gradient clipping, add KL control configuration parameters for variational inference, and implement KL annealing in the SVI training loop. The project version increments to 6.1.0 with optax added as a dependency. Changes
Sequence DiagramsequenceDiagram
participant Training as Training Loop
participant SVI as SVI Handler
participant Guide as Guide/Model
participant KLScale as KL Scale Handler
participant Loss as Loss Computation
Training->>Training: Initialize with kl_annealing_factor=1.0
Training->>Training: For each step, compute annealing factor
Note over Training: If kl_annealing enabled, use jax.lax.scan
Training->>SVI: Call svi.update(data, kl_annealing_factor)
SVI->>KLScale: Wrap guide with scale handler
SVI->>Guide: Sample parameters with KL scaling
Guide->>KLScale: Apply scaling to parameter sampling
KLScale-->>Guide: Scaled samples returned
Guide->>Loss: Forward pass with scaled samples
Loss-->>SVI: Compute ELBO loss
SVI-->>Training: Return loss scalar
Training->>Training: Update step, continue iteration
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
⚔️ Resolve merge conflicts
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (4)
tests/test_model.py (3)
1849-1852:assert_array_equalon float losses across two SVI runs is brittle.Even though both branches follow the same code path (
kl_annealing_active = False), comparingbnn_none._approx_historyandbnn_zero._approx_historywithassert_array_equalrequires bit-identical floats across two independent JAX traces/compiles. JAX/XLA generally honors this for the same code on the same device, but trivial differences (e.g. CPU vs. GPU CI runners, future XLA layout changes, parallel reduction orderings) can produce ULP-level deltas and make the test flaky.
np.testing.assert_allclose(..., rtol=0, atol=0)is identical in semantics today but lets you loosen later without rewriting the assertion; alternativelyassert_allclose(..., rtol=1e-6)would be a safer floor.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/test_model.py` around lines 1849 - 1852, Replace the brittle bit-exact comparison between bnn_none._approx_history and bnn_zero._approx_history with a tolerance-based comparison: change the np.testing.assert_array_equal(...) call to np.testing.assert_allclose(...) (or assert_allclose(..., rtol=1e-6, atol=0) / at least assert_allclose(..., rtol=0, atol=0) to allow future loosening) so floating-point ULP differences from separate JAX SVI runs won’t make the test flaky.
1822-1894: KL annealing tests don't exercise the probability path.These four tests cover construction, equivalence of inactive paths, divergence vs.
kl_tau-only, and serialization — but none of them callbnn.sample_proba(...)afterbnn.update(...)to verify the trained model still produces well-formed[0, 1]probabilities. Since KL annealing directly affects the posterior the predictive distribution is drawn from, a smoke check on the probability output would catch regressions where annealing leavesmu/sigmain a bad state (NaN, infinite, out-of-range).Consider augmenting at least
test_kl_annealing_interacts_with_kl_tau(and ideally the round-trip test) with the same probability/finiteness assertions used intest_vi_training_options(lines 900-902):result = bnn_tau_anneal.sample_proba(context=context) assert all(0 <= p[0] <= 1 for p in result) assert all(np.isfinite(p[1]) for p in result)As per coding guidelines:
**/*test*.py: Always apply callable cost and probability logic testing.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/test_model.py` around lines 1822 - 1894, Add probability/finiteness smoke checks after updating models in the KL-annealing tests: call sample_proba on the trained instances (e.g., bnn_tau_anneal and bnn_tau_only in test_kl_annealing_interacts_with_kl_tau, and bnn or bnn2 in test_kl_annealing_fraction_serialization_round_trip) and assert outputs are valid probabilities and finite (e.g., each p in result has 0<=p[0]<=1 and p[1] is finite). This mirrors the checks in test_vi_training_options and ensures kl_annealing doesn't produce NaN/inf or out-of-range probabilities. Ensure you call sample_proba after the update() or after reconstruction so the smoke check validates the trained/post-serialization model.
793-811:test_lr_scheduler_validonly checks the optimizer is constructed; no end-to-end exercise.The test confirms
bnn._obj_optimizer is not Nonebut never runsbnn.update(...)orbnn.sample_proba(...). That means a scheduler that builds successfully but produces NaN gradients, mis-shaped learning rates, or otherwise breaks SVI updates would still pass. Given the coding guidelines (**/*test*.py: Always apply callable cost and probability logic testing) and the fact thattest_vi_training_optionsalready exercises a tiny update +sample_probaround-trip, mirroring that pattern here would noticeably strengthen the LR-scheduler coverage.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/test_model.py` around lines 793 - 811, Extend test_lr_scheduler_valid to perform a small end-to-end exercise after constructing the BNN: call BayesianNeuralNetwork.cold_start as currently done, then run a minimal update (e.g., bnn.update with a tiny synthetic batch or the same fixtures used in test_vi_training_options) and invoke bnn.sample_proba on a small input to ensure training and inference succeed; assert the update returns a finite loss (not NaN/inf) and that sample_proba returns probabilities of the expected shape and finite values. This uses the existing test_lr_scheduler_valid, bnn.update, and bnn.sample_proba symbols so the scheduler is validated during an actual SVI step and inference round-trip.pybandits/model.py (1)
1530-1530:n_samples: PositiveIntis misleading — callers pass a traced JAX scalar.
_create_update_model.modelinvokeskl_scale(n_samples, kl_annealing_factor)withn_samples = x.shape[0], which is a traced JAX integer when the function is JIT-compiled (e.g. insidelax.scan). A traced scalar is not a Pythonintand won't satisfy anyPositiveIntruntime validator. There is novalidate_calldecorator here so it does not actually fail today, but the annotation is misleading and will trip up anyone who later adds Pydantic validation.Consider widening the annotation to something like
Union[PositiveInt, jax.Array]or simplyintwith a comment noting traced scalars are accepted.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@pybandits/model.py` at line 1530, The type annotation for _kl_scale_ctx's parameter n_samples is misleading because callers pass JAX-traced scalars (e.g. from _create_update_model.model calling kl_scale with x.shape[0]), so change the annotation on def _kl_scale_ctx to accept traced JAX values—either widen it to Union[PositiveInt, jax.Array] or to plain int—and add a short comment on the parameter stating that JAX traced scalars (jax.Array) are allowed to avoid future Pydantic/validation surprises; ensure related uses such as kl_scale and calls in _create_update_model.model remain type-compatible.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@pybandits/model.py`:
- Around line 2037-2048: The kl annealing currently computes kl_annealing_factor
in _svi_body as (step.astype(jnp.float32) + 1.0) / kl_warmup_f which makes the
first step > 0; change the formula to (step.astype(jnp.float32)) / kl_warmup_f
so the first step yields 0 and the factor linearly reaches 1.0 after kl_warmup_f
steps; update the expression computing kl_annealing_factor in _svi_body (which
is passed into svi.update) accordingly and ensure jnp.minimum(1.0, ...) is still
applied.
- Around line 2056-2060: The JIT for _run_epoch currently marks the epoch length
argument `n` static (static_argnums=(1,)), so when the final epoch uses a
different step count (because epoch_steps_list can include a smaller `remaining`
epoch) it triggers an extra XLA compile; change to use a non-static,
numeric-length argument or make all epochs the same length. Concretely: either
remove `n` from static_argnums on the jax.jit of _run_epoch (so pass an
int32/array length like the active path does) or normalize `epoch_steps_list`
(drop/absorb the remainder so every entry equals steps_per_epoch) so _run_epoch
is always called with the same `n`; update references to _run_epoch, _svi_body,
epoch_steps_list, num_steps and steps_per_epoch accordingly.
- Around line 1135-1150: The current annotation-based checks in
_resolve_optax_fn (which references self._optax_return_types and optax types
like optax.Schedule / GradientTransformationExtraArgs) are brittle; change the
function to first tolerate missing or string annotations (treat
inspect.Signature.empty or str annotations as "unreliable") and then perform a
pragmatic duck-typed validation fallback: try to call the candidate fn with
small dummy inputs appropriate to the kind (e.g., for "schedule" call fn(0) or
fn(learning_rate=0.01) for optimizer factories) and verify the returned object
supports the expected protocol (e.g., has .init and .update for
GradientTransformation-like objects, or is callable/returns numbers for
schedules), catching any exceptions and raising a clear ValueError indicating
either "missing/unparseable annotation" or "callable did not produce expected
interface" and returning fn on success; keep using the name checks and callable
guard but remove strict issubclass/== checks against return_annotation.
---
Nitpick comments:
In `@pybandits/model.py`:
- Line 1530: The type annotation for _kl_scale_ctx's parameter n_samples is
misleading because callers pass JAX-traced scalars (e.g. from
_create_update_model.model calling kl_scale with x.shape[0]), so change the
annotation on def _kl_scale_ctx to accept traced JAX values—either widen it to
Union[PositiveInt, jax.Array] or to plain int—and add a short comment on the
parameter stating that JAX traced scalars (jax.Array) are allowed to avoid
future Pydantic/validation surprises; ensure related uses such as kl_scale and
calls in _create_update_model.model remain type-compatible.
In `@tests/test_model.py`:
- Around line 1849-1852: Replace the brittle bit-exact comparison between
bnn_none._approx_history and bnn_zero._approx_history with a tolerance-based
comparison: change the np.testing.assert_array_equal(...) call to
np.testing.assert_allclose(...) (or assert_allclose(..., rtol=1e-6, atol=0) / at
least assert_allclose(..., rtol=0, atol=0) to allow future loosening) so
floating-point ULP differences from separate JAX SVI runs won’t make the test
flaky.
- Around line 1822-1894: Add probability/finiteness smoke checks after updating
models in the KL-annealing tests: call sample_proba on the trained instances
(e.g., bnn_tau_anneal and bnn_tau_only in
test_kl_annealing_interacts_with_kl_tau, and bnn or bnn2 in
test_kl_annealing_fraction_serialization_round_trip) and assert outputs are
valid probabilities and finite (e.g., each p in result has 0<=p[0]<=1 and p[1]
is finite). This mirrors the checks in test_vi_training_options and ensures
kl_annealing doesn't produce NaN/inf or out-of-range probabilities. Ensure you
call sample_proba after the update() or after reconstruction so the smoke check
validates the trained/post-serialization model.
- Around line 793-811: Extend test_lr_scheduler_valid to perform a small
end-to-end exercise after constructing the BNN: call
BayesianNeuralNetwork.cold_start as currently done, then run a minimal update
(e.g., bnn.update with a tiny synthetic batch or the same fixtures used in
test_vi_training_options) and invoke bnn.sample_proba on a small input to ensure
training and inference succeed; assert the update returns a finite loss (not
NaN/inf) and that sample_proba returns probabilities of the expected shape and
finite values. This uses the existing test_lr_scheduler_valid, bnn.update, and
bnn.sample_proba symbols so the scheduler is validated during an actual SVI step
and inference round-trip.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 343828ea-f7fd-4f54-8772-c6257cdf38ff
📒 Files selected for processing (3)
pybandits/model.pypyproject.tomltests/test_model.py
| fn = getattr(optax, name, None) | ||
| if fn is None or not callable(fn): | ||
| raise ValueError(f"Invalid {kind}: '{name}' is not a callable attribute of optax.") | ||
| expected = self._optax_return_types[kind] | ||
| return_annotation = inspect.signature(fn).return_annotation | ||
| if isinstance(expected, type): | ||
| # e.g. GradientTransformationExtraArgs — check via issubclass | ||
| valid = isinstance(return_annotation, type) and issubclass(return_annotation, expected) | ||
| else: | ||
| # e.g. optax.Schedule (a generic alias) — check for equality | ||
| valid = return_annotation == expected | ||
| if not valid: | ||
| raise ValueError( | ||
| f"Invalid {kind}: '{name}' does not return {expected} (got return annotation: {return_annotation})." | ||
| ) | ||
| return fn |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Verify that the optax functions exercised in tests still carry the expected
# concrete return annotations on the pinned optax 0.1.x range.
python - <<'PY'
import inspect, optax
for name in [
"sgd", "adam",
"exponential_decay", "cosine_decay_schedule",
"warmup_cosine_decay_schedule", "linear_schedule",
]:
fn = getattr(optax, name, None)
ann = inspect.signature(fn).return_annotation if fn else "<missing>"
print(f"{name}: {ann!r} (type={type(ann).__name__})")
print("optax.Schedule =", optax.Schedule)
print("optax.GradientTransformation =", optax.GradientTransformation)
PYRepository: PlaytikaOSS/pybandits
Length of output: 183
🏁 Script executed:
# First, explore the structure and find the requirements
fd -type f -name "setup.py" -o -name "pyproject.toml" -o -name "requirements*.txt"Repository: PlaytikaOSS/pybandits
Length of output: 236
🏁 Script executed:
# Read the relevant section of model.py around lines 1135-1150
sed -n '1100,1160p' pybandits/model.py | cat -nRepository: PlaytikaOSS/pybandits
Length of output: 2847
🏁 Script executed:
# Check the full _resolve_optax_fn method and surrounding context
sed -n '1110,1160p' pybandits/model.py | cat -nRepository: PlaytikaOSS/pybandits
Length of output: 2406
🏁 Script executed:
find . -type f \( -name "setup.py" -o -name "pyproject.toml" \)Repository: PlaytikaOSS/pybandits
Length of output: 82
🏁 Script executed:
# Search for optax version constraints
rg "optax" --type python -i | grep -i "version\|require" | head -20Repository: PlaytikaOSS/pybandits
Length of output: 96
🏁 Script executed:
# Look for tests of _resolve_optax_fn
rg "_resolve_optax_fn|optimizer_name|lr_scheduler_name" --type pythonRepository: PlaytikaOSS/pybandits
Length of output: 96
🏁 Script executed:
# Read pyproject.toml for dependencies
cat pyproject.tomlRepository: PlaytikaOSS/pybandits
Length of output: 3089
🏁 Script executed:
# Search for tests and usage of resolve_optax_fn without type filter
rg "_resolve_optax_fn" -A 5 -B 2Repository: PlaytikaOSS/pybandits
Length of output: 1359
🏁 Script executed:
# Search for optimizer_name and lr_scheduler_name usage (which triggers _resolve_optax_fn)
rg "optimizer_name|lr_scheduler_name" -i -A 3 -B 1Repository: PlaytikaOSS/pybandits
Length of output: 47
🏁 Script executed:
# Find test files
find . -name "*test*.py" -type fRepository: PlaytikaOSS/pybandits
Length of output: 523
🏁 Script executed:
# Search for tests that use optimizer_type or lr_scheduler_type
rg "optimizer_type|lr_scheduler_type" -A 2 -B 1Repository: PlaytikaOSS/pybandits
Length of output: 12176
🏁 Script executed:
# Look for actual test cases involving BayesianNeuralNetworkVI or similar models
rg "BayesianNeuralNetworkVI|adam|sgd" --max-count=20Repository: PlaytikaOSS/pybandits
Length of output: 2209
🏁 Script executed:
# Read the test_model.py file to see actual tests
sed -n '1,100p' tests/test_model.py | head -50Repository: PlaytikaOSS/pybandits
Length of output: 1855
🏁 Script executed:
# Look for the test with dummy_optimizer that should fail
rg "dummy_optimizer" -A 5 -B 5Repository: PlaytikaOSS/pybandits
Length of output: 1163
🏁 Script executed:
# Check if there are any instances where annotations might be missing or string-based
rg "from __future__ import annotations"Repository: PlaytikaOSS/pybandits
Length of output: 47
🏁 Script executed:
# Check if there's any try-except around _resolve_optax_fn calls
rg "_resolve_optax_fn" -B 3 -A 8Repository: PlaytikaOSS/pybandits
Length of output: 1987
🏁 Script executed:
# Look for any eval_str usage in the codebase
rg "eval_str|get_type_hints"Repository: PlaytikaOSS/pybandits
Length of output: 1214
🏁 Script executed:
# Check imports in model.py to see how inspect is used
sed -n '1,50p' pybandits/model.py | cat -nRepository: PlaytikaOSS/pybandits
Length of output: 2502
🏁 Script executed:
# Check the exact lines where _resolve_optax_fn is called to see error handling
sed -n '1570,1630p' pybandits/model.py | cat -nRepository: PlaytikaOSS/pybandits
Length of output: 3433
🏁 Script executed:
# Let's create a simple test to understand the actual behavior
python3 << 'PY'
import inspect
# Simulate what happens with missing annotation
def func_no_annotation():
return None
# Simulate what happens with string annotation
def func_str_annotation():
return "SomeType"
func_str_annotation.__annotations__ = {"return": "SomeType"}
# Test the logic
sig = inspect.signature(func_no_annotation)
print(f"Missing annotation: {sig.return_annotation!r}")
print(f"Is type? {isinstance(sig.return_annotation, type)}")
print()
sig2 = inspect.signature(func_str_annotation)
print(f"String annotation: {sig2.return_annotation!r}")
print(f"Is type? {isinstance(sig2.return_annotation, type)}")
print(f"Equals 'SomeType'? {sig2.return_annotation == 'SomeType'}")
# What about empty
print(f"\ninspect.Signature.empty: {inspect.Signature.empty!r}")
print(f"Is type? {isinstance(inspect.Signature.empty, type)}")
PYRepository: PlaytikaOSS/pybandits
Length of output: 259
🏁 Script executed:
# This is crucial: what happens when issubclass is called with inspect._empty?
python3 << 'PY'
import inspect
empty = inspect.Signature.empty
# Try to call issubclass with inspect._empty
try:
result = issubclass(empty, object)
print(f"issubclass(inspect._empty, object) = {result}")
except Exception as e:
print(f"Error: {type(e).__name__}: {e}")
# What if we try with a real type expected?
class FakeType:
pass
try:
result = issubclass(empty, FakeType)
print(f"issubclass(inspect._empty, FakeType) = {result}")
except Exception as e:
print(f"Error: {type(e).__name__}: {e}")
PYRepository: PlaytikaOSS/pybandits
Length of output: 152
🏁 Script executed:
# Check which optax functions exist and whether they have annotations by looking at optax source docs
web_search "optax 0.1 sgd adam exponential_decay cosine_decay_schedule return annotation"Repository: PlaytikaOSS/pybandits
Length of output: 116
🏁 Script executed:
# Let's look at the complete _arrange_update_kwargs method to understand validation flow
sed -n '1450,1480p' pybandits/model.py | cat -nRepository: PlaytikaOSS/pybandits
Length of output: 1861
🏁 Script executed:
# Check if there's any docstring or comment explaining why annotation-based validation was chosen
sed -n '1110,1155p' pybandits/model.pyRepository: PlaytikaOSS/pybandits
Length of output: 1885
🏁 Script executed:
# Verify the full flow to see if there's error handling anywhere else
sed -n '1560,1600p' pybandits/model.py | cat -nRepository: PlaytikaOSS/pybandits
Length of output: 2141
Annotation-based validation in _resolve_optax_fn is fragile to changes in optax internals.
The validation depends on three brittle assumptions:
-
Every supported optax callable carries a precise return annotation. If a function lacks an annotation,
inspect.signature(fn).return_annotationreturnsinspect.Signature.empty, and theissubclasscheck fails with a misleading "does not return …" error message rather than a clear "missing annotation" error. -
Annotations are concrete objects, not strings. If a future optax release adds
from __future__ import annotations(PEP 563), return annotations become strings. Theisinstance(return_annotation, type)check fails, and the equality check also fails, causing valid optimizers/schedulers to be rejected at import time. -
Schedule annotations match
optax.Scheduleby==. This generic alias equality is brittle; if optax changes its type alias or directly annotates functions with the underlyingCallable[...], the check fails even though the function is functionally compatible.
While tests confirm that optax 0.1.x functions carry the expected annotations, a future optax upgrade (e.g., adding PEP 563 or restructuring type hints) could silently break user configs at import time without clear guidance.
Consider implementing duck-typed validation instead (e.g. instantiate with dummy arguments: fn(learning_rate=0.01) for optimizers, fn(0) for schedules) or handle missing/unparseable annotations gracefully so import-time failures do not occur for minor optax versions.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@pybandits/model.py` around lines 1135 - 1150, The current annotation-based
checks in _resolve_optax_fn (which references self._optax_return_types and optax
types like optax.Schedule / GradientTransformationExtraArgs) are brittle; change
the function to first tolerate missing or string annotations (treat
inspect.Signature.empty or str annotations as "unreliable") and then perform a
pragmatic duck-typed validation fallback: try to call the candidate fn with
small dummy inputs appropriate to the kind (e.g., for "schedule" call fn(0) or
fn(learning_rate=0.01) for optimizer factories) and verify the returned object
supports the expected protocol (e.g., has .init and .update for
GradientTransformation-like objects, or is callable/returns numbers for
schedules), catching any exceptions and raising a clear ValueError indicating
either "missing/unparseable annotation" or "callable did not produce expected
interface" and returning fn on success; keep using the name checks and callable
guard but remove strict issubclass/== checks against return_annotation.
| svi_state = svi.init(subkey, x_jnp, y_jnp, 1.0) | ||
| if kl_annealing_active: | ||
| kl_warmup_f = jnp.float32(kl_warmup_steps) | ||
|
|
||
| def _svi_body(state, step): # step is a traced int32 scalar (global index) | ||
| kl_annealing_factor = jnp.minimum(1.0, (step.astype(jnp.float32) + 1.0) / kl_warmup_f) | ||
| state, loss = svi.update(state, x_jnp, y_jnp, kl_annealing_factor) | ||
| return state, loss | ||
|
|
||
| _run_epoch = jax.jit( | ||
| lambda state, steps: jax.lax.scan(_svi_body, state, steps), | ||
| ) # steps is a 1-D int32 array; scan feeds each element as xs to _svi_body |
There was a problem hiding this comment.
KL annealing factor never reaches 0; starts at 1/kl_warmup_steps.
The PR description states that the annealing factor "linearly warms the KL term in the ELBO from 0 to 1 over the first kl_annealing_fraction * total_steps steps". The current formula is
kl_annealing_factor = jnp.minimum(1.0, (step.astype(jnp.float32) + 1.0) / kl_warmup_f)So at step=0 the factor is 1/kl_warmup_steps (and at kl_warmup_steps == 1 the factor is already 1.0, i.e. effectively no warmup). If the goal is "from 0 to 1", drop the + 1.0 so the first step starts at exactly 0 and the factor reaches 1.0 only after kl_warmup_steps updates:
♻️ Suggested fix
- def _svi_body(state, step): # step is a traced int32 scalar (global index)
- kl_annealing_factor = jnp.minimum(1.0, (step.astype(jnp.float32) + 1.0) / kl_warmup_f)
+ def _svi_body(state, step): # step is a traced int32 scalar (global index)
+ kl_annealing_factor = jnp.minimum(1.0, step.astype(jnp.float32) / kl_warmup_f)
state, loss = svi.update(state, x_jnp, y_jnp, kl_annealing_factor)
return state, lossIf the off-by-one is intentional, please update the PR description / _kl_scale_ctx docstring to reflect that the factor is in (0, 1] and starts at 1/N.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@pybandits/model.py` around lines 2037 - 2048, The kl annealing currently
computes kl_annealing_factor in _svi_body as (step.astype(jnp.float32) + 1.0) /
kl_warmup_f which makes the first step > 0; change the formula to
(step.astype(jnp.float32)) / kl_warmup_f so the first step yields 0 and the
factor linearly reaches 1.0 after kl_warmup_f steps; update the expression
computing kl_annealing_factor in _svi_body (which is passed into svi.update)
accordingly and ensure jnp.minimum(1.0, ...) is still applied.
|
|
||
| _run_epoch = jax.jit( | ||
| lambda state, n: jax.lax.scan(_svi_body, state, None, length=n), | ||
| static_argnums=(1,), | ||
| ) | ||
| _run_epoch = jax.jit( | ||
| lambda state, n: jax.lax.scan(_svi_body, state, None, length=n), | ||
| static_argnums=(1,), | ||
| ) |
There was a problem hiding this comment.
static_argnums=(1,) will recompile when an epoch has a different epoch_steps.
In the inactive-annealing branch, _run_epoch is JIT-compiled with n marked static. The epoch-construction logic above (lines 1996-2002) appends a final remaining epoch when num_steps % steps_per_epoch != 0, so the last _run_epoch call uses a different n and triggers an additional XLA compilation — defeating the very reason the comment cites for moving to lax.scan. This is a minor performance issue (one extra compile per training run) but worth noting.
A simple mitigation is to align epoch_steps_list to a uniform step count (drop or absorb the remainder), or to pass n as an int32 array length argument that matches the active-path style.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@pybandits/model.py` around lines 2056 - 2060, The JIT for _run_epoch
currently marks the epoch length argument `n` static (static_argnums=(1,)), so
when the final epoch uses a different step count (because epoch_steps_list can
include a smaller `remaining` epoch) it triggers an extra XLA compile; change to
use a non-static, numeric-length argument or make all epochs the same length.
Concretely: either remove `n` from static_argnums on the jax.jit of _run_epoch
(so pass an int32/array length like the active path does) or normalize
`epoch_steps_list` (drop/absorb the remainder so every entry equals
steps_per_epoch) so _run_epoch is always called with the same `n`; update
references to _run_epoch, _svi_body, epoch_steps_list, num_steps and
steps_per_epoch accordingly.
|
|
||
| def model(x: jax.Array, y: jax.Array): | ||
| N = x.shape[0] | ||
| def model(x: jax.Array, y: jax.Array, kl_annealing_factor=1.0): |
There was a problem hiding this comment.
Add Float01 typehint for kl_annealing_factor
| total_steps = sum(epoch_steps_list) | ||
| kl_annealing_fraction = self._update_kwargs.get("kl_annealing_fraction") | ||
| kl_annealing_active = kl_annealing_fraction not in (None, 0.0) | ||
| kl_warmup_steps = max(1, int(np.ceil(kl_annealing_fraction * total_steps))) if kl_annealing_active else None |
There was a problem hiding this comment.
I think you can drop the np.ceil and just directly apply int().
| if kl_tau_active or kl_annealing_active: | ||
| kl_scale = self._kl_scale_ctx | ||
|
|
||
| def guide(x, y, kl_annealing_factor, *args, **kwargs): |
| n_samples : PositiveInt | ||
| Number of data points in the current batch. | ||
| kl_annealing_factor : float or jax scalar, optional | ||
| Multiplicative factor in [0, 1]. Default 1.0 (no annealing). |
There was a problem hiding this comment.
I think you told me it is in (0, 1].
| global_step_offset = 0 | ||
| for epoch_idx, epoch_steps in enumerate(epoch_steps_list): | ||
| svi_state, epoch_losses = _run_epoch(svi_state, epoch_steps) | ||
| if kl_annealing_active: |
There was a problem hiding this comment.
Instead of calculating steps, wouldn't it be better to precompute here the kl_factor and pass it to run_epoch?
| # --------------------------------------------------------------------------- | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("bad_value", [-0.1, 1.5, 2.0, -1.0]) |
There was a problem hiding this comment.
Put all these tests on a designated class.
| @pytest.mark.parametrize("bad_value", [-0.1, 1.5, 2.0, -1.0]) | ||
| def test_kl_annealing_fraction_range_validation(bad_value: float): | ||
| """Out-of-range kl_annealing_fraction raises ValueError at construction time.""" | ||
| with pytest.raises(ValueError, match="kl_annealing_fraction"): |
There was a problem hiding this comment.
Please avoid hard-coded constants in the tests to make them easily-maintainable.
Try using pytest.mark.parameterize and hypothesis given for better coverage.
Add KL annealing to BNN VI training
Summary
kl_annealing_fractionas a new VIupdate_kwargsparameter forBayesianNeuralNetwork. When set, the KL divergence term in the ELBO is linearly warmed up from0 → 1over the firstkl_annealing_fraction * total_stepstraining steps, then held at1.0for the remainder.modelandguidecallables, so the warm-up schedule runs entirely insidelax.scanwith no Python-loop overhead or XLA recompilation per step.kl_tautemperature scale when both are active:effective_scale = kl_base_factor * kl_annealing_factor.kl_annealing_fraction=Noneand=0.0are both treated as inactive (no-op), validated at construction time.Test plan
[0, 1]Noneand0.0produce equivalent training behavior (both inactive paths)num_particles,gradient_clip_norm,kl_tau,lr_schedulerSummary by CodeRabbit
New Features
Chores