Skip to content

Commit 07ccaea

Browse files
committed
Un-nest graphing
1 parent e359766 commit 07ccaea

File tree

4 files changed

+489
-544
lines changed

4 files changed

+489
-544
lines changed

pufferlib/extensions/env_binding.c

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,14 @@ void static_vec_omp_step(StaticVec* vec) {
142142
}
143143
}
144144

145+
void static_vec_seq_step(StaticVec* vec) {
146+
StaticThreading* threading = vec->threading;
147+
for (int buf = 0; buf < vec->buffers; buf++) {
148+
atomic_store(&threading->buffer_states[buf], OMP_RUNNING);
149+
while (atomic_load(&threading->buffer_states[buf]) != OMP_WAITING) {}
150+
}
151+
}
152+
145153
// Optional: Initialize all envs at once (for shared state, variable agents per env, etc.)
146154
// Default implementation creates envs until total_agents is reached
147155
#ifndef MY_VEC_INIT

pufferlib/extensions/env_binding.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ void static_vec_log(StaticVec* vec, Dict* out);
109109
void create_static_threads(StaticVec* vec, int num_threads, int horizon,
110110
void* ctx, net_callback_fn net_callback, thread_init_fn thread_init);
111111
void static_vec_omp_step(StaticVec* vec);
112+
void static_vec_seq_step(StaticVec* vec);
112113

113114
// Env info
114115
int get_obs_size(void);

pufferlib/extensions/models.cpp

Lines changed: 128 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -138,10 +138,8 @@ struct RNN : public nn::Module {
138138
virtual Tensor initial_state(int batch_size, torch::Device device, torch::Dtype dtype) = 0;
139139
};
140140

141-
class MinGRU : public RNN {
142-
public:
143-
int hidden;
144-
int num_layers;
141+
struct MinGRU : public RNN {
142+
int hidden, num_layers;
145143
bool kernels;
146144
vector<nn::Linear> layers;
147145

@@ -196,8 +194,7 @@ class MinGRU : public RNN {
196194
}
197195
};
198196

199-
class Policy : public nn::Module {
200-
public:
197+
struct Policy : public nn::Module {
201198
int input, hidden, num_atns;
202199
shared_ptr<Encoder> encoder{nullptr};
203200
shared_ptr<Decoder> decoder{nullptr};
@@ -448,69 +445,143 @@ Tensor logcumsumexp_cpp(Tensor x) {
448445
return x.exp().cumsum(1).log();
449446
}
450447

451-
Tensor fused_ppo_loss_cpp(
452-
Tensor logits,
453-
Tensor newvalue,
454-
Tensor actions,
455-
Tensor old_logprobs,
456-
Tensor advantages,
457-
Tensor prio,
458-
Tensor values,
459-
Tensor returns,
460-
Tensor adv_mean,
461-
Tensor adv_std,
462-
float clip_coef,
463-
float vf_clip_coef,
464-
float vf_coef,
465-
float ent_coef
466-
) {
467-
auto segments = logits.size(0);
468-
auto horizon = logits.size(1);
448+
// Sample from multi-head discrete distribution
449+
// Returns {actions (B, heads), total_logprob (B,)}
450+
vector<Tensor> sample_discrete_cpp(Tensor logits, Tensor act_sizes_cpu, int num_heads) {
451+
logits = torch::nan_to_num(logits, 1e-8, 1e-8, 1e-8);
452+
auto split = torch::split(logits, c10::IntArrayRef(act_sizes_cpu.data_ptr<int64_t>(), num_heads), 1);
453+
vector<Tensor> actions_vec, logprobs_vec;
454+
for (int i = 0; i < num_heads; i++) {
455+
auto log_probs = torch::log_softmax(split[i], 1);
456+
auto action = at::multinomial(log_probs.exp(), 1, true);
457+
actions_vec.push_back(action);
458+
logprobs_vec.push_back(log_probs.gather(1, action));
459+
}
460+
return {torch::cat(actions_vec, 1), torch::cat(logprobs_vec, 1).sum(1)};
461+
}
469462

470-
auto flat_logits = logits.reshape({-1, logits.size(-1)});
471-
auto flat_actions = actions.reshape({-1});
472-
auto logprobs_new = torch::log_softmax(flat_logits, 1);
463+
// Sample from continuous Normal distribution
464+
// Returns {actions (B, D), total_logprob (B,)}
465+
vector<Tensor> sample_continuous_cpp(Tensor mean, Tensor logstd) {
466+
auto std = logstd.exp();
467+
auto actions = mean + std * torch::randn_like(mean);
468+
auto log_prob = -0.5 * ((actions - mean) / std).pow(2) - 0.5 * std::log(2 * M_PI) - logstd;
469+
return {actions, log_prob.sum(1)};
470+
}
473471

474-
auto probs_new = logprobs_new.exp();
475-
auto entropy = -(probs_new * logprobs_new).sum(1).mean();
472+
// Compute logprob + entropy for multi-head discrete actions
473+
// Returns {logprob (batch,), entropy scalar}
474+
vector<Tensor> discrete_logprob_entropy_cpp(Tensor logits, Tensor actions, Tensor act_sizes_cpu, int num_heads) {
475+
logits = torch::nan_to_num(logits, 1e-8, 1e-8, 1e-8);
476+
auto split = torch::split(logits, c10::IntArrayRef(act_sizes_cpu.data_ptr<int64_t>(), num_heads), 1);
477+
int batch = logits.size(0);
478+
vector<Tensor> logprobs_vec, entropies_vec;
479+
for (int h = 0; h < num_heads; h++) {
480+
auto log_probs = torch::log_softmax(split[h], 1);
481+
auto probs = log_probs.exp();
482+
auto head_actions = actions.select(-1, h).reshape({batch}).to(torch::kInt64);
483+
logprobs_vec.push_back(log_probs.gather(1, head_actions.unsqueeze(1)));
484+
entropies_vec.push_back(-(probs * log_probs).sum(1, true));
485+
}
486+
auto logprob = torch::cat(logprobs_vec, 1).sum(1);
487+
auto entropy = torch::cat(entropies_vec, 1).sum(1).mean();
488+
return {logprob, entropy};
489+
}
476490

477-
auto newlogprob_flat = logprobs_new.gather(1, flat_actions.unsqueeze(1)).squeeze(1);
478-
auto newlogprob = newlogprob_flat.reshape({segments, horizon});
479-
auto logratio = newlogprob - old_logprobs;
480-
auto ratio_new = logratio.exp();
491+
// Compute logprob + entropy for continuous Normal actions
492+
// Returns {logprob (batch,), entropy scalar}
493+
vector<Tensor> continuous_logprob_entropy_cpp(Tensor mean, Tensor logstd, Tensor actions) {
494+
auto std = logstd.exp();
495+
auto normalized = (actions.to(mean.dtype()) - mean) / std;
496+
auto log_prob = -0.5 * normalized.pow(2) - 0.5 * std::log(2 * M_PI) - logstd;
497+
auto logprob = log_prob.sum(1);
498+
constexpr float HALF_1_PLUS_LOG_2PI = 1.4189385332046727f;
499+
auto entropy = (HALF_1_PLUS_LOG_2PI + logstd).sum(1).mean();
500+
return {logprob, entropy};
501+
}
481502

482-
auto adv_normalized = prio.unsqueeze(1) * (advantages - adv_mean) / (adv_std + 1e-8);
483-
auto pg_loss1 = -adv_normalized * ratio_new;
484-
auto pg_loss2 = -adv_normalized * torch::clamp(ratio_new, 1.0 - clip_coef, 1.0 + clip_coef);
503+
// PPO clipped loss with clipped value loss
504+
Tensor ppo_loss_cpp(Tensor ratio, Tensor advantages, Tensor prio,
505+
Tensor newvalue, Tensor values, Tensor returns, Tensor entropy,
506+
float clip_coef, float vf_clip_coef, float vf_coef, float ent_coef) {
507+
auto adv_normalized = prio * (advantages - advantages.mean()) / (advantages.std() + 1e-8);
508+
auto pg_loss1 = -adv_normalized * ratio;
509+
auto pg_loss2 = -adv_normalized * torch::clamp(ratio, 1.0 - clip_coef, 1.0 + clip_coef);
485510
auto pg_loss = torch::max(pg_loss1, pg_loss2).mean();
486511

487-
auto nv = newvalue.view(returns.sizes());
488-
auto v_clipped = values + torch::clamp(nv - values, -vf_clip_coef, vf_clip_coef);
489-
auto v_loss_unclipped = (nv - returns).pow(2);
490-
auto v_loss_clipped = (v_clipped - returns).pow(2);
491-
auto v_loss = 0.5 * torch::max(v_loss_unclipped, v_loss_clipped).mean();
512+
newvalue = newvalue.view(returns.sizes());
513+
auto v_clipped = values + torch::clamp(newvalue - values, -vf_clip_coef, vf_clip_coef);
514+
auto v_loss = 0.5 * torch::max((newvalue - returns).pow(2), (v_clipped - returns).pow(2)).mean();
492515

493516
return pg_loss + vf_coef * v_loss - ent_coef * entropy;
494517
}
495518

496-
// Reference implementation for sample_logits (for correctness testing)
497-
vector<Tensor> sample_logits_cpp(
498-
Tensor logits
499-
) {
500-
// nan_to_num
501-
auto clean_logits = torch::nan_to_num(logits);
502-
503-
// log_softmax
504-
auto log_probs = torch::log_softmax(clean_logits, 1);
505-
506-
// multinomial sampling
507-
auto probs = log_probs.exp();
508-
auto actions = torch::multinomial(probs, 1, /*replacement=*/false).squeeze(1);
519+
// Dispatch: sample actions using kernel or cpp path, write to output buffers
520+
void sample_actions(Logits& logits, Tensor value,
521+
Tensor actions_out, Tensor logprobs_out, Tensor values_out,
522+
Tensor act_sizes, Tensor act_sizes_cpu,
523+
bool is_continuous, bool kernels, uint64_t rng_seed, Tensor rng_offset) {
524+
if (kernels) {
525+
Tensor logstd = logits.logstd.defined() ? logits.logstd : Tensor();
526+
sample_logits(logits.mean, logstd, value, actions_out, logprobs_out,
527+
values_out, act_sizes, rng_seed, rng_offset);
528+
} else {
529+
vector<Tensor> result;
530+
if (is_continuous) {
531+
result = sample_continuous_cpp(logits.mean, logits.logstd);
532+
} else {
533+
result = sample_discrete_cpp(logits.mean, act_sizes_cpu, actions_out.size(1));
534+
}
535+
actions_out.copy_(result[0].to(torch::kFloat64), false);
536+
logprobs_out.copy_(result[1], false);
537+
values_out.copy_(value.flatten(), false);
538+
}
539+
}
509540

510-
// gather logprobs
511-
auto sampled_logprobs = log_probs.gather(1, actions.unsqueeze(1)).squeeze(1);
541+
// Dispatch: compute PPO loss using kernel or cpp path
542+
// Writes ratio and newvalue to output buffers as side effect
543+
Tensor compute_train_loss(Logits& logits, Tensor newvalue,
544+
Tensor actions, Tensor old_logprobs, Tensor advantages, Tensor prio,
545+
Tensor values, Tensor returns,
546+
Tensor ratio_out, Tensor newvalue_out,
547+
Tensor act_sizes, Tensor act_sizes_cpu,
548+
int minibatch_size, int horizon,
549+
float clip_coef, float vf_clip_coef, float vf_coef, float ent_coef,
550+
bool is_continuous, bool kernels) {
551+
if (kernels) {
552+
Tensor logstd_safe = logits.logstd.defined() ? logits.logstd : torch::empty({0}, logits.mean.options());
553+
auto [adv_var, adv_mean] = torch::var_mean(advantages);
554+
return fused_ppo_loss_optimized(
555+
logits.mean, logstd_safe, newvalue,
556+
actions, old_logprobs, advantages, prio, values, returns,
557+
adv_mean, adv_var, // variance, not std - kernel does sqrtf
558+
ratio_out, newvalue_out,
559+
act_sizes, clip_coef, vf_clip_coef, vf_coef, ent_coef
560+
)[0];
561+
} else {
562+
int num_heads = actions.size(-1);
563+
int batch = minibatch_size;
564+
int segments = batch / horizon;
565+
566+
vector<Tensor> result;
567+
if (is_continuous) {
568+
TORCH_CHECK(logits.logstd.defined() && logits.logstd.numel() > 0,
569+
"logstd must be defined for continuous actions");
570+
result = continuous_logprob_entropy_cpp(
571+
logits.mean.reshape({batch, -1}), logits.logstd.reshape({batch, -1}),
572+
actions.reshape({batch, -1}));
573+
} else {
574+
result = discrete_logprob_entropy_cpp(
575+
logits.mean.reshape({batch, -1}), actions, act_sizes_cpu, num_heads);
576+
}
577+
Tensor ratio = (result[0].reshape({segments, horizon}) - old_logprobs).exp();
578+
ratio_out.copy_(ratio, false);
579+
newvalue_out.copy_(newvalue, false);
512580

513-
return {actions, sampled_logprobs};
581+
return ppo_loss_cpp(ratio, advantages, prio,
582+
newvalue, values, returns, result[1],
583+
clip_coef, vf_clip_coef, vf_coef, ent_coef);
584+
}
514585
}
515586

516587
// Reference implementation for testing

0 commit comments

Comments
 (0)