Skip to content

Commit 9e1f1b4

Browse files
committed
Jonah's bf16
2 parents f027552 + 95c05d2 commit 9e1f1b4

File tree

10 files changed

+284
-162
lines changed

10 files changed

+284
-162
lines changed

pufferlib/config/default.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ seed = 42
2020
[train]
2121
name = pufferai
2222
project = ablations
23+
bf16 = True
2324

2425
seed = 42
2526
torch_deterministic = True

pufferlib/extensions/bindings.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ pybind11::dict log_environments(pybind11::object pufferl_obj) {
2222

2323
Tensor initial_state(pybind11::object pufferl_obj, int64_t batch_size, torch::Device device) {
2424
auto& pufferl = pufferl_obj.cast<PuffeRL&>();
25-
return pufferl.policy->initial_state(batch_size, device);
25+
return pufferl.policy_bf16->initial_state(batch_size, device);
2626
}
2727

2828
void python_vec_recv(pybind11::object pufferl_obj, int buf) {
@@ -127,6 +127,7 @@ std::unique_ptr<pufferlib::PuffeRL> create_pufferl(pybind11::dict kwargs) {
127127
hypers.kernels = get_config(kwargs, "kernels");
128128
hypers.profile = get_config(kwargs, "profile");
129129
hypers.use_omp = get_config(kwargs, "use_omp");
130+
hypers.bf16 = get_config(kwargs, "bf16");
130131

131132
std::string env_name = kwargs["env_name"].cast<std::string>();
132133
Dict* vec_kwargs = py_dict_to_c_dict(kwargs["vec_kwargs"].cast<py::dict>());
@@ -232,7 +233,8 @@ PYBIND11_MODULE(_C, m) {
232233

233234
m.def("create_pufferl", &create_pufferl);
234235
py::class_<PuffeRL, std::unique_ptr<PuffeRL>>(m, "PuffeRL")
235-
.def_readwrite("policy", &PuffeRL::policy)
236+
.def_readwrite("policy_bf16", &PuffeRL::policy_bf16)
237+
.def_readwrite("policy_fp32", &PuffeRL::policy_fp32)
236238
.def_readwrite("muon", &PuffeRL::muon)
237239
.def_readwrite("hypers", &PuffeRL::hypers)
238240
.def_readwrite("rollouts", &PuffeRL::rollouts);

pufferlib/extensions/cuda/advantage.cu

Lines changed: 59 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,25 @@
11
#include <torch/extension.h>
22
#include <cuda.h>
33
#include <cuda_runtime.h>
4+
#include <c10/util/BFloat16.h>
45

56
namespace pufferlib {
67

7-
__host__ __device__ void puff_advantage_row_cuda(float* values, float* rewards, float* dones,
8-
float* importance, float* advantages, float gamma, float lambda,
8+
// TIn = input type (bf16 or float), TOut = output type (always float for precision)
9+
template<typename TIn, typename TOut>
10+
__host__ __device__ void puff_advantage_row_cuda(const TIn* values, const TIn* rewards, const TIn* dones,
11+
const TIn* importance, TOut* advantages, float gamma, float lambda,
912
float rho_clip, float c_clip, int horizon) {
1013
float lastpufferlam = 0;
1114
for (int t = horizon-2; t >= 0; t--) {
1215
int t_next = t + 1;
13-
float nextnonterminal = 1.0 - dones[t_next];
14-
float rho_t = fminf(importance[t], rho_clip);
15-
float c_t = fminf(importance[t], c_clip);
16-
float delta = rho_t*(rewards[t_next] + gamma*values[t_next]*nextnonterminal - values[t]);
16+
float nextnonterminal = 1.0f - float(dones[t_next]);
17+
float imp = float(importance[t]);
18+
float rho_t = fminf(imp, rho_clip);
19+
float c_t = fminf(imp, c_clip);
20+
float delta = rho_t*(float(rewards[t_next]) + gamma*float(values[t_next])*nextnonterminal - float(values[t]));
1721
lastpufferlam = delta + gamma*lambda*c_t*lastpufferlam*nextnonterminal;
18-
advantages[t] = lastpufferlam;
22+
advantages[t] = TOut(lastpufferlam);
1923
}
2024
}
2125

@@ -25,32 +29,42 @@ void vtrace_check_cuda(torch::Tensor values, torch::Tensor rewards,
2529

2630
// Validate input tensors
2731
torch::Device device = values.device();
28-
for (const torch::Tensor& t : {values, rewards, dones, importance, advantages}) {
32+
auto input_dtype = values.dtype();
33+
for (const torch::Tensor& t : {values, rewards, dones, importance}) {
2934
TORCH_CHECK(t.dim() == 2, "Tensor must be 2D");
3035
TORCH_CHECK(t.device() == device, "All tensors must be on same device");
3136
TORCH_CHECK(t.size(0) == num_steps, "First dimension must match num_steps");
3237
TORCH_CHECK(t.size(1) == horizon, "Second dimension must match horizon");
33-
TORCH_CHECK(t.dtype() == torch::kFloat32, "All tensors must be float32");
38+
TORCH_CHECK(t.dtype() == input_dtype, "Input tensors must have matching dtype");
3439
if (!t.is_contiguous()) {
3540
t.contiguous();
3641
}
3742
}
43+
// advantages can be different dtype (fp32 for precision)
44+
TORCH_CHECK(advantages.dim() == 2, "Advantages must be 2D");
45+
TORCH_CHECK(advantages.device() == device, "Advantages must be on same device");
46+
TORCH_CHECK(advantages.size(0) == num_steps, "Advantages first dimension must match");
47+
TORCH_CHECK(advantages.size(1) == horizon, "Advantages second dimension must match");
48+
if (!advantages.is_contiguous()) {
49+
advantages.contiguous();
50+
}
3851
}
3952

40-
// [num_steps, horizon]
41-
__global__ void puff_advantage_kernel(float* values, float* rewards,
42-
float* dones, float* importance, float* advantages, float gamma,
53+
template<typename TIn, typename TOut>
54+
__global__ void puff_advantage_kernel(const TIn* values, const TIn* rewards,
55+
const TIn* dones, const TIn* importance, TOut* advantages, float gamma,
4356
float lambda, float rho_clip, float c_clip, int num_steps, int horizon) {
4457
int row = blockIdx.x*blockDim.x + threadIdx.x;
4558
if (row >= num_steps) {
4659
return;
4760
}
4861
int offset = row*horizon;
49-
puff_advantage_row_cuda(values + offset, rewards + offset, dones + offset,
62+
puff_advantage_row_cuda<TIn, TOut>(values + offset, rewards + offset, dones + offset,
5063
importance + offset, advantages + offset, gamma, lambda, rho_clip, c_clip, horizon);
5164
}
5265

53-
void compute_puff_advantage_cuda(torch::Tensor values, torch::Tensor rewards,
66+
template<typename TIn, typename TOut>
67+
void compute_puff_advantage_cuda_impl(torch::Tensor values, torch::Tensor rewards,
5468
torch::Tensor dones, torch::Tensor importance, torch::Tensor advantages,
5569
double gamma, double lambda, double rho_clip, double c_clip) {
5670
int num_steps = values.size(0);
@@ -61,16 +75,16 @@ void compute_puff_advantage_cuda(torch::Tensor values, torch::Tensor rewards,
6175
int threads_per_block = 256;
6276
int blocks = (num_steps + threads_per_block - 1) / threads_per_block;
6377

64-
puff_advantage_kernel<<<blocks, threads_per_block>>>(
65-
values.data_ptr<float>(),
66-
rewards.data_ptr<float>(),
67-
dones.data_ptr<float>(),
68-
importance.data_ptr<float>(),
69-
advantages.data_ptr<float>(),
70-
gamma,
71-
lambda,
72-
rho_clip,
73-
c_clip,
78+
puff_advantage_kernel<TIn, TOut><<<blocks, threads_per_block>>>(
79+
values.data_ptr<TIn>(),
80+
rewards.data_ptr<TIn>(),
81+
dones.data_ptr<TIn>(),
82+
importance.data_ptr<TIn>(),
83+
advantages.data_ptr<TOut>(),
84+
static_cast<float>(gamma),
85+
static_cast<float>(lambda),
86+
static_cast<float>(rho_clip),
87+
static_cast<float>(c_clip),
7488
num_steps,
7589
horizon
7690
);
@@ -81,6 +95,27 @@ void compute_puff_advantage_cuda(torch::Tensor values, torch::Tensor rewards,
8195
}
8296
}
8397

98+
void compute_puff_advantage_cuda(torch::Tensor values, torch::Tensor rewards,
99+
torch::Tensor dones, torch::Tensor importance, torch::Tensor advantages,
100+
double gamma, double lambda, double rho_clip, double c_clip) {
101+
auto input_dtype = values.dtype();
102+
auto output_dtype = advantages.dtype();
103+
104+
// Support bf16 inputs with fp32 output for precision
105+
if (input_dtype == torch::kFloat32 && output_dtype == torch::kFloat32) {
106+
compute_puff_advantage_cuda_impl<float, float>(values, rewards, dones, importance, advantages,
107+
gamma, lambda, rho_clip, c_clip);
108+
} else if (input_dtype == torch::kBFloat16 && output_dtype == torch::kFloat32) {
109+
compute_puff_advantage_cuda_impl<at::BFloat16, float>(values, rewards, dones, importance, advantages,
110+
gamma, lambda, rho_clip, c_clip);
111+
} else if (input_dtype == torch::kBFloat16 && output_dtype == torch::kBFloat16) {
112+
compute_puff_advantage_cuda_impl<at::BFloat16, at::BFloat16>(values, rewards, dones, importance, advantages,
113+
gamma, lambda, rho_clip, c_clip);
114+
} else {
115+
TORCH_CHECK(false, "Unsupported dtype combination: inputs must be float32 or bfloat16, advantages must be float32 or bfloat16");
116+
}
117+
}
118+
84119
TORCH_LIBRARY_IMPL(pufferlib, CUDA, m) {
85120
m.impl("compute_puff_advantage", &compute_puff_advantage_cuda);
86121
}

pufferlib/extensions/cuda/kernels.cu

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1474,7 +1474,7 @@ __global__ void ppo_loss_forward_kernel_optimized(
14741474
const T* __restrict__ values_pred,
14751475
const int64_t* __restrict__ actions,
14761476
const T* __restrict__ old_logprobs,
1477-
const T* __restrict__ advantages,
1477+
const float* __restrict__ advantages,
14781478
const T* __restrict__ prio,
14791479
const T* __restrict__ values,
14801480
const T* __restrict__ returns,
@@ -1541,8 +1541,8 @@ __global__ void ppo_loss_forward_kernel_optimized(
15411541
float old_logp = float(old_logprobs[nt]);
15421542
float adv = float(advantages[nt]);
15431543
float w = float(prio[n]);
1544-
float adv_std = sqrtf(adv_var[0]);
1545-
float adv_normalized = (adv - adv_mean[0]) / (adv_std + 1e-8f);
1544+
float adv_std = sqrtf(float(adv_var[0]));
1545+
float adv_normalized = (adv - float(adv_mean[0])) / (adv_std + 1e-8f);
15461546

15471547
float logratio = new_logp - old_logp;
15481548
float ratio = __expf(logratio);
@@ -1588,14 +1588,14 @@ __global__ void ppo_loss_forward_kernel_optimized(
15881588

15891589
template<typename T>
15901590
__global__ void ppo_loss_backward_kernel_optimized(
1591-
T* __restrict__ grad_logits,
1592-
T* __restrict__ grad_values_pred,
1591+
float* __restrict__ grad_logits,
1592+
float* __restrict__ grad_values_pred,
15931593
const float* __restrict__ grad_loss,
15941594
const T* __restrict__ logits,
15951595
const T* __restrict__ values_pred,
15961596
const int64_t* __restrict__ actions,
15971597
const T* __restrict__ old_logprobs,
1598-
const T* __restrict__ advantages,
1598+
const float* __restrict__ advantages,
15991599
const T* __restrict__ prio,
16001600
const T* __restrict__ values,
16011601
const T* __restrict__ returns,
@@ -1665,8 +1665,8 @@ __global__ void ppo_loss_backward_kernel_optimized(
16651665
float v_clipped = val + fmaxf(-vf_clip_coef, fminf(vf_clip_coef, v_error));
16661666

16671667
// normalize advantage
1668-
float adv_std = sqrtf(adv_var[0]);
1669-
float adv_normalized = (adv - adv_mean[0]) / (adv_std + 1e-8f);
1668+
float adv_std = sqrtf(float(adv_var[0]));
1669+
float adv_normalized = (adv - float(adv_mean[0])) / (adv_std + 1e-8f);
16701670

16711671
// loss gradient scaling
16721672
float dL = grad_loss[0] * inv_NT;
@@ -1686,7 +1686,7 @@ __global__ void ppo_loss_backward_kernel_optimized(
16861686
} else {
16871687
d_val_pred = val_pred - ret;
16881688
}
1689-
grad_values_pred[values_idx] = T(dL * vf_coef * d_val_pred);
1689+
grad_values_pred[values_idx] = dL * vf_coef * d_val_pred;
16901690

16911691
// policy loss gradient
16921692
float ratio_clipped = fmaxf(1.0f - clip_coef, fminf(1.0f + clip_coef, ratio));
@@ -1710,7 +1710,7 @@ __global__ void ppo_loss_backward_kernel_optimized(
17101710
d_logit -= p * d_new_logp;
17111711

17121712
d_logit += d_entropy_term * p * (-entropy - logp);
1713-
grad_logits[logits_base + a * logits_stride_a] = T(d_logit);
1713+
grad_logits[logits_base + a * logits_stride_a] = d_logit;
17141714
}
17151715
}
17161716

@@ -1724,12 +1724,12 @@ inline void launch_ppo_loss_forward_optimized(
17241724
const T* values_pred,
17251725
const int64_t* actions,
17261726
const T* old_logprobs,
1727-
const T* advantages,
1727+
const float* advantages, // always fp32 for precision
17281728
const T* prio,
17291729
const T* values,
17301730
const T* returns,
1731-
const float* adv_mean,
1732-
const float* adv_var,
1731+
const float* adv_mean, // keep fp32
1732+
const float* adv_var, // keep fp32
17331733
float clip_coef,
17341734
float vf_clip_coef,
17351735
float vf_coef,
@@ -1784,19 +1784,19 @@ inline void launch_ppo_loss_forward_optimized(
17841784

17851785
template<typename T>
17861786
void launch_ppo_loss_backward_optimized(
1787-
T* grad_logits,
1788-
T* grad_values_pred,
1787+
float* grad_logits,
1788+
float* grad_values_pred,
17891789
const float* grad_loss,
17901790
const T* logits,
17911791
const T* values_pred, // added: need to read val_pred directly
17921792
const int64_t* actions,
17931793
const T* old_logprobs,
1794-
const T* advantages,
1794+
const float* advantages,
17951795
const T* prio,
17961796
const T* values,
17971797
const T* returns,
17981798
const float* adv_mean,
1799-
const float* adv_var, // variance, not std
1799+
const float* adv_var,
18001800
float clip_coef,
18011801
float vf_clip_coef,
18021802
float vf_coef,
@@ -2011,7 +2011,7 @@ __global__ void ppo_loss_backward_kernel(
20112011

20122012
// === Retrieve saved values from forward pass ===
20132013
const double* saved = saved_for_backward + idx * 5;
2014-
double new_logp = saved[0]; // new log prob of selected action
2014+
// double new_logp = saved[0]; // new log prob of selected action
20152015
double ratio = saved[1]; // exp(new_logp - old_logp)
20162016
double val_pred = saved[2]; // value prediction
20172017
double v_clipped = saved[3]; // clipped value target
@@ -2842,10 +2842,10 @@ void launch_ppo_loss_backward_optimized_float(float* grad_logits, float* grad_va
28422842
launch_ppo_loss_backward_optimized<float>(grad_logits, grad_values_pred, grad_loss, logits, values_pred, actions, old_logprobs, advantages, prio, values, returns, adv_mean, adv_var, clip_coef, vf_clip_coef, vf_coef, ent_coef, T_seq, A, N, logits_stride_n, logits_stride_t, logits_stride_a, values_stride_n, values_stride_t, stream);
28432843
}
28442844

2845-
void launch_ppo_loss_forward_optimized_bf16(float* loss_output, double* saved_for_backward, at::BFloat16* ratio_out, at::BFloat16* newvalue_out, const at::BFloat16* logits, const at::BFloat16* values_pred, const int64_t* actions, const at::BFloat16* old_logprobs, const at::BFloat16* advantages, const at::BFloat16* prio, const at::BFloat16* values, const at::BFloat16* returns, const float* adv_mean, const float* adv_var, float clip_coef, float vf_clip_coef, float vf_coef, float ent_coef, int T_seq, int A, int N, int logits_stride_n, int logits_stride_t, int logits_stride_a, int values_stride_n, int values_stride_t, cudaStream_t stream) {
2845+
void launch_ppo_loss_forward_optimized_bf16(float* loss_output, double* saved_for_backward, at::BFloat16* ratio_out, at::BFloat16* newvalue_out, const at::BFloat16* logits, const at::BFloat16* values_pred, const int64_t* actions, const at::BFloat16* old_logprobs, const float* advantages, const at::BFloat16* prio, const at::BFloat16* values, const at::BFloat16* returns, const float* adv_mean, const float* adv_var, float clip_coef, float vf_clip_coef, float vf_coef, float ent_coef, int T_seq, int A, int N, int logits_stride_n, int logits_stride_t, int logits_stride_a, int values_stride_n, int values_stride_t, cudaStream_t stream) {
28462846
launch_ppo_loss_forward_optimized<at::BFloat16>(loss_output, saved_for_backward, ratio_out, newvalue_out, logits, values_pred, actions, old_logprobs, advantages, prio, values, returns, adv_mean, adv_var, clip_coef, vf_clip_coef, vf_coef, ent_coef, T_seq, A, N, logits_stride_n, logits_stride_t, logits_stride_a, values_stride_n, values_stride_t, stream);
28472847
}
2848-
void launch_ppo_loss_backward_optimized_bf16(at::BFloat16* grad_logits, at::BFloat16* grad_values_pred, const float* grad_loss, const at::BFloat16* logits, const at::BFloat16* values_pred, const int64_t* actions, const at::BFloat16* old_logprobs, const at::BFloat16* advantages, const at::BFloat16* prio, const at::BFloat16* values, const at::BFloat16* returns, const float* adv_mean, const float* adv_var, float clip_coef, float vf_clip_coef, float vf_coef, float ent_coef, int T_seq, int A, int N, int logits_stride_n, int logits_stride_t, int logits_stride_a, int values_stride_n, int values_stride_t, cudaStream_t stream) {
2848+
void launch_ppo_loss_backward_optimized_bf16(float* grad_logits, float* grad_values_pred, const float* grad_loss, const at::BFloat16* logits, const at::BFloat16* values_pred, const int64_t* actions, const at::BFloat16* old_logprobs, const float* advantages, const at::BFloat16* prio, const at::BFloat16* values, const at::BFloat16* returns, const float* adv_mean, const float* adv_var, float clip_coef, float vf_clip_coef, float vf_coef, float ent_coef, int T_seq, int A, int N, int logits_stride_n, int logits_stride_t, int logits_stride_a, int values_stride_n, int values_stride_t, cudaStream_t stream) {
28492849
launch_ppo_loss_backward_optimized<at::BFloat16>(grad_logits, grad_values_pred, grad_loss, logits, values_pred, actions, old_logprobs, advantages, prio, values, returns, adv_mean, adv_var, clip_coef, vf_clip_coef, vf_coef, ent_coef, T_seq, A, N, logits_stride_n, logits_stride_t, logits_stride_a, values_stride_n, values_stride_t, stream);
28502850
}
28512851

0 commit comments

Comments
 (0)