@@ -33,8 +33,10 @@ typedef torch::Tensor Tensor;
3333// CUDA kernel wrappers
3434#include " modules.cpp"
3535
36- auto DTYPE = torch::kBFloat16 ;
37- auto DTYPE_FP32 = torch::kFloat32 ; // master weights
36+ // get dtype based on bf16 flag
37+ inline torch::ScalarType get_dtype (bool bf16 ) {
38+ return bf16 ? torch::kBFloat16 : torch::kFloat32 ;
39+ }
3840
3941namespace pufferlib {
4042
@@ -186,16 +188,17 @@ typedef struct {
186188} TrainGraph;
187189
188190TrainGraph create_train_graph (int minibatch_segments, int horizon, int input_size,
189- int num_layers, int hidden_size, int expansion_factor, int num_atns) {
191+ int num_layers, int hidden_size, int expansion_factor, int num_atns, bool bf16 ) {
190192 TrainGraph g;
191- auto options = torch::TensorOptions ().dtype (DTYPE).device (torch::kCUDA );
193+ auto dtype = get_dtype (bf16 );
194+ auto options = torch::TensorOptions ().dtype (dtype).device (torch::kCUDA );
192195 g.mb_obs = torch::zeros ({minibatch_segments, horizon, input_size}, options);
193196 g.mb_state = torch::zeros ({num_layers, minibatch_segments, 1 , hidden_size * expansion_factor}, options);
194197 g.mb_newvalue = torch::zeros ({minibatch_segments, horizon, 1 }, options);
195198 g.mb_ratio = torch::zeros ({minibatch_segments, horizon}, options);
196199 g.mb_actions = torch::zeros ({minibatch_segments, horizon, num_atns}, options).to (torch::kInt64 );
197200 g.mb_logprobs = torch::zeros ({minibatch_segments, horizon}, options);
198- g.mb_advantages = torch::zeros ({minibatch_segments, horizon}, options.dtype (torch::kFloat32 )); // fp32 precision
201+ g.mb_advantages = torch::zeros ({minibatch_segments, horizon}, options.dtype (torch::kFloat32 )); // always fp32 for precision
199202 g.mb_prio = torch::zeros ({minibatch_segments, 1 }, options);
200203 g.mb_values = torch::zeros ({minibatch_segments, horizon}, options);
201204 g.mb_returns = torch::zeros ({minibatch_segments, horizon}, options);
@@ -213,16 +216,17 @@ typedef struct {
213216 Tensor importance;
214217} RolloutBuf;
215218
216- RolloutBuf create_rollouts (int horizon, int segments, int input_size, int num_atns) {
219+ RolloutBuf create_rollouts (int horizon, int segments, int input_size, int num_atns, bool bf16 ) {
217220 RolloutBuf r;
218- r.observations = torch::zeros ({horizon, segments, input_size}, torch::dtype (DTYPE).device (torch::kCUDA ));
221+ auto dtype = get_dtype (bf16 );
222+ r.observations = torch::zeros ({horizon, segments, input_size}, torch::dtype (dtype).device (torch::kCUDA ));
219223 r.actions = torch::zeros ({horizon, segments, num_atns}, torch::dtype (torch::kFloat64 ).device (torch::kCUDA ));
220- r.values = torch::zeros ({horizon, segments}, torch::dtype (DTYPE ).device (torch::kCUDA ));
221- r.logprobs = torch::zeros ({horizon, segments}, torch::dtype (DTYPE ).device (torch::kCUDA ));
222- r.rewards = torch::zeros ({horizon, segments}, torch::dtype (DTYPE ).device (torch::kCUDA ));
223- r.terminals = torch::zeros ({horizon, segments}, torch::dtype (DTYPE ).device (torch::kCUDA ));
224- r.ratio = torch::zeros ({horizon, segments}, torch::dtype (DTYPE ).device (torch::kCUDA ));
225- r.importance = torch::zeros ({horizon, segments}, torch::dtype (DTYPE ).device (torch::kCUDA ));
224+ r.values = torch::zeros ({horizon, segments}, torch::dtype (dtype ).device (torch::kCUDA ));
225+ r.logprobs = torch::zeros ({horizon, segments}, torch::dtype (dtype ).device (torch::kCUDA ));
226+ r.rewards = torch::zeros ({horizon, segments}, torch::dtype (dtype ).device (torch::kCUDA ));
227+ r.terminals = torch::zeros ({horizon, segments}, torch::dtype (dtype ).device (torch::kCUDA ));
228+ r.ratio = torch::zeros ({horizon, segments}, torch::dtype (dtype ).device (torch::kCUDA ));
229+ r.importance = torch::zeros ({horizon, segments}, torch::dtype (dtype ).device (torch::kCUDA ));
226230 return r;
227231}
228232
@@ -271,6 +275,7 @@ typedef struct {
271275 bool kernels;
272276 bool profile;
273277 bool use_omp;
278+ bool bf16 ; // bfloat16 mixed precision training
274279} HypersT;
275280
276281typedef struct {
@@ -442,18 +447,21 @@ void train_forward_call(TrainGraph& graph, PolicyMinGRU* policy_bf16, PolicyMinG
442447 loss = pg_loss + hypers.vf_coef *v_loss - hypers.ent_coef *entropy;
443448 }
444449
445- // computes gradients on bf16 weights
450+ // computes gradients on bf16 weights (or fp32 if not using bf16)
446451 loss.backward ();
447452
448453 // copy gradients from bf16 to fp32, then optimizer step on fp32 master weights
449- copy_gradients_to_fp32 (policy_bf16, policy_fp32);
454+ if (hypers.bf16 ) {
455+ copy_gradients_to_fp32 (policy_bf16, policy_fp32);
456+ }
450457 clip_grad_norm_ (policy_fp32->parameters (), hypers.max_grad_norm );
451458 muon->step ();
452459 muon->zero_grad ();
453- policy_bf16->zero_grad (); // also need to clear bf16 gradients
454-
455- // sync updated fp32 weights back to bf16 for next forward pass
456- sync_policy_weights (policy_bf16, policy_fp32);
460+ if (hypers.bf16 ) {
461+ policy_bf16->zero_grad (); // also need to clear bf16 gradients
462+ // sync updated fp32 weights back to bf16 for next forward pass
463+ sync_policy_weights (policy_bf16, policy_fp32);
464+ }
457465}
458466
459467// Capture
@@ -609,18 +617,21 @@ std::unique_ptr<pufferlib::PuffeRL> create_pufferl_impl(HypersT& hypers, const s
609617 auto [enc_fp32, dec_fp32] = create_encoder_decoder ();
610618 PolicyMinGRU* policy_fp32 = new PolicyMinGRU (enc_fp32, dec_fp32, input_size, act_n, hidden_size, expansion_factor, num_layers, kernels);
611619 policy_fp32->to (torch::kCUDA );
612- policy_fp32->to (DTYPE_FP32 );
620+ policy_fp32->to (torch:: kFloat32 );
613621 pufferl->policy_fp32 = policy_fp32;
614622
615- // Create bf16 working policy (for forward/backward - fast Tensor Core ops)
616- auto [enc_bf16, dec_bf16] = create_encoder_decoder ();
617- PolicyMinGRU* policy_bf16 = new PolicyMinGRU (enc_bf16, dec_bf16, input_size, act_n, hidden_size, expansion_factor, num_layers, kernels);
618- policy_bf16->to (torch::kCUDA );
619- policy_bf16->to (DTYPE);
620- pufferl->policy_bf16 = policy_bf16;
621-
622- // Sync bf16 weights from fp32 initially
623- sync_policy_weights (policy_bf16, policy_fp32);
623+ if (hypers.bf16 ) {
624+ // create bf16 working policy (for fwd/bwd)
625+ auto [enc_bf16, dec_bf16] = create_encoder_decoder ();
626+ PolicyMinGRU* policy_bf16 = new PolicyMinGRU (enc_bf16, dec_bf16, input_size, act_n, hidden_size, expansion_factor, num_layers, kernels);
627+ policy_bf16->to (torch::kCUDA );
628+ policy_bf16->to (torch::kBFloat16 );
629+ pufferl->policy_bf16 = policy_bf16;
630+ sync_policy_weights (policy_bf16, policy_fp32); // initial sync
631+ } else {
632+ // just use same policy for both
633+ pufferl->policy_bf16 = policy_fp32;
634+ }
624635
625636 // Optimizer uses fp32 master weights for precise gradient accumulation
626637 float lr = hypers.lr ;
@@ -641,17 +652,18 @@ std::unique_ptr<pufferlib::PuffeRL> create_pufferl_impl(HypersT& hypers, const s
641652 printf (" DEBUG: num_envs=%d, total_agents=%d, segments=%d, batch=%d, num_buffers=%d\n " ,
642653 vec->size , total_agents, segments, batch, num_buffers);
643654
644- pufferl->rollouts = create_rollouts (horizon, total_agents, input_size, num_action_heads);
655+ pufferl->rollouts = create_rollouts (horizon, total_agents, input_size, num_action_heads, hypers. bf16 );
645656 pufferl->train_buf = create_train_graph (minibatch_segments, horizon, input_size,
646- policy_bf16 ->num_layers , policy_bf16 ->hidden_size , policy_bf16 ->expansion_factor , num_action_heads);
657+ policy_fp32 ->num_layers , policy_fp32 ->hidden_size , policy_fp32 ->expansion_factor , num_action_heads, hypers. bf16 );
647658
648- pufferl->adv_mean = torch::zeros ({1 }, torch::dtype (DTYPE).device (torch::kCUDA ));
649- pufferl->adv_std = torch::ones ({1 }, torch::dtype (DTYPE).device (torch::kCUDA ));
659+ // always fp32 since advantages are computed in fp32
660+ pufferl->adv_mean = torch::zeros ({1 }, torch::dtype (torch::kFloat32 ).device (torch::kCUDA ));
661+ pufferl->adv_std = torch::ones ({1 }, torch::dtype (torch::kFloat32 ).device (torch::kCUDA ));
650662
651663 // Per-buffer states: each is {num_layers, block_size, hidden} for contiguous access
652664 pufferl->buffer_states .resize (num_buffers);
653665 for (int i = 0 ; i < num_buffers; i++) {
654- pufferl->buffer_states [i] = policy_bf16->initial_state (batch, torch::kCUDA );
666+ pufferl->buffer_states [i] = pufferl-> policy_bf16 ->initial_state (batch, torch::kCUDA );
655667 }
656668
657669 if (hypers.cudagraphs ) {
@@ -830,9 +842,10 @@ void train_impl(PuffeRL& pufferl) {
830842 cudaEventCreate (&start);
831843 cudaEventCreate (&stop);
832844
845+ auto dtype = get_dtype (hypers.bf16 );
833846 Tensor mb_state = torch::zeros (
834847 {policy_bf16->num_layers , minibatch_segments, 1 , (int64_t )(policy_bf16->hidden_size *policy_bf16->expansion_factor )},
835- torch::dtype (DTYPE ).device (rollouts.values .device ())
848+ torch::dtype (dtype ).device (rollouts.values .device ())
836849 );
837850
838851 // Temporary: random indices and uniform weights
@@ -870,8 +883,8 @@ void train_impl(PuffeRL& pufferl) {
870883 // Update global ratio and values in-place (matches Python)
871884 // Buffers are {horizon, segments}, so index_copy_ along dim 1 (segments)
872885 // Source is {minibatch_segments, horizon}, need to transpose to {horizon, minibatch_segments}
873- pufferl.rollouts .ratio .index_copy_ (1 , idx, graph.mb_ratio .detach ().squeeze (-1 ).to (DTYPE ).transpose (0 , 1 ));
874- pufferl.rollouts .values .index_copy_ (1 , idx, graph.mb_newvalue .detach ().squeeze (-1 ).to (DTYPE ).transpose (0 , 1 ));
886+ pufferl.rollouts .ratio .index_copy_ (1 , idx, graph.mb_ratio .detach ().squeeze (-1 ).to (dtype ).transpose (0 , 1 ));
887+ pufferl.rollouts .values .index_copy_ (1 , idx, graph.mb_newvalue .detach ().squeeze (-1 ).to (dtype ).transpose (0 , 1 ));
875888
876889 }
877890 pufferl.epoch += 1 ;
0 commit comments