@@ -1474,7 +1474,7 @@ __global__ void ppo_loss_forward_kernel_optimized(
14741474 const T* __restrict__ values_pred,
14751475 const int64_t * __restrict__ actions,
14761476 const T* __restrict__ old_logprobs,
1477- const T * __restrict__ advantages,
1477+ const float * __restrict__ advantages,
14781478 const T* __restrict__ prio,
14791479 const T* __restrict__ values,
14801480 const T* __restrict__ returns,
@@ -1541,8 +1541,8 @@ __global__ void ppo_loss_forward_kernel_optimized(
15411541 float old_logp = float (old_logprobs[nt]);
15421542 float adv = float (advantages[nt]);
15431543 float w = float (prio[n]);
1544- float adv_std = sqrtf (adv_var[0 ]);
1545- float adv_normalized = (adv - adv_mean[0 ]) / (adv_std + 1e-8f );
1544+ float adv_std = sqrtf (float ( adv_var[0 ]) );
1545+ float adv_normalized = (adv - float ( adv_mean[0 ]) ) / (adv_std + 1e-8f );
15461546
15471547 float logratio = new_logp - old_logp;
15481548 float ratio = __expf (logratio);
@@ -1588,14 +1588,14 @@ __global__ void ppo_loss_forward_kernel_optimized(
15881588
15891589template <typename T>
15901590__global__ void ppo_loss_backward_kernel_optimized (
1591- T * __restrict__ grad_logits,
1592- T * __restrict__ grad_values_pred,
1591+ float * __restrict__ grad_logits,
1592+ float * __restrict__ grad_values_pred,
15931593 const float * __restrict__ grad_loss,
15941594 const T* __restrict__ logits,
15951595 const T* __restrict__ values_pred,
15961596 const int64_t * __restrict__ actions,
15971597 const T* __restrict__ old_logprobs,
1598- const T * __restrict__ advantages,
1598+ const float * __restrict__ advantages,
15991599 const T* __restrict__ prio,
16001600 const T* __restrict__ values,
16011601 const T* __restrict__ returns,
@@ -1665,8 +1665,8 @@ __global__ void ppo_loss_backward_kernel_optimized(
16651665 float v_clipped = val + fmaxf (-vf_clip_coef, fminf (vf_clip_coef, v_error));
16661666
16671667 // normalize advantage
1668- float adv_std = sqrtf (adv_var[0 ]);
1669- float adv_normalized = (adv - adv_mean[0 ]) / (adv_std + 1e-8f );
1668+ float adv_std = sqrtf (float ( adv_var[0 ]) );
1669+ float adv_normalized = (adv - float ( adv_mean[0 ]) ) / (adv_std + 1e-8f );
16701670
16711671 // loss gradient scaling
16721672 float dL = grad_loss[0 ] * inv_NT;
@@ -1686,7 +1686,7 @@ __global__ void ppo_loss_backward_kernel_optimized(
16861686 } else {
16871687 d_val_pred = val_pred - ret;
16881688 }
1689- grad_values_pred[values_idx] = T ( dL * vf_coef * d_val_pred) ;
1689+ grad_values_pred[values_idx] = dL * vf_coef * d_val_pred;
16901690
16911691 // policy loss gradient
16921692 float ratio_clipped = fmaxf (1 .0f - clip_coef, fminf (1 .0f + clip_coef, ratio));
@@ -1710,7 +1710,7 @@ __global__ void ppo_loss_backward_kernel_optimized(
17101710 d_logit -= p * d_new_logp;
17111711
17121712 d_logit += d_entropy_term * p * (-entropy - logp);
1713- grad_logits[logits_base + a * logits_stride_a] = T ( d_logit) ;
1713+ grad_logits[logits_base + a * logits_stride_a] = d_logit;
17141714 }
17151715}
17161716
@@ -1724,12 +1724,12 @@ inline void launch_ppo_loss_forward_optimized(
17241724 const T* values_pred,
17251725 const int64_t * actions,
17261726 const T* old_logprobs,
1727- const T * advantages,
1727+ const float * advantages, // always fp32 for precision
17281728 const T* prio,
17291729 const T* values,
17301730 const T* returns,
1731- const float * adv_mean,
1732- const float * adv_var,
1731+ const float * adv_mean, // keep fp32
1732+ const float * adv_var, // keep fp32
17331733 float clip_coef,
17341734 float vf_clip_coef,
17351735 float vf_coef,
@@ -1784,19 +1784,19 @@ inline void launch_ppo_loss_forward_optimized(
17841784
17851785template <typename T>
17861786void launch_ppo_loss_backward_optimized (
1787- T * grad_logits,
1788- T * grad_values_pred,
1787+ float * grad_logits,
1788+ float * grad_values_pred,
17891789 const float * grad_loss,
17901790 const T* logits,
17911791 const T* values_pred, // added: need to read val_pred directly
17921792 const int64_t * actions,
17931793 const T* old_logprobs,
1794- const T * advantages,
1794+ const float * advantages,
17951795 const T* prio,
17961796 const T* values,
17971797 const T* returns,
17981798 const float * adv_mean,
1799- const float * adv_var, // variance, not std
1799+ const float * adv_var,
18001800 float clip_coef,
18011801 float vf_clip_coef,
18021802 float vf_coef,
@@ -2011,7 +2011,7 @@ __global__ void ppo_loss_backward_kernel(
20112011
20122012 // === Retrieve saved values from forward pass ===
20132013 const double * saved = saved_for_backward + idx * 5 ;
2014- double new_logp = saved[0 ]; // new log prob of selected action
2014+ // double new_logp = saved[0]; // new log prob of selected action
20152015 double ratio = saved[1 ]; // exp(new_logp - old_logp)
20162016 double val_pred = saved[2 ]; // value prediction
20172017 double v_clipped = saved[3 ]; // clipped value target
@@ -2842,10 +2842,10 @@ void launch_ppo_loss_backward_optimized_float(float* grad_logits, float* grad_va
28422842 launch_ppo_loss_backward_optimized<float >(grad_logits, grad_values_pred, grad_loss, logits, values_pred, actions, old_logprobs, advantages, prio, values, returns, adv_mean, adv_var, clip_coef, vf_clip_coef, vf_coef, ent_coef, T_seq, A, N, logits_stride_n, logits_stride_t , logits_stride_a, values_stride_n, values_stride_t , stream);
28432843}
28442844
2845- void launch_ppo_loss_forward_optimized_bf16 (float * loss_output, double * saved_for_backward, at::BFloat16* ratio_out, at::BFloat16* newvalue_out, const at::BFloat16* logits, const at::BFloat16* values_pred, const int64_t * actions, const at::BFloat16* old_logprobs, const at::BFloat16 * advantages, const at::BFloat16* prio, const at::BFloat16* values, const at::BFloat16* returns, const float * adv_mean, const float * adv_var, float clip_coef, float vf_clip_coef, float vf_coef, float ent_coef, int T_seq, int A, int N, int logits_stride_n, int logits_stride_t , int logits_stride_a, int values_stride_n, int values_stride_t , cudaStream_t stream) {
2845+ void launch_ppo_loss_forward_optimized_bf16 (float * loss_output, double * saved_for_backward, at::BFloat16* ratio_out, at::BFloat16* newvalue_out, const at::BFloat16* logits, const at::BFloat16* values_pred, const int64_t * actions, const at::BFloat16* old_logprobs, const float * advantages, const at::BFloat16* prio, const at::BFloat16* values, const at::BFloat16* returns, const float * adv_mean, const float * adv_var, float clip_coef, float vf_clip_coef, float vf_coef, float ent_coef, int T_seq, int A, int N, int logits_stride_n, int logits_stride_t , int logits_stride_a, int values_stride_n, int values_stride_t , cudaStream_t stream) {
28462846 launch_ppo_loss_forward_optimized<at::BFloat16>(loss_output, saved_for_backward, ratio_out, newvalue_out, logits, values_pred, actions, old_logprobs, advantages, prio, values, returns, adv_mean, adv_var, clip_coef, vf_clip_coef, vf_coef, ent_coef, T_seq, A, N, logits_stride_n, logits_stride_t , logits_stride_a, values_stride_n, values_stride_t , stream);
28472847}
2848- void launch_ppo_loss_backward_optimized_bf16 (at::BFloat16 * grad_logits, at::BFloat16 * grad_values_pred, const float * grad_loss, const at::BFloat16* logits, const at::BFloat16* values_pred, const int64_t * actions, const at::BFloat16* old_logprobs, const at::BFloat16 * advantages, const at::BFloat16* prio, const at::BFloat16* values, const at::BFloat16* returns, const float * adv_mean, const float * adv_var, float clip_coef, float vf_clip_coef, float vf_coef, float ent_coef, int T_seq, int A, int N, int logits_stride_n, int logits_stride_t , int logits_stride_a, int values_stride_n, int values_stride_t , cudaStream_t stream) {
2848+ void launch_ppo_loss_backward_optimized_bf16 (float * grad_logits, float * grad_values_pred, const float * grad_loss, const at::BFloat16* logits, const at::BFloat16* values_pred, const int64_t * actions, const at::BFloat16* old_logprobs, const float * advantages, const at::BFloat16* prio, const at::BFloat16* values, const at::BFloat16* returns, const float * adv_mean, const float * adv_var, float clip_coef, float vf_clip_coef, float vf_coef, float ent_coef, int T_seq, int A, int N, int logits_stride_n, int logits_stride_t , int logits_stride_a, int values_stride_n, int values_stride_t , cudaStream_t stream) {
28492849 launch_ppo_loss_backward_optimized<at::BFloat16>(grad_logits, grad_values_pred, grad_loss, logits, values_pred, actions, old_logprobs, advantages, prio, values, returns, adv_mean, adv_var, clip_coef, vf_clip_coef, vf_coef, ent_coef, T_seq, A, N, logits_stride_n, logits_stride_t , logits_stride_a, values_stride_n, values_stride_t , stream);
28502850}
28512851
0 commit comments