Skip to content

Commit 95c05d2

Browse files
committed
working both
1 parent f68f7f7 commit 95c05d2

File tree

3 files changed

+52
-37
lines changed

3 files changed

+52
-37
lines changed

pufferlib/extensions/bindings.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ std::unique_ptr<pufferlib::PuffeRL> create_pufferl(pybind11::dict kwargs) {
129129
hypers.kernels = get_config(kwargs, "kernels");
130130
hypers.profile = get_config(kwargs, "profile");
131131
hypers.use_omp = get_config(kwargs, "use_omp");
132+
hypers.bf16 = get_config(kwargs, "bf16");
132133

133134
std::string env_name = kwargs["env_name"].cast<std::string>();
134135
Dict* vec_kwargs = py_dict_to_c_dict(kwargs["vec_kwargs"].cast<py::dict>());

pufferlib/extensions/pufferlib.cpp

Lines changed: 50 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,10 @@ typedef torch::Tensor Tensor;
3333
// CUDA kernel wrappers
3434
#include "modules.cpp"
3535

36-
auto DTYPE = torch::kBFloat16;
37-
auto DTYPE_FP32 = torch::kFloat32; // master weights
36+
// get dtype based on bf16 flag
37+
inline torch::ScalarType get_dtype(bool bf16) {
38+
return bf16 ? torch::kBFloat16 : torch::kFloat32;
39+
}
3840

3941
namespace pufferlib {
4042

@@ -186,16 +188,17 @@ typedef struct {
186188
} TrainGraph;
187189

188190
TrainGraph create_train_graph(int minibatch_segments, int horizon, int input_size,
189-
int num_layers, int hidden_size, int expansion_factor, int num_atns) {
191+
int num_layers, int hidden_size, int expansion_factor, int num_atns, bool bf16) {
190192
TrainGraph g;
191-
auto options = torch::TensorOptions().dtype(DTYPE).device(torch::kCUDA);
193+
auto dtype = get_dtype(bf16);
194+
auto options = torch::TensorOptions().dtype(dtype).device(torch::kCUDA);
192195
g.mb_obs = torch::zeros({minibatch_segments, horizon, input_size}, options);
193196
g.mb_state = torch::zeros({num_layers, minibatch_segments, 1, hidden_size * expansion_factor}, options);
194197
g.mb_newvalue = torch::zeros({minibatch_segments, horizon, 1}, options);
195198
g.mb_ratio = torch::zeros({minibatch_segments, horizon}, options);
196199
g.mb_actions = torch::zeros({minibatch_segments, horizon, num_atns}, options).to(torch::kInt64);
197200
g.mb_logprobs = torch::zeros({minibatch_segments, horizon}, options);
198-
g.mb_advantages = torch::zeros({minibatch_segments, horizon}, options.dtype(torch::kFloat32)); // fp32 precision
201+
g.mb_advantages = torch::zeros({minibatch_segments, horizon}, options.dtype(torch::kFloat32)); // always fp32 for precision
199202
g.mb_prio = torch::zeros({minibatch_segments, 1}, options);
200203
g.mb_values = torch::zeros({minibatch_segments, horizon}, options);
201204
g.mb_returns = torch::zeros({minibatch_segments, horizon}, options);
@@ -213,16 +216,17 @@ typedef struct {
213216
Tensor importance;
214217
} RolloutBuf;
215218

216-
RolloutBuf create_rollouts(int horizon, int segments, int input_size, int num_atns) {
219+
RolloutBuf create_rollouts(int horizon, int segments, int input_size, int num_atns, bool bf16) {
217220
RolloutBuf r;
218-
r.observations = torch::zeros({horizon, segments, input_size}, torch::dtype(DTYPE).device(torch::kCUDA));
221+
auto dtype = get_dtype(bf16);
222+
r.observations = torch::zeros({horizon, segments, input_size}, torch::dtype(dtype).device(torch::kCUDA));
219223
r.actions = torch::zeros({horizon, segments, num_atns}, torch::dtype(torch::kFloat64).device(torch::kCUDA));
220-
r.values = torch::zeros({horizon, segments}, torch::dtype(DTYPE).device(torch::kCUDA));
221-
r.logprobs = torch::zeros({horizon, segments}, torch::dtype(DTYPE).device(torch::kCUDA));
222-
r.rewards = torch::zeros({horizon, segments}, torch::dtype(DTYPE).device(torch::kCUDA));
223-
r.terminals = torch::zeros({horizon, segments}, torch::dtype(DTYPE).device(torch::kCUDA));
224-
r.ratio = torch::zeros({horizon, segments}, torch::dtype(DTYPE).device(torch::kCUDA));
225-
r.importance = torch::zeros({horizon, segments}, torch::dtype(DTYPE).device(torch::kCUDA));
224+
r.values = torch::zeros({horizon, segments}, torch::dtype(dtype).device(torch::kCUDA));
225+
r.logprobs = torch::zeros({horizon, segments}, torch::dtype(dtype).device(torch::kCUDA));
226+
r.rewards = torch::zeros({horizon, segments}, torch::dtype(dtype).device(torch::kCUDA));
227+
r.terminals = torch::zeros({horizon, segments}, torch::dtype(dtype).device(torch::kCUDA));
228+
r.ratio = torch::zeros({horizon, segments}, torch::dtype(dtype).device(torch::kCUDA));
229+
r.importance = torch::zeros({horizon, segments}, torch::dtype(dtype).device(torch::kCUDA));
226230
return r;
227231
}
228232

@@ -271,6 +275,7 @@ typedef struct {
271275
bool kernels;
272276
bool profile;
273277
bool use_omp;
278+
bool bf16; // bfloat16 mixed precision training
274279
} HypersT;
275280

276281
typedef struct {
@@ -442,18 +447,21 @@ void train_forward_call(TrainGraph& graph, PolicyMinGRU* policy_bf16, PolicyMinG
442447
loss = pg_loss + hypers.vf_coef*v_loss - hypers.ent_coef*entropy;
443448
}
444449

445-
// computes gradients on bf16 weights
450+
// computes gradients on bf16 weights (or fp32 if not using bf16)
446451
loss.backward();
447452

448453
// copy gradients from bf16 to fp32, then optimizer step on fp32 master weights
449-
copy_gradients_to_fp32(policy_bf16, policy_fp32);
454+
if (hypers.bf16) {
455+
copy_gradients_to_fp32(policy_bf16, policy_fp32);
456+
}
450457
clip_grad_norm_(policy_fp32->parameters(), hypers.max_grad_norm);
451458
muon->step();
452459
muon->zero_grad();
453-
policy_bf16->zero_grad(); // also need to clear bf16 gradients
454-
455-
// sync updated fp32 weights back to bf16 for next forward pass
456-
sync_policy_weights(policy_bf16, policy_fp32);
460+
if (hypers.bf16) {
461+
policy_bf16->zero_grad(); // also need to clear bf16 gradients
462+
// sync updated fp32 weights back to bf16 for next forward pass
463+
sync_policy_weights(policy_bf16, policy_fp32);
464+
}
457465
}
458466

459467
// Capture
@@ -609,18 +617,21 @@ std::unique_ptr<pufferlib::PuffeRL> create_pufferl_impl(HypersT& hypers, const s
609617
auto [enc_fp32, dec_fp32] = create_encoder_decoder();
610618
PolicyMinGRU* policy_fp32 = new PolicyMinGRU(enc_fp32, dec_fp32, input_size, act_n, hidden_size, expansion_factor, num_layers, kernels);
611619
policy_fp32->to(torch::kCUDA);
612-
policy_fp32->to(DTYPE_FP32);
620+
policy_fp32->to(torch::kFloat32);
613621
pufferl->policy_fp32 = policy_fp32;
614622

615-
// Create bf16 working policy (for forward/backward - fast Tensor Core ops)
616-
auto [enc_bf16, dec_bf16] = create_encoder_decoder();
617-
PolicyMinGRU* policy_bf16 = new PolicyMinGRU(enc_bf16, dec_bf16, input_size, act_n, hidden_size, expansion_factor, num_layers, kernels);
618-
policy_bf16->to(torch::kCUDA);
619-
policy_bf16->to(DTYPE);
620-
pufferl->policy_bf16 = policy_bf16;
621-
622-
// Sync bf16 weights from fp32 initially
623-
sync_policy_weights(policy_bf16, policy_fp32);
623+
if (hypers.bf16) {
624+
// create bf16 working policy (for fwd/bwd)
625+
auto [enc_bf16, dec_bf16] = create_encoder_decoder();
626+
PolicyMinGRU* policy_bf16 = new PolicyMinGRU(enc_bf16, dec_bf16, input_size, act_n, hidden_size, expansion_factor, num_layers, kernels);
627+
policy_bf16->to(torch::kCUDA);
628+
policy_bf16->to(torch::kBFloat16);
629+
pufferl->policy_bf16 = policy_bf16;
630+
sync_policy_weights(policy_bf16, policy_fp32); // initial sync
631+
} else {
632+
// just use same policy for both
633+
pufferl->policy_bf16 = policy_fp32;
634+
}
624635

625636
// Optimizer uses fp32 master weights for precise gradient accumulation
626637
float lr = hypers.lr;
@@ -641,17 +652,18 @@ std::unique_ptr<pufferlib::PuffeRL> create_pufferl_impl(HypersT& hypers, const s
641652
printf("DEBUG: num_envs=%d, total_agents=%d, segments=%d, batch=%d, num_buffers=%d\n",
642653
vec->size, total_agents, segments, batch, num_buffers);
643654

644-
pufferl->rollouts = create_rollouts(horizon, total_agents, input_size, num_action_heads);
655+
pufferl->rollouts = create_rollouts(horizon, total_agents, input_size, num_action_heads, hypers.bf16);
645656
pufferl->train_buf = create_train_graph(minibatch_segments, horizon, input_size,
646-
policy_bf16->num_layers, policy_bf16->hidden_size, policy_bf16->expansion_factor, num_action_heads);
657+
policy_fp32->num_layers, policy_fp32->hidden_size, policy_fp32->expansion_factor, num_action_heads, hypers.bf16);
647658

648-
pufferl->adv_mean = torch::zeros({1}, torch::dtype(DTYPE).device(torch::kCUDA));
649-
pufferl->adv_std = torch::ones({1}, torch::dtype(DTYPE).device(torch::kCUDA));
659+
// always fp32 since advantages are computed in fp32
660+
pufferl->adv_mean = torch::zeros({1}, torch::dtype(torch::kFloat32).device(torch::kCUDA));
661+
pufferl->adv_std = torch::ones({1}, torch::dtype(torch::kFloat32).device(torch::kCUDA));
650662

651663
// Per-buffer states: each is {num_layers, block_size, hidden} for contiguous access
652664
pufferl->buffer_states.resize(num_buffers);
653665
for (int i = 0; i < num_buffers; i++) {
654-
pufferl->buffer_states[i] = policy_bf16->initial_state(batch, torch::kCUDA);
666+
pufferl->buffer_states[i] = pufferl->policy_bf16->initial_state(batch, torch::kCUDA);
655667
}
656668

657669
if (hypers.cudagraphs) {
@@ -830,9 +842,10 @@ void train_impl(PuffeRL& pufferl) {
830842
cudaEventCreate(&start);
831843
cudaEventCreate(&stop);
832844

845+
auto dtype = get_dtype(hypers.bf16);
833846
Tensor mb_state = torch::zeros(
834847
{policy_bf16->num_layers, minibatch_segments, 1, (int64_t)(policy_bf16->hidden_size*policy_bf16->expansion_factor)},
835-
torch::dtype(DTYPE).device(rollouts.values.device())
848+
torch::dtype(dtype).device(rollouts.values.device())
836849
);
837850

838851
// Temporary: random indices and uniform weights
@@ -870,8 +883,8 @@ void train_impl(PuffeRL& pufferl) {
870883
// Update global ratio and values in-place (matches Python)
871884
// Buffers are {horizon, segments}, so index_copy_ along dim 1 (segments)
872885
// Source is {minibatch_segments, horizon}, need to transpose to {horizon, minibatch_segments}
873-
pufferl.rollouts.ratio.index_copy_(1, idx, graph.mb_ratio.detach().squeeze(-1).to(DTYPE).transpose(0, 1));
874-
pufferl.rollouts.values.index_copy_(1, idx, graph.mb_newvalue.detach().squeeze(-1).to(DTYPE).transpose(0, 1));
886+
pufferl.rollouts.ratio.index_copy_(1, idx, graph.mb_ratio.detach().squeeze(-1).to(dtype).transpose(0, 1));
887+
pufferl.rollouts.values.index_copy_(1, idx, graph.mb_newvalue.detach().squeeze(-1).to(dtype).transpose(0, 1));
875888

876889
}
877890
pufferl.epoch += 1;

pufferlib/pufferl.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ def __init__(self, config, logger=None, verbose=True):
156156
config['kernels'] = True
157157
config['use_omp'] = True
158158
config['num_buffers'] = 2
159+
config['bf16'] = config.get('bf16', True) # bfloat16 mixed precision training
159160
self.pufferl_cpp = _C.create_pufferl(config)
160161
self.observations = self.pufferl_cpp.rollouts.observations
161162
self.actions = self.pufferl_cpp.rollouts.actions

0 commit comments

Comments
 (0)