@@ -138,10 +138,8 @@ struct RNN : public nn::Module {
138138 virtual Tensor initial_state (int batch_size, torch::Device device, torch::Dtype dtype) = 0;
139139};
140140
141- class MinGRU : public RNN {
142- public:
143- int hidden;
144- int num_layers;
141+ struct MinGRU : public RNN {
142+ int hidden, num_layers;
145143 bool kernels;
146144 vector<nn::Linear> layers;
147145
@@ -196,8 +194,7 @@ class MinGRU : public RNN {
196194 }
197195};
198196
199- class Policy : public nn ::Module {
200- public:
197+ struct Policy : public nn ::Module {
201198 int input, hidden, num_atns;
202199 shared_ptr<Encoder> encoder{nullptr };
203200 shared_ptr<Decoder> decoder{nullptr };
@@ -448,69 +445,143 @@ Tensor logcumsumexp_cpp(Tensor x) {
448445 return x.exp ().cumsum (1 ).log ();
449446}
450447
451- Tensor fused_ppo_loss_cpp (
452- Tensor logits,
453- Tensor newvalue,
454- Tensor actions,
455- Tensor old_logprobs,
456- Tensor advantages,
457- Tensor prio,
458- Tensor values,
459- Tensor returns,
460- Tensor adv_mean,
461- Tensor adv_std,
462- float clip_coef,
463- float vf_clip_coef,
464- float vf_coef,
465- float ent_coef
466- ) {
467- auto segments = logits.size (0 );
468- auto horizon = logits.size (1 );
448+ // Sample from multi-head discrete distribution
449+ // Returns {actions (B, heads), total_logprob (B,)}
450+ vector<Tensor> sample_discrete_cpp (Tensor logits, Tensor act_sizes_cpu, int num_heads) {
451+ logits = torch::nan_to_num (logits, 1e-8 , 1e-8 , 1e-8 );
452+ auto split = torch::split (logits, c10::IntArrayRef (act_sizes_cpu.data_ptr <int64_t >(), num_heads), 1 );
453+ vector<Tensor> actions_vec, logprobs_vec;
454+ for (int i = 0 ; i < num_heads; i++) {
455+ auto log_probs = torch::log_softmax (split[i], 1 );
456+ auto action = at::multinomial (log_probs.exp (), 1 , true );
457+ actions_vec.push_back (action);
458+ logprobs_vec.push_back (log_probs.gather (1 , action));
459+ }
460+ return {torch::cat (actions_vec, 1 ), torch::cat (logprobs_vec, 1 ).sum (1 )};
461+ }
469462
470- auto flat_logits = logits.reshape ({-1 , logits.size (-1 )});
471- auto flat_actions = actions.reshape ({-1 });
472- auto logprobs_new = torch::log_softmax (flat_logits, 1 );
463+ // Sample from continuous Normal distribution
464+ // Returns {actions (B, D), total_logprob (B,)}
465+ vector<Tensor> sample_continuous_cpp (Tensor mean, Tensor logstd) {
466+ auto std = logstd.exp ();
467+ auto actions = mean + std * torch::randn_like (mean);
468+ auto log_prob = -0.5 * ((actions - mean) / std).pow (2 ) - 0.5 * std::log (2 * M_PI) - logstd;
469+ return {actions, log_prob.sum (1 )};
470+ }
473471
474- auto probs_new = logprobs_new.exp ();
475- auto entropy = -(probs_new * logprobs_new).sum (1 ).mean ();
472+ // Compute logprob + entropy for multi-head discrete actions
473+ // Returns {logprob (batch,), entropy scalar}
474+ vector<Tensor> discrete_logprob_entropy_cpp (Tensor logits, Tensor actions, Tensor act_sizes_cpu, int num_heads) {
475+ logits = torch::nan_to_num (logits, 1e-8 , 1e-8 , 1e-8 );
476+ auto split = torch::split (logits, c10::IntArrayRef (act_sizes_cpu.data_ptr <int64_t >(), num_heads), 1 );
477+ int batch = logits.size (0 );
478+ vector<Tensor> logprobs_vec, entropies_vec;
479+ for (int h = 0 ; h < num_heads; h++) {
480+ auto log_probs = torch::log_softmax (split[h], 1 );
481+ auto probs = log_probs.exp ();
482+ auto head_actions = actions.select (-1 , h).reshape ({batch}).to (torch::kInt64 );
483+ logprobs_vec.push_back (log_probs.gather (1 , head_actions.unsqueeze (1 )));
484+ entropies_vec.push_back (-(probs * log_probs).sum (1 , true ));
485+ }
486+ auto logprob = torch::cat (logprobs_vec, 1 ).sum (1 );
487+ auto entropy = torch::cat (entropies_vec, 1 ).sum (1 ).mean ();
488+ return {logprob, entropy};
489+ }
476490
477- auto newlogprob_flat = logprobs_new.gather (1 , flat_actions.unsqueeze (1 )).squeeze (1 );
478- auto newlogprob = newlogprob_flat.reshape ({segments, horizon});
479- auto logratio = newlogprob - old_logprobs;
480- auto ratio_new = logratio.exp ();
491+ // Compute logprob + entropy for continuous Normal actions
492+ // Returns {logprob (batch,), entropy scalar}
493+ vector<Tensor> continuous_logprob_entropy_cpp (Tensor mean, Tensor logstd, Tensor actions) {
494+ auto std = logstd.exp ();
495+ auto normalized = (actions.to (mean.dtype ()) - mean) / std;
496+ auto log_prob = -0.5 * normalized.pow (2 ) - 0.5 * std::log (2 * M_PI) - logstd;
497+ auto logprob = log_prob.sum (1 );
498+ constexpr float HALF_1_PLUS_LOG_2PI = 1 .4189385332046727f ;
499+ auto entropy = (HALF_1_PLUS_LOG_2PI + logstd).sum (1 ).mean ();
500+ return {logprob, entropy};
501+ }
481502
482- auto adv_normalized = prio.unsqueeze (1 ) * (advantages - adv_mean) / (adv_std + 1e-8 );
483- auto pg_loss1 = -adv_normalized * ratio_new;
484- auto pg_loss2 = -adv_normalized * torch::clamp (ratio_new, 1.0 - clip_coef, 1.0 + clip_coef);
503+ // PPO clipped loss with clipped value loss
504+ Tensor ppo_loss_cpp (Tensor ratio, Tensor advantages, Tensor prio,
505+ Tensor newvalue, Tensor values, Tensor returns, Tensor entropy,
506+ float clip_coef, float vf_clip_coef, float vf_coef, float ent_coef) {
507+ auto adv_normalized = prio * (advantages - advantages.mean ()) / (advantages.std () + 1e-8 );
508+ auto pg_loss1 = -adv_normalized * ratio;
509+ auto pg_loss2 = -adv_normalized * torch::clamp (ratio, 1.0 - clip_coef, 1.0 + clip_coef);
485510 auto pg_loss = torch::max (pg_loss1, pg_loss2).mean ();
486511
487- auto nv = newvalue.view (returns.sizes ());
488- auto v_clipped = values + torch::clamp (nv - values, -vf_clip_coef, vf_clip_coef);
489- auto v_loss_unclipped = (nv - returns).pow (2 );
490- auto v_loss_clipped = (v_clipped - returns).pow (2 );
491- auto v_loss = 0.5 * torch::max (v_loss_unclipped, v_loss_clipped).mean ();
512+ newvalue = newvalue.view (returns.sizes ());
513+ auto v_clipped = values + torch::clamp (newvalue - values, -vf_clip_coef, vf_clip_coef);
514+ auto v_loss = 0.5 * torch::max ((newvalue - returns).pow (2 ), (v_clipped - returns).pow (2 )).mean ();
492515
493516 return pg_loss + vf_coef * v_loss - ent_coef * entropy;
494517}
495518
496- // Reference implementation for sample_logits (for correctness testing)
497- vector<Tensor> sample_logits_cpp (
498- Tensor logits
499- ) {
500- // nan_to_num
501- auto clean_logits = torch::nan_to_num (logits);
502-
503- // log_softmax
504- auto log_probs = torch::log_softmax (clean_logits, 1 );
505-
506- // multinomial sampling
507- auto probs = log_probs.exp ();
508- auto actions = torch::multinomial (probs, 1 , /* replacement=*/ false ).squeeze (1 );
519+ // Dispatch: sample actions using kernel or cpp path, write to output buffers
520+ void sample_actions (Logits& logits, Tensor value,
521+ Tensor actions_out, Tensor logprobs_out, Tensor values_out,
522+ Tensor act_sizes, Tensor act_sizes_cpu,
523+ bool is_continuous, bool kernels, uint64_t rng_seed, Tensor rng_offset) {
524+ if (kernels) {
525+ Tensor logstd = logits.logstd .defined () ? logits.logstd : Tensor ();
526+ sample_logits (logits.mean , logstd, value, actions_out, logprobs_out,
527+ values_out, act_sizes, rng_seed, rng_offset);
528+ } else {
529+ vector<Tensor> result;
530+ if (is_continuous) {
531+ result = sample_continuous_cpp (logits.mean , logits.logstd );
532+ } else {
533+ result = sample_discrete_cpp (logits.mean , act_sizes_cpu, actions_out.size (1 ));
534+ }
535+ actions_out.copy_ (result[0 ].to (torch::kFloat64 ), false );
536+ logprobs_out.copy_ (result[1 ], false );
537+ values_out.copy_ (value.flatten (), false );
538+ }
539+ }
509540
510- // gather logprobs
511- auto sampled_logprobs = log_probs.gather (1 , actions.unsqueeze (1 )).squeeze (1 );
541+ // Dispatch: compute PPO loss using kernel or cpp path
542+ // Writes ratio and newvalue to output buffers as side effect
543+ Tensor compute_train_loss (Logits& logits, Tensor newvalue,
544+ Tensor actions, Tensor old_logprobs, Tensor advantages, Tensor prio,
545+ Tensor values, Tensor returns,
546+ Tensor ratio_out, Tensor newvalue_out,
547+ Tensor act_sizes, Tensor act_sizes_cpu,
548+ int minibatch_size, int horizon,
549+ float clip_coef, float vf_clip_coef, float vf_coef, float ent_coef,
550+ bool is_continuous, bool kernels) {
551+ if (kernels) {
552+ Tensor logstd_safe = logits.logstd .defined () ? logits.logstd : torch::empty ({0 }, logits.mean .options ());
553+ auto [adv_var, adv_mean] = torch::var_mean (advantages);
554+ return fused_ppo_loss_optimized (
555+ logits.mean , logstd_safe, newvalue,
556+ actions, old_logprobs, advantages, prio, values, returns,
557+ adv_mean, adv_var, // variance, not std - kernel does sqrtf
558+ ratio_out, newvalue_out,
559+ act_sizes, clip_coef, vf_clip_coef, vf_coef, ent_coef
560+ )[0 ];
561+ } else {
562+ int num_heads = actions.size (-1 );
563+ int batch = minibatch_size;
564+ int segments = batch / horizon;
565+
566+ vector<Tensor> result;
567+ if (is_continuous) {
568+ TORCH_CHECK (logits.logstd .defined () && logits.logstd .numel () > 0 ,
569+ " logstd must be defined for continuous actions" );
570+ result = continuous_logprob_entropy_cpp (
571+ logits.mean .reshape ({batch, -1 }), logits.logstd .reshape ({batch, -1 }),
572+ actions.reshape ({batch, -1 }));
573+ } else {
574+ result = discrete_logprob_entropy_cpp (
575+ logits.mean .reshape ({batch, -1 }), actions, act_sizes_cpu, num_heads);
576+ }
577+ Tensor ratio = (result[0 ].reshape ({segments, horizon}) - old_logprobs).exp ();
578+ ratio_out.copy_ (ratio, false );
579+ newvalue_out.copy_ (newvalue, false );
512580
513- return {actions, sampled_logprobs};
581+ return ppo_loss_cpp (ratio, advantages, prio,
582+ newvalue, values, returns, result[1 ],
583+ clip_coef, vf_clip_coef, vf_coef, ent_coef);
584+ }
514585}
515586
516587// Reference implementation for testing
0 commit comments