Skip to content

Add KL annealing to BNN VI training #137

Open
Moritesh wants to merge 2 commits intodevelopfrom
feature/kl_annealing
Open

Add KL annealing to BNN VI training #137
Moritesh wants to merge 2 commits intodevelopfrom
feature/kl_annealing

Conversation

@Moritesh
Copy link
Copy Markdown
Collaborator

@Moritesh Moritesh commented Apr 26, 2026

Add KL annealing to BNN VI training

Note: This PR is based on the VI training enhancements from PR #132 (num_particles, gradient clipping, kl_tau, lr scheduler) and should be rebased & merged only after #132 lands on develop.

Summary

  • Introduces kl_annealing_fraction as a new VI update_kwargs parameter for BayesianNeuralNetwork. When set, the KL divergence term in the ELBO is linearly warmed up from 0 → 1 over the first kl_annealing_fraction * total_steps training steps, then held at 1.0 for the remainder.
  • The annealing factor is threaded as a traced JAX scalar through the model and guide callables, so the warm-up schedule runs entirely inside lax.scan with no Python-loop overhead or XLA recompilation per step.
  • Composes multiplicatively with the existing kl_tau temperature scale when both are active: effective_scale = kl_base_factor * kl_annealing_factor.
  • kl_annealing_fraction=None and =0.0 are both treated as inactive (no-op), validated at construction time.

Test plan

  • Validation rejects values outside [0, 1]
  • None and 0.0 produce equivalent training behavior (both inactive paths)
  • Hypothesis-based test covers the full cross-product with existing VI options: num_particles, gradient_clip_norm, kl_tau, lr_scheduler

Summary by CodeRabbit

  • New Features

    • Added Optax optimizer support with optional learning rate scheduling and gradient clipping configuration.
    • Introduced KL annealing framework for variational inference with configurable annealing parameters.
  • Chores

    • Version bumped to 6.1.0.
    • Added Optax as a runtime dependency.

shaharbar1 and others added 2 commits April 6, 2026 14:22
…perature scaling

### Changes:
* Add `num_particles`, `gradient_clip_norm`, and `kl_tau` to VI `update_kwargs`
  - `num_particles` is passed to `TraceMeanField_ELBO`/`Trace_ELBO` for multi-particle gradient estimates
  - `gradient_clip_norm` chains `optax.clip_by_global_norm` before the base optimizer
  - `kl_tau` scales prior log-prob by `tau * N_data / N_neurons` via `numpyro.handlers.scale` (Huix et al. 2022)
* Switch all optimizers from numpyro-native to optax via `optax_to_numpyro`, removing the dual-path logic
* Remove `clipped_adam`, `momentum`, and `reduce_on_plateau` scheduler
* Remove `batch_size` parameter from `create_update_model`; read from `_update_kwargs` instead
* Add `optax` to `pyproject.toml` dependencies
* Introduces `kl_annealing_fraction` as a new VI update kwarg for BayesianNeuralNetwork. When set, the KL divergence term in the ELBO is linearly warmed up from 0 → 1 over the first `kl_annealing_fraction * total_steps` training steps, then held at 1.0 for the remainder.
* The annealing factor is threaded as a traced JAX scalar through the `model` and `guide` callables so the warm-up schedule runs entirely inside `lax.scan` without Python-loop overhead or recompilation. The factor composes multiplicatively with the existing `kl_tau` temperature scale when both are active.
* Validation rejects values outside [0, 1]. Hypothesis-based tests cover the full cross-product with existing VI options (num_particles, gradient_clip_norm, kl_tau, lr_scheduler).
@Moritesh Moritesh requested a review from shaharbar1 April 26, 2026 14:39
@coderabbitai
Copy link
Copy Markdown

coderabbitai Bot commented Apr 26, 2026

📝 Walkthrough

Walkthrough

The changes introduce optax-based optimizer construction with optional learning rate scheduling and gradient clipping, add KL control configuration parameters for variational inference, and implement KL annealing in the SVI training loop. The project version increments to 6.1.0 with optax added as a dependency.

Changes

Cohort / File(s) Summary
Core Model Implementation
pybandits/model.py
Introduces optax-based optimizer selection via _resolve_optax_fn and _get_obj_optimizer, supporting composition of LR schedulers and gradient clipping. Adds KL control configuration (kl_tau, kl_annealing_fraction) with validation. Implements KL scaling via _kl_scale_ctx in model/guide and integrates per-step KL annealing in SVI training loop using jax.lax.scan. Returns unwrapped guide and updates early-stopping logic.
Project Configuration
pyproject.toml
Updates version from 6.0.1 to 6.1.0 and adds optax ^0.1 as a runtime dependency.
Test Coverage
tests/test_model.py
Updates existing tests to use _create_update_model() API and standardizes reward generation. Extends categorical-BNN helper to accept update_kwargs. Adds comprehensive coverage for optimizer/scheduler composition, KL annealing range validation, equivalence checking, and serialization preservation. Includes property-based end-to-end tests combining multiple VI options.

Sequence Diagram

sequenceDiagram
    participant Training as Training Loop
    participant SVI as SVI Handler
    participant Guide as Guide/Model
    participant KLScale as KL Scale Handler
    participant Loss as Loss Computation

    Training->>Training: Initialize with kl_annealing_factor=1.0
    Training->>Training: For each step, compute annealing factor
    
    Note over Training: If kl_annealing enabled, use jax.lax.scan
    Training->>SVI: Call svi.update(data, kl_annealing_factor)
    
    SVI->>KLScale: Wrap guide with scale handler
    SVI->>Guide: Sample parameters with KL scaling
    Guide->>KLScale: Apply scaling to parameter sampling
    KLScale-->>Guide: Scaled samples returned
    Guide->>Loss: Forward pass with scaled samples
    Loss-->>SVI: Compute ELBO loss
    
    SVI-->>Training: Return loss scalar
    Training->>Training: Update step, continue iteration
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Poem

🐰 Optax comes hopping to aid our stride,
With schedulers dancing, and clipping guide,
KL annealing scales with every step,
The SVI loop gains vigor and pep! 🌟

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 63.16% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly summarizes the main change: introducing KL annealing to BNN VI training, which aligns with the primary feature additions across the codebase.
Description check ✅ Passed The description provides comprehensive details on the feature (kl_annealing_fraction parameter, linear warm-up, composition with kl_tau, validation rules), implementation approach (traced JAX scalar, lax.scan), and test plans, matching and exceeding the template structure.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch feature/kl_annealing
⚔️ Resolve merge conflicts
  • Resolve merge conflict in branch feature/kl_annealing

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (4)
tests/test_model.py (3)

1849-1852: assert_array_equal on float losses across two SVI runs is brittle.

Even though both branches follow the same code path (kl_annealing_active = False), comparing bnn_none._approx_history and bnn_zero._approx_history with assert_array_equal requires bit-identical floats across two independent JAX traces/compiles. JAX/XLA generally honors this for the same code on the same device, but trivial differences (e.g. CPU vs. GPU CI runners, future XLA layout changes, parallel reduction orderings) can produce ULP-level deltas and make the test flaky.

np.testing.assert_allclose(..., rtol=0, atol=0) is identical in semantics today but lets you loosen later without rewriting the assertion; alternatively assert_allclose(..., rtol=1e-6) would be a safer floor.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/test_model.py` around lines 1849 - 1852, Replace the brittle bit-exact
comparison between bnn_none._approx_history and bnn_zero._approx_history with a
tolerance-based comparison: change the np.testing.assert_array_equal(...) call
to np.testing.assert_allclose(...) (or assert_allclose(..., rtol=1e-6, atol=0) /
at least assert_allclose(..., rtol=0, atol=0) to allow future loosening) so
floating-point ULP differences from separate JAX SVI runs won’t make the test
flaky.

1822-1894: KL annealing tests don't exercise the probability path.

These four tests cover construction, equivalence of inactive paths, divergence vs. kl_tau-only, and serialization — but none of them call bnn.sample_proba(...) after bnn.update(...) to verify the trained model still produces well-formed [0, 1] probabilities. Since KL annealing directly affects the posterior the predictive distribution is drawn from, a smoke check on the probability output would catch regressions where annealing leaves mu/sigma in a bad state (NaN, infinite, out-of-range).

Consider augmenting at least test_kl_annealing_interacts_with_kl_tau (and ideally the round-trip test) with the same probability/finiteness assertions used in test_vi_training_options (lines 900-902):

result = bnn_tau_anneal.sample_proba(context=context)
assert all(0 <= p[0] <= 1 for p in result)
assert all(np.isfinite(p[1]) for p in result)

As per coding guidelines: **/*test*.py: Always apply callable cost and probability logic testing.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/test_model.py` around lines 1822 - 1894, Add probability/finiteness
smoke checks after updating models in the KL-annealing tests: call sample_proba
on the trained instances (e.g., bnn_tau_anneal and bnn_tau_only in
test_kl_annealing_interacts_with_kl_tau, and bnn or bnn2 in
test_kl_annealing_fraction_serialization_round_trip) and assert outputs are
valid probabilities and finite (e.g., each p in result has 0<=p[0]<=1 and p[1]
is finite). This mirrors the checks in test_vi_training_options and ensures
kl_annealing doesn't produce NaN/inf or out-of-range probabilities. Ensure you
call sample_proba after the update() or after reconstruction so the smoke check
validates the trained/post-serialization model.

793-811: test_lr_scheduler_valid only checks the optimizer is constructed; no end-to-end exercise.

The test confirms bnn._obj_optimizer is not None but never runs bnn.update(...) or bnn.sample_proba(...). That means a scheduler that builds successfully but produces NaN gradients, mis-shaped learning rates, or otherwise breaks SVI updates would still pass. Given the coding guidelines (**/*test*.py: Always apply callable cost and probability logic testing) and the fact that test_vi_training_options already exercises a tiny update + sample_proba round-trip, mirroring that pattern here would noticeably strengthen the LR-scheduler coverage.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/test_model.py` around lines 793 - 811, Extend test_lr_scheduler_valid
to perform a small end-to-end exercise after constructing the BNN: call
BayesianNeuralNetwork.cold_start as currently done, then run a minimal update
(e.g., bnn.update with a tiny synthetic batch or the same fixtures used in
test_vi_training_options) and invoke bnn.sample_proba on a small input to ensure
training and inference succeed; assert the update returns a finite loss (not
NaN/inf) and that sample_proba returns probabilities of the expected shape and
finite values. This uses the existing test_lr_scheduler_valid, bnn.update, and
bnn.sample_proba symbols so the scheduler is validated during an actual SVI step
and inference round-trip.
pybandits/model.py (1)

1530-1530: n_samples: PositiveInt is misleading — callers pass a traced JAX scalar.

_create_update_model.model invokes kl_scale(n_samples, kl_annealing_factor) with n_samples = x.shape[0], which is a traced JAX integer when the function is JIT-compiled (e.g. inside lax.scan). A traced scalar is not a Python int and won't satisfy any PositiveInt runtime validator. There is no validate_call decorator here so it does not actually fail today, but the annotation is misleading and will trip up anyone who later adds Pydantic validation.

Consider widening the annotation to something like Union[PositiveInt, jax.Array] or simply int with a comment noting traced scalars are accepted.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@pybandits/model.py` at line 1530, The type annotation for _kl_scale_ctx's
parameter n_samples is misleading because callers pass JAX-traced scalars (e.g.
from _create_update_model.model calling kl_scale with x.shape[0]), so change the
annotation on def _kl_scale_ctx to accept traced JAX values—either widen it to
Union[PositiveInt, jax.Array] or to plain int—and add a short comment on the
parameter stating that JAX traced scalars (jax.Array) are allowed to avoid
future Pydantic/validation surprises; ensure related uses such as kl_scale and
calls in _create_update_model.model remain type-compatible.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@pybandits/model.py`:
- Around line 2037-2048: The kl annealing currently computes kl_annealing_factor
in _svi_body as (step.astype(jnp.float32) + 1.0) / kl_warmup_f which makes the
first step > 0; change the formula to (step.astype(jnp.float32)) / kl_warmup_f
so the first step yields 0 and the factor linearly reaches 1.0 after kl_warmup_f
steps; update the expression computing kl_annealing_factor in _svi_body (which
is passed into svi.update) accordingly and ensure jnp.minimum(1.0, ...) is still
applied.
- Around line 2056-2060: The JIT for _run_epoch currently marks the epoch length
argument `n` static (static_argnums=(1,)), so when the final epoch uses a
different step count (because epoch_steps_list can include a smaller `remaining`
epoch) it triggers an extra XLA compile; change to use a non-static,
numeric-length argument or make all epochs the same length. Concretely: either
remove `n` from static_argnums on the jax.jit of _run_epoch (so pass an
int32/array length like the active path does) or normalize `epoch_steps_list`
(drop/absorb the remainder so every entry equals steps_per_epoch) so _run_epoch
is always called with the same `n`; update references to _run_epoch, _svi_body,
epoch_steps_list, num_steps and steps_per_epoch accordingly.
- Around line 1135-1150: The current annotation-based checks in
_resolve_optax_fn (which references self._optax_return_types and optax types
like optax.Schedule / GradientTransformationExtraArgs) are brittle; change the
function to first tolerate missing or string annotations (treat
inspect.Signature.empty or str annotations as "unreliable") and then perform a
pragmatic duck-typed validation fallback: try to call the candidate fn with
small dummy inputs appropriate to the kind (e.g., for "schedule" call fn(0) or
fn(learning_rate=0.01) for optimizer factories) and verify the returned object
supports the expected protocol (e.g., has .init and .update for
GradientTransformation-like objects, or is callable/returns numbers for
schedules), catching any exceptions and raising a clear ValueError indicating
either "missing/unparseable annotation" or "callable did not produce expected
interface" and returning fn on success; keep using the name checks and callable
guard but remove strict issubclass/== checks against return_annotation.

---

Nitpick comments:
In `@pybandits/model.py`:
- Line 1530: The type annotation for _kl_scale_ctx's parameter n_samples is
misleading because callers pass JAX-traced scalars (e.g. from
_create_update_model.model calling kl_scale with x.shape[0]), so change the
annotation on def _kl_scale_ctx to accept traced JAX values—either widen it to
Union[PositiveInt, jax.Array] or to plain int—and add a short comment on the
parameter stating that JAX traced scalars (jax.Array) are allowed to avoid
future Pydantic/validation surprises; ensure related uses such as kl_scale and
calls in _create_update_model.model remain type-compatible.

In `@tests/test_model.py`:
- Around line 1849-1852: Replace the brittle bit-exact comparison between
bnn_none._approx_history and bnn_zero._approx_history with a tolerance-based
comparison: change the np.testing.assert_array_equal(...) call to
np.testing.assert_allclose(...) (or assert_allclose(..., rtol=1e-6, atol=0) / at
least assert_allclose(..., rtol=0, atol=0) to allow future loosening) so
floating-point ULP differences from separate JAX SVI runs won’t make the test
flaky.
- Around line 1822-1894: Add probability/finiteness smoke checks after updating
models in the KL-annealing tests: call sample_proba on the trained instances
(e.g., bnn_tau_anneal and bnn_tau_only in
test_kl_annealing_interacts_with_kl_tau, and bnn or bnn2 in
test_kl_annealing_fraction_serialization_round_trip) and assert outputs are
valid probabilities and finite (e.g., each p in result has 0<=p[0]<=1 and p[1]
is finite). This mirrors the checks in test_vi_training_options and ensures
kl_annealing doesn't produce NaN/inf or out-of-range probabilities. Ensure you
call sample_proba after the update() or after reconstruction so the smoke check
validates the trained/post-serialization model.
- Around line 793-811: Extend test_lr_scheduler_valid to perform a small
end-to-end exercise after constructing the BNN: call
BayesianNeuralNetwork.cold_start as currently done, then run a minimal update
(e.g., bnn.update with a tiny synthetic batch or the same fixtures used in
test_vi_training_options) and invoke bnn.sample_proba on a small input to ensure
training and inference succeed; assert the update returns a finite loss (not
NaN/inf) and that sample_proba returns probabilities of the expected shape and
finite values. This uses the existing test_lr_scheduler_valid, bnn.update, and
bnn.sample_proba symbols so the scheduler is validated during an actual SVI step
and inference round-trip.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 343828ea-f7fd-4f54-8772-c6257cdf38ff

📥 Commits

Reviewing files that changed from the base of the PR and between 7af2bb0 and 5a7421e.

📒 Files selected for processing (3)
  • pybandits/model.py
  • pyproject.toml
  • tests/test_model.py

Comment thread pybandits/model.py
Comment on lines +1135 to +1150
fn = getattr(optax, name, None)
if fn is None or not callable(fn):
raise ValueError(f"Invalid {kind}: '{name}' is not a callable attribute of optax.")
expected = self._optax_return_types[kind]
return_annotation = inspect.signature(fn).return_annotation
if isinstance(expected, type):
# e.g. GradientTransformationExtraArgs — check via issubclass
valid = isinstance(return_annotation, type) and issubclass(return_annotation, expected)
else:
# e.g. optax.Schedule (a generic alias) — check for equality
valid = return_annotation == expected
if not valid:
raise ValueError(
f"Invalid {kind}: '{name}' does not return {expected} (got return annotation: {return_annotation})."
)
return fn
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Verify that the optax functions exercised in tests still carry the expected
# concrete return annotations on the pinned optax 0.1.x range.
python - <<'PY'
import inspect, optax
for name in [
    "sgd", "adam",
    "exponential_decay", "cosine_decay_schedule",
    "warmup_cosine_decay_schedule", "linear_schedule",
]:
    fn = getattr(optax, name, None)
    ann = inspect.signature(fn).return_annotation if fn else "<missing>"
    print(f"{name}: {ann!r} (type={type(ann).__name__})")
print("optax.Schedule =", optax.Schedule)
print("optax.GradientTransformation =", optax.GradientTransformation)
PY

Repository: PlaytikaOSS/pybandits

Length of output: 183


🏁 Script executed:

# First, explore the structure and find the requirements
fd -type f -name "setup.py" -o -name "pyproject.toml" -o -name "requirements*.txt"

Repository: PlaytikaOSS/pybandits

Length of output: 236


🏁 Script executed:

# Read the relevant section of model.py around lines 1135-1150
sed -n '1100,1160p' pybandits/model.py | cat -n

Repository: PlaytikaOSS/pybandits

Length of output: 2847


🏁 Script executed:

# Check the full _resolve_optax_fn method and surrounding context
sed -n '1110,1160p' pybandits/model.py | cat -n

Repository: PlaytikaOSS/pybandits

Length of output: 2406


🏁 Script executed:

find . -type f \( -name "setup.py" -o -name "pyproject.toml" \)

Repository: PlaytikaOSS/pybandits

Length of output: 82


🏁 Script executed:

# Search for optax version constraints
rg "optax" --type python -i | grep -i "version\|require" | head -20

Repository: PlaytikaOSS/pybandits

Length of output: 96


🏁 Script executed:

# Look for tests of _resolve_optax_fn
rg "_resolve_optax_fn|optimizer_name|lr_scheduler_name" --type python

Repository: PlaytikaOSS/pybandits

Length of output: 96


🏁 Script executed:

# Read pyproject.toml for dependencies
cat pyproject.toml

Repository: PlaytikaOSS/pybandits

Length of output: 3089


🏁 Script executed:

# Search for tests and usage of resolve_optax_fn without type filter
rg "_resolve_optax_fn" -A 5 -B 2

Repository: PlaytikaOSS/pybandits

Length of output: 1359


🏁 Script executed:

# Search for optimizer_name and lr_scheduler_name usage (which triggers _resolve_optax_fn)
rg "optimizer_name|lr_scheduler_name" -i -A 3 -B 1

Repository: PlaytikaOSS/pybandits

Length of output: 47


🏁 Script executed:

# Find test files
find . -name "*test*.py" -type f

Repository: PlaytikaOSS/pybandits

Length of output: 523


🏁 Script executed:

# Search for tests that use optimizer_type or lr_scheduler_type
rg "optimizer_type|lr_scheduler_type" -A 2 -B 1

Repository: PlaytikaOSS/pybandits

Length of output: 12176


🏁 Script executed:

# Look for actual test cases involving BayesianNeuralNetworkVI or similar models
rg "BayesianNeuralNetworkVI|adam|sgd" --max-count=20

Repository: PlaytikaOSS/pybandits

Length of output: 2209


🏁 Script executed:

# Read the test_model.py file to see actual tests
sed -n '1,100p' tests/test_model.py | head -50

Repository: PlaytikaOSS/pybandits

Length of output: 1855


🏁 Script executed:

# Look for the test with dummy_optimizer that should fail
rg "dummy_optimizer" -A 5 -B 5

Repository: PlaytikaOSS/pybandits

Length of output: 1163


🏁 Script executed:

# Check if there are any instances where annotations might be missing or string-based
rg "from __future__ import annotations"

Repository: PlaytikaOSS/pybandits

Length of output: 47


🏁 Script executed:

# Check if there's any try-except around _resolve_optax_fn calls
rg "_resolve_optax_fn" -B 3 -A 8

Repository: PlaytikaOSS/pybandits

Length of output: 1987


🏁 Script executed:

# Look for any eval_str usage in the codebase
rg "eval_str|get_type_hints"

Repository: PlaytikaOSS/pybandits

Length of output: 1214


🏁 Script executed:

# Check imports in model.py to see how inspect is used
sed -n '1,50p' pybandits/model.py | cat -n

Repository: PlaytikaOSS/pybandits

Length of output: 2502


🏁 Script executed:

# Check the exact lines where _resolve_optax_fn is called to see error handling
sed -n '1570,1630p' pybandits/model.py | cat -n

Repository: PlaytikaOSS/pybandits

Length of output: 3433


🏁 Script executed:

# Let's create a simple test to understand the actual behavior
python3 << 'PY'
import inspect

# Simulate what happens with missing annotation
def func_no_annotation():
    return None

# Simulate what happens with string annotation
def func_str_annotation():
    return "SomeType"
func_str_annotation.__annotations__ = {"return": "SomeType"}

# Test the logic
sig = inspect.signature(func_no_annotation)
print(f"Missing annotation: {sig.return_annotation!r}")
print(f"Is type? {isinstance(sig.return_annotation, type)}")
print()

sig2 = inspect.signature(func_str_annotation)
print(f"String annotation: {sig2.return_annotation!r}")
print(f"Is type? {isinstance(sig2.return_annotation, type)}")
print(f"Equals 'SomeType'? {sig2.return_annotation == 'SomeType'}")

# What about empty
print(f"\ninspect.Signature.empty: {inspect.Signature.empty!r}")
print(f"Is type? {isinstance(inspect.Signature.empty, type)}")
PY

Repository: PlaytikaOSS/pybandits

Length of output: 259


🏁 Script executed:

# This is crucial: what happens when issubclass is called with inspect._empty?
python3 << 'PY'
import inspect

empty = inspect.Signature.empty

# Try to call issubclass with inspect._empty
try:
    result = issubclass(empty, object)
    print(f"issubclass(inspect._empty, object) = {result}")
except Exception as e:
    print(f"Error: {type(e).__name__}: {e}")

# What if we try with a real type expected?
class FakeType:
    pass

try:
    result = issubclass(empty, FakeType)
    print(f"issubclass(inspect._empty, FakeType) = {result}")
except Exception as e:
    print(f"Error: {type(e).__name__}: {e}")
PY

Repository: PlaytikaOSS/pybandits

Length of output: 152


🏁 Script executed:

# Check which optax functions exist and whether they have annotations by looking at optax source docs
web_search "optax 0.1 sgd adam exponential_decay cosine_decay_schedule return annotation"

Repository: PlaytikaOSS/pybandits

Length of output: 116


🏁 Script executed:

# Let's look at the complete _arrange_update_kwargs method to understand validation flow
sed -n '1450,1480p' pybandits/model.py | cat -n

Repository: PlaytikaOSS/pybandits

Length of output: 1861


🏁 Script executed:

# Check if there's any docstring or comment explaining why annotation-based validation was chosen
sed -n '1110,1155p' pybandits/model.py

Repository: PlaytikaOSS/pybandits

Length of output: 1885


🏁 Script executed:

# Verify the full flow to see if there's error handling anywhere else
sed -n '1560,1600p' pybandits/model.py | cat -n

Repository: PlaytikaOSS/pybandits

Length of output: 2141


Annotation-based validation in _resolve_optax_fn is fragile to changes in optax internals.

The validation depends on three brittle assumptions:

  1. Every supported optax callable carries a precise return annotation. If a function lacks an annotation, inspect.signature(fn).return_annotation returns inspect.Signature.empty, and the issubclass check fails with a misleading "does not return …" error message rather than a clear "missing annotation" error.

  2. Annotations are concrete objects, not strings. If a future optax release adds from __future__ import annotations (PEP 563), return annotations become strings. The isinstance(return_annotation, type) check fails, and the equality check also fails, causing valid optimizers/schedulers to be rejected at import time.

  3. Schedule annotations match optax.Schedule by ==. This generic alias equality is brittle; if optax changes its type alias or directly annotates functions with the underlying Callable[...], the check fails even though the function is functionally compatible.

While tests confirm that optax 0.1.x functions carry the expected annotations, a future optax upgrade (e.g., adding PEP 563 or restructuring type hints) could silently break user configs at import time without clear guidance.

Consider implementing duck-typed validation instead (e.g. instantiate with dummy arguments: fn(learning_rate=0.01) for optimizers, fn(0) for schedules) or handle missing/unparseable annotations gracefully so import-time failures do not occur for minor optax versions.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@pybandits/model.py` around lines 1135 - 1150, The current annotation-based
checks in _resolve_optax_fn (which references self._optax_return_types and optax
types like optax.Schedule / GradientTransformationExtraArgs) are brittle; change
the function to first tolerate missing or string annotations (treat
inspect.Signature.empty or str annotations as "unreliable") and then perform a
pragmatic duck-typed validation fallback: try to call the candidate fn with
small dummy inputs appropriate to the kind (e.g., for "schedule" call fn(0) or
fn(learning_rate=0.01) for optimizer factories) and verify the returned object
supports the expected protocol (e.g., has .init and .update for
GradientTransformation-like objects, or is callable/returns numbers for
schedules), catching any exceptions and raising a clear ValueError indicating
either "missing/unparseable annotation" or "callable did not produce expected
interface" and returning fn on success; keep using the name checks and callable
guard but remove strict issubclass/== checks against return_annotation.

Comment thread pybandits/model.py
Comment on lines +2037 to +2048
svi_state = svi.init(subkey, x_jnp, y_jnp, 1.0)
if kl_annealing_active:
kl_warmup_f = jnp.float32(kl_warmup_steps)

def _svi_body(state, step): # step is a traced int32 scalar (global index)
kl_annealing_factor = jnp.minimum(1.0, (step.astype(jnp.float32) + 1.0) / kl_warmup_f)
state, loss = svi.update(state, x_jnp, y_jnp, kl_annealing_factor)
return state, loss

_run_epoch = jax.jit(
lambda state, steps: jax.lax.scan(_svi_body, state, steps),
) # steps is a 1-D int32 array; scan feeds each element as xs to _svi_body
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

KL annealing factor never reaches 0; starts at 1/kl_warmup_steps.

The PR description states that the annealing factor "linearly warms the KL term in the ELBO from 0 to 1 over the first kl_annealing_fraction * total_steps steps". The current formula is

kl_annealing_factor = jnp.minimum(1.0, (step.astype(jnp.float32) + 1.0) / kl_warmup_f)

So at step=0 the factor is 1/kl_warmup_steps (and at kl_warmup_steps == 1 the factor is already 1.0, i.e. effectively no warmup). If the goal is "from 0 to 1", drop the + 1.0 so the first step starts at exactly 0 and the factor reaches 1.0 only after kl_warmup_steps updates:

♻️ Suggested fix
-            def _svi_body(state, step):  # step is a traced int32 scalar (global index)
-                kl_annealing_factor = jnp.minimum(1.0, (step.astype(jnp.float32) + 1.0) / kl_warmup_f)
+            def _svi_body(state, step):  # step is a traced int32 scalar (global index)
+                kl_annealing_factor = jnp.minimum(1.0, step.astype(jnp.float32) / kl_warmup_f)
                 state, loss = svi.update(state, x_jnp, y_jnp, kl_annealing_factor)
                 return state, loss

If the off-by-one is intentional, please update the PR description / _kl_scale_ctx docstring to reflect that the factor is in (0, 1] and starts at 1/N.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@pybandits/model.py` around lines 2037 - 2048, The kl annealing currently
computes kl_annealing_factor in _svi_body as (step.astype(jnp.float32) + 1.0) /
kl_warmup_f which makes the first step > 0; change the formula to
(step.astype(jnp.float32)) / kl_warmup_f so the first step yields 0 and the
factor linearly reaches 1.0 after kl_warmup_f steps; update the expression
computing kl_annealing_factor in _svi_body (which is passed into svi.update)
accordingly and ensure jnp.minimum(1.0, ...) is still applied.

Comment thread pybandits/model.py
Comment on lines 2056 to +2060

_run_epoch = jax.jit(
lambda state, n: jax.lax.scan(_svi_body, state, None, length=n),
static_argnums=(1,),
)
_run_epoch = jax.jit(
lambda state, n: jax.lax.scan(_svi_body, state, None, length=n),
static_argnums=(1,),
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

static_argnums=(1,) will recompile when an epoch has a different epoch_steps.

In the inactive-annealing branch, _run_epoch is JIT-compiled with n marked static. The epoch-construction logic above (lines 1996-2002) appends a final remaining epoch when num_steps % steps_per_epoch != 0, so the last _run_epoch call uses a different n and triggers an additional XLA compilation — defeating the very reason the comment cites for moving to lax.scan. This is a minor performance issue (one extra compile per training run) but worth noting.

A simple mitigation is to align epoch_steps_list to a uniform step count (drop or absorb the remainder), or to pass n as an int32 array length argument that matches the active-path style.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@pybandits/model.py` around lines 2056 - 2060, The JIT for _run_epoch
currently marks the epoch length argument `n` static (static_argnums=(1,)), so
when the final epoch uses a different step count (because epoch_steps_list can
include a smaller `remaining` epoch) it triggers an extra XLA compile; change to
use a non-static, numeric-length argument or make all epochs the same length.
Concretely: either remove `n` from static_argnums on the jax.jit of _run_epoch
(so pass an int32/array length like the active path does) or normalize
`epoch_steps_list` (drop/absorb the remainder so every entry equals
steps_per_epoch) so _run_epoch is always called with the same `n`; update
references to _run_epoch, _svi_body, epoch_steps_list, num_steps and
steps_per_epoch accordingly.

Comment thread pybandits/model.py

def model(x: jax.Array, y: jax.Array):
N = x.shape[0]
def model(x: jax.Array, y: jax.Array, kl_annealing_factor=1.0):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add Float01 typehint for kl_annealing_factor

Comment thread pybandits/model.py
total_steps = sum(epoch_steps_list)
kl_annealing_fraction = self._update_kwargs.get("kl_annealing_fraction")
kl_annealing_active = kl_annealing_fraction not in (None, 0.0)
kl_warmup_steps = max(1, int(np.ceil(kl_annealing_fraction * total_steps))) if kl_annealing_active else None
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can drop the np.ceil and just directly apply int().

Comment thread pybandits/model.py
if kl_tau_active or kl_annealing_active:
kl_scale = self._kl_scale_ctx

def guide(x, y, kl_annealing_factor, *args, **kwargs):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add typehints

Comment thread pybandits/model.py
n_samples : PositiveInt
Number of data points in the current batch.
kl_annealing_factor : float or jax scalar, optional
Multiplicative factor in [0, 1]. Default 1.0 (no annealing).
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you told me it is in (0, 1].

Comment thread pybandits/model.py
global_step_offset = 0
for epoch_idx, epoch_steps in enumerate(epoch_steps_list):
svi_state, epoch_losses = _run_epoch(svi_state, epoch_steps)
if kl_annealing_active:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of calculating steps, wouldn't it be better to precompute here the kl_factor and pass it to run_epoch?

Comment thread tests/test_model.py
# ---------------------------------------------------------------------------


@pytest.mark.parametrize("bad_value", [-0.1, 1.5, 2.0, -1.0])
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put all these tests on a designated class.

Comment thread tests/test_model.py
@pytest.mark.parametrize("bad_value", [-0.1, 1.5, 2.0, -1.0])
def test_kl_annealing_fraction_range_validation(bad_value: float):
"""Out-of-range kl_annealing_fraction raises ValueError at construction time."""
with pytest.raises(ValueError, match="kl_annealing_fraction"):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please avoid hard-coded constants in the tests to make them easily-maintainable.
Try using pytest.mark.parameterize and hypothesis given for better coverage.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants