The current implementation of policy_gradient_loss is:
log_pi_a_t = distributions.softmax().logprob(a_t, logits_t)
adv_t = jax.lax.select(use_stop_gradient, jax.lax.stop_gradient(adv_t), adv_t)
loss_per_timestep = -log_pi_a_t * adv_t
It's good that the gradients are already stopped around the advantages, but they should also be stopped around the actions to ensure an unbiased gradient estimator.
This is important when the actions are sampled as part of the training graph (MPO-style algos, imagination training with world models) rather than coming from the replay buffer, and the actor distribution implements a gradient for sample() (e.g. gaussian, or straight-through categoricals).
The current implementation of
policy_gradient_lossis:It's good that the gradients are already stopped around the advantages, but they should also be stopped around the actions to ensure an unbiased gradient estimator.
This is important when the actions are sampled as part of the training graph (MPO-style algos, imagination training with world models) rather than coming from the replay buffer, and the actor distribution implements a gradient for
sample()(e.g. gaussian, or straight-through categoricals).