Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added docs/images/perf_metrics_TB.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/images/perf_metrics_perfetto.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
310 changes: 273 additions & 37 deletions docs/metrics.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
<!-- DO NOT REMOVE! Placeholder for TOC. -->

# Metrics
Tunix provides a comprehensive observability stack for training LLMs. It
automatically collects a rich set of system and model performance metrics
out-of-the-box, covering everything from basic loss and perplexity to advanced
RL-specific signals. Furthermore, Tunix offers a flexible, protocol-based
logging system that allows you to seamlessly integrate with your preferred
logging service or library.
Tunix provides a comprehensive observability stack for training LLMs,
encompassing everything from basic training metrics to detailed execution
traces. The section is composed of three main pillars:

* **[Collected Metrics](#collected-metrics)**: describing a rich
set of system, model, and RL-specific performance metrics out-of-the-box.
* **[Metric Loggers](#metric-loggers)**: describing a flexible,
protocol-based logging system that allows you to seamlessly integrate with
your preferred logging service (e.g., TensorBoard, Wandb, CLU) or create
custom backends.
* **[Performance Metric Tracing](#performance-metric-tracing)**: describing a
built-in, lightweight tracing system that generates detailed execution
timelines for deep performance analysis and visualization in Perfetto.

## Collected Metrics

Expand All @@ -15,49 +22,71 @@ monitor performance, convergence, and resource utilization.

### Common Metrics (SFT & RL)

These metrics are collected for both Supervised Fine-Tuning (SFT) and Reinforcement Learning (RL) jobs:
These metrics are collected for both Supervised Fine-Tuning (SFT) and
Reinforcement Learning (RL) jobs:

* **`loss`**: The training loss for the current step.
* **`perplexity`**: The perplexity of the model on the training batch (exp(loss)).
* **`perplexity`**: The perplexity of the model on the training batch
(exp(loss)).
* **`learning_rate`**: The current learning rate from the optimizer.
* **`step_time_sec`**: The time taken to execute a single training step (in seconds).
* **`step_time_sec`**: The time taken to execute a single training step (in
seconds).
* **`steps_per_sec`**: The training speed, measured in steps per second.
* **`tflops_per_step`**: The estimated Trillion Floating Point Operations (TFLOPs) performed per step (if supported by the hardware/backend).
* **`tflops_per_step`**: The estimated Trillion Floating Point Operations
(TFLOPs) performed per step (if supported by the hardware/backend).

### RL-Specific Metrics (PPO/GRPO)

For Reinforcement Learning jobs, Tunix collects additional metrics related to the
RL algorithm (e.g., PPO), reward modeling, and generation.
For Reinforcement Learning jobs, Tunix collects additional metrics related to
the RL algorithm (e.g., PPO), reward modeling, and generation.

#### Rewards & Scores
* **`rewards/sum`**: The sum of rewards for a trajectory.
* **`rewards/mean`**, **`rewards/max`**, **`rewards/min`**: Statistics of the rewards across the batch.
* **`score/mean`**, **`score/max`**, **`score/min`**: Statistics of the raw scores from the reward model (before any algorithm-specific modifications like KL penalty).
* **`reward_kl_penalty`**: The KL divergence penalty applied to the reward (if applicable).
* **`rewards/<reward_fn_name>`**: If using multiple reward functions, individual reward components are logged by name.
* **`rewards/mean`**, **`rewards/max`**, **`rewards/min`**: Statistics of the
rewards across the batch.
* **`score/mean`**, **`score/max`**, **`score/min`**: Statistics of the raw
scores from the reward model (before any algorithm-specific modifications
like KL penalty).
* **`reward_kl_penalty`**: The KL divergence penalty applied to the reward
(if applicable).
* **`rewards/<reward_fn_name>`**: If using multiple reward functions,
individual reward components are logged by name.

#### Policy & Value (PPO)
* **`advantages/mean`**, **`advantages/max`**, **`advantages/min`**: Statistics of the advantages.
* **`returns/mean`**, **`returns/max`**, **`returns/min`**: Statistics of the returns.
* **`values/mean`**, **`values/max`**, **`values/min`**: Statistics of the value function estimates.
* **`pg_clipfrac`**: The fraction of the batch where the policy gradient was clipped.
* **`vf_clipfrac`**: The fraction of the batch where the value function update was clipped.
* **`loss/entropy`**: The entropy of the policy (if entropy regularization is enabled).
* **`advantages/mean`**, **`advantages/max`**, **`advantages/min`**:
Statistics of the advantages.
* **`returns/mean`**, **`returns/max`**, **`returns/min`**: Statistics of
the returns.
* **`values/mean`**, **`values/max`**, **`values/min`**: Statistics of the
value function estimates.
* **`pg_clipfrac`**: The fraction of the batch where the policy gradient was
clipped.
* **`vf_clipfrac`**: The fraction of the batch where the value function update
was clipped.
* **`loss/entropy`**: The entropy of the policy (if entropy regularization is
enabled).

#### Generation & Data
* **`prompts`**: The input prompts used for generation.
* **`completions`**: The text completions generated by the model.
* **`completions/mean_length`**, **`completions/max_length`**, **`completions/min_length`**: Statistics on the length of generated completions.
* **`completions/mean_length`**, **`completions/max_length`**,
**`completions/min_length`**: Statistics on the length of generated
completions.
* **`trajectory_ids`**: Unique identifiers for the trajectories.
* **`actor_dequeue_time`**: Time spent waiting for data from the rollout workers (if async rollout is enabled).
* **`actor_dequeue_time`**: Time spent waiting for data from the rollout
workers (if async rollout is enabled).

## Metric Loggers

Tunix provides a flexible, protocol-based logging system that allows you to
integrate any logging service or library.

The primary interface for logging is the `MetricsLogger`. It is configured using `MetricsLoggerOptions`.
Below is an example of how to configure the `MetricsLogger`. **Note**: The exact fields that need to be configured depend on the backends, which typically default based on the execution environment. See [Logging Backends Supported](#logging-backends-supported) for details on backend-specific configurations.
The primary interface for logging is the `MetricsLogger`. It is configured
using `MetricsLoggerOptions`. Below is an example of how to configure the
`MetricsLogger`. **Note**: The exact fields that need to be configured depend
on the backends, which typically default based on the execution environment. See
[Logging Backends Supported](#logging-backends-supported) for details on
backend-specific configurations.


```python
Expand All @@ -73,8 +102,8 @@ logger = metrics_logger.MetricsLogger(metrics_logger_options=options)

### Enabling Metrics in Jobs

Once you have your `MetricsLoggerOptions` configured, you can pass it to your SFT
or RL job via the training configuration.
Once you have your `MetricsLoggerOptions` configured, you can pass it to your
SFT or RL job via the training configuration.

#### Supervised Fine-Tuning (SFT)

Expand Down Expand Up @@ -104,7 +133,8 @@ trainer = peft_trainer.PeftTrainer(

#### Reinforcement Learning (RL)

For RL, pass the `metrics_logging_options` to the `RLTrainingConfig`, which is then used in `ClusterConfig`.
For RL, pass the `metrics_logging_options` to the `RLTrainingConfig`, which is
then used in `ClusterConfig`.

```python
from tunix.rl import rl_cluster
Expand Down Expand Up @@ -137,20 +167,22 @@ cluster = rl_cluster.RLCluster(

### Logging Backends Supported

Tunix supports several logging backends out of the box, powered by `metrax` [link](https://github.com/google/metrax/ ). The
default backend selection depends on the execution environment.
Tunix supports several logging backends out of the box, powered by `metrax`
[link](https://github.com/google/metrax/ ). The default backend selection
depends on the execution environment.

#### Wandb

[Weights & Biases](https://wandb.ai/) is a supported backend for experiment tracking. ([Backend Code](https://github.com/google/metrax/blob/main/src/metrax/logging/wandb_backend.py))

* **Availability**: *Enabled by default* in external environments (if `wandb` is installed).
* **Availability**: *Enabled by default* in external environments (if `wandb`
is installed).
* **Configuration**:
* `project_name`: Sets the Wandb project name (default: "tunix").
* `run_name`: Sets the specific run name. If not provided, it defaults to
a timestamp (e.g., `2025-01-14_08-40-01`). **Note:** Wandb distinguishes
between a run name and a run id. Runs with the same name are tracked as
separate entities differentiated by thier run id.
separate entities differentiated by their run id.

#### TensorBoard

Expand All @@ -169,8 +201,8 @@ to the `metrax.logging.LoggingBackend` protocol.
#### 1. The Protocol

Your custom backend class need only needs to implement `log_scalar` and `close`.
Explicit inheritance from a base class is not required since Metrax uses Python's
structural typing (duck typing).
Explicit inheritance from a base class is not required since Metrax uses
Python's structural typing (duck typing).

```python
from typing import Protocol
Expand Down Expand Up @@ -207,11 +239,14 @@ class SimplePrintBackend:

#### 3. Using Your Custom Backend

To use your custom backend, you must pass a **factory** (a callable that returns an instance) to `MetricsLoggerOptions`. This ensures configuration objects remain serializable and safe to copy.
To use your custom backend, you must pass a **factory** (a callable that returns
an instance) to `MetricsLoggerOptions`. This ensures configuration objects
remain serializable and safe to copy.

##### Case A: Simple Backend (No Arguments)

If your backend class requires no arguments in its `__init__`, you can simply pass the class itself.
If your backend class requires no arguments in its `__init__`, you can simply
pass the class itself.

```python
options = metrics_logger.MetricsLoggerOptions(
Expand Down Expand Up @@ -247,3 +282,204 @@ options = metrics_logger.MetricsLoggerOptions(

logger = metrics_logger.MetricsLogger(metrics_logger_options=options)
```

## Performance Metric Tracing

Tunix provides a highly lightweight performance tracing and metrics collection
system designed specifically for RL workflows. Unlike detailed profiling tools
(e.g., xprof or standard JAX profiling), which collect exhaustive low-level
details but incur significant overhead and are typically only used for short
debugging intervals (e.g., 10s of seconds), this tracing system is designed
with minimal overhead. It can safely be left enabled for the entire duration
of your training run. It allows you to monitor the execution time of different
stages (e.g., rollout, actor training, reference inference) across both host
and device timelines.

> **Note:** Performance metric tracing is currently only supported for the GRPO
> main entry point

There are currently two versions of the performance metrics system:

* **Original Version (v1)** ([Code](https://github.com/google/tunix/blob/main/perf/export.py)):
Uses `PerfSpanQuery` to extract spans and compute metrics (e.g., rollout
time, wait time).
* **Experimental Version (v2)** ([Code](https://github.com/google/tunix/blob/main/perf/experimental/export.py)):
A more flexible version that tracks `Timeline` objects and is planned to
replace v1.

Both versions can export metrics to your custom export function and write
detailed Perfetto trace files that can be visualized at
[ui.perfetto.dev](https://ui.perfetto.dev/).

### Using Performance Metrics via CLI

When running Tunix via the CLI, you can configure performance metrics by
providing a `perf_metrics_options` dictionary inside your `rl_training_config`.

```yaml
rl_training_config:
perf_metrics_options:
enable_perf_v1: true # Enable v1 (default: true)
enable_perf_v2: false # Enable v2 (default: false)
enable_trace_writer: true # Enable writing Perfetto trace files (default: true)
log_dir: "/tmp/perf_trace" # Directory to write the trace files to
custom_export_fn_path: "path.to.my.custom_fn" # Optional path to a custom v1 export function
custom_export_fn_path_v2: "path.to.my.custom_fn_v2" # Optional path to a custom v2 export function
```

The CLI automatically parses these options, initializes the appropriate export
functions, and registers them with the training cluster.

Note that `enable_perf_v1` and `enable_perf_v2` can be toggled independently,
allowing you to use one or both systems simultaneously. If you wish to use a
custom export function instead of the defaults, you must provide the fully
qualified import path to your function via `custom_export_fn_path` (for v1) and
`custom_export_fn_path_v2` (for v2).

### Using Performance Metrics via Code

If you are initializing the `RLCluster` programmatically, you must construct a
`PerfMetricsConfig` and pass it to the cluster manually.

#### Original Version (v1)

In v1, use `PerfMetricsExport.from_cluster_config()` to generate a default
export function. This function automatically computes various duration metrics
based on the cluster's mesh topology (e.g., whether the rollout and actor models
are collocated or on different TPU meshes
[code](https://github.com/google/tunix/blob/main/perf/export.py;l=102).
The metrics are aggregated per `global_step` and use the
[Metric Logger](#metric-loggers) to log to the desired output. For example,
if Tensorboard is activated:

![Perf Metrics Tensorboard](images/perf_metrics_TB.png)

By default, v1 also writes detailed execution traces to Perfetto Proto formatted
file. It reads `perf_metrics_options` from the cluster configuration to
initialize the trace writer. You can specify the output directory by configuring
`log_dir` within `PerfMetricsOptions` inside your `RLTrainingConfig`.

```python
from tunix.perf import metrics as perf_metrics
from tunix.perf import export as perf_export
from tunix.rl import rl_cluster
from tunix.sft import metrics_logger



# 1. Define metric logger options (for collecting aggregate perf metrics).
metric_logger_options = metrics_logger.MetricsLoggerOptions(
log_dir="/tmp/tensorboard/grpo",
project_name="my-rl-project",
)

training_config = rl_cluster.RLTrainingConfig(
metrics_logging_options=metric_logger_options,
# ... other configurations
)

cluster_config = rl_cluster.ClusterConfig(
training_config=training_config,
# ... other configurations
)

# 2. Create a PerfMetricsConfig object.
perf_config = perf_metrics.PerfMetricsConfig()
perf_config.custom_export_fn = (
perf_export.PerfMetricsExport.from_cluster_config(cluster_config)
)


# 4. Pass the config to the RLCluster.
cluster = rl_cluster.RLCluster(
actor=actor_model,
tokenizer=tokenizer,
cluster_config=cluster_config,
perf_config=perf_config,
)
```

#### Experimental Version (v2)

For the experimental version, you can use the default export function which
writes the raw timelines to a Perfetto trace file by using the
`PerfMetricExport` class. You will need to define the `trace_dir` as the
location for the file to be written to. Note that the v2 is still experimental
and the additional capabilities such as exporting aggregated metrics to
Tensorboard are WIP. Once the functionlity is complete, v2 will be replacing
the original version.

```python
from tunix.perf import metrics as perf_metrics
from tunix.perf.experimental import export as perf_export_v2
from tunix.rl import rl_cluster

# 1. Create a PerfMetricsConfig object.
perf_config = perf_metrics.PerfMetricsConfig()

# 2. Create the v2 metrics export function, specifying the trace directory.
perf_config.custom_export_fn_v2 = (
perf_export_v2.PerfMetricsExport(trace_dir="/tmp/perf_trace").export_metrics
)

# 3. Pass the config to the RLCluster.
cluster = rl_cluster.RLCluster(
actor=actor_model,
tokenizer=tokenizer,
cluster_config=cluster_config,
perf_config=perf_config,
)
```

### Custom Export Functions

If you want to compute custom metrics from the collected spans instead of using
the defaults, you can define and provide your own export function.

**Custom Export Function for v1:**
A v1 export function takes a `PerfSpanQuery` and returns a dictionary of
metrics.

```python
from tunix.perf import metrics as perf_metrics

def my_custom_export_fn(query: perf_metrics.PerfSpanQuery) -> perf_metrics.MetricsT:
# Example: query main thread for the latest 'global_step' group
global_steps = query().main().last_group("global_step").get()
if global_steps:
# MetricsT maps metric names to (value, optional_aggregation_fn)
return {"perf/custom_step_time": (global_steps[0].duration, None)}
return {}

perf_config.custom_export_fn = my_custom_export_fn
```

**Custom Export Function for v2:**
A v2 export function takes a mapping of timeline IDs to their respective
`Timeline` objects.

```python
from tunix.perf import metrics as perf_metrics
from tunix.perf.experimental import tracer

def my_custom_export_fn_v2(timelines: dict[str, tracer.Timeline]) -> perf_metrics.MetricsT:
# Example: iterate over host and device timelines
for tl_id, timeline in timelines.items():
pass # Analyze timeline.root span
return {}

perf_config.custom_export_fn_v2 = my_custom_export_fn_v2
```

### Visualizing with Perfetto

If you have enabled the trace writer (by setting `enable_trace_writer: true` via
the CLI or by specifying `trace_dir` in your configuration), a proto-formatted
file (e.g., `perfetto_trace_1771973518.pb`) containing the raw spans and
timelines will be saved to the specified directory (which defaults to
`/tmp/perf_traces` in v1). To view the trace, download the file to your local
machine and drag-and-drop it into the
[Perfetto UI](https://ui.perfetto.dev/). The interface allows you to
interactively zoom, pan, and query the execution trace, as shown below:

![Perf Metrics Perfetto](images/perf_metrics_perfetto.png)
Loading