Skip to content

Commit 576bc48

Browse files
committed
stage
1 parent 0dfbc58 commit 576bc48

File tree

112 files changed

+792
-1962
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

112 files changed

+792
-1962
lines changed

.pre-commit-config.yaml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,12 @@ repos:
1111
- id: check-merge-conflict
1212
- id: detect-private-key
1313

14-
# - repo: https://github.com/psf/black
15-
# rev: 23.7.0
16-
# hooks:
17-
# - id: black
18-
# language_version: python3.10
19-
# args: [--line-length=100]
14+
- repo: https://github.com/psf/black
15+
rev: 23.7.0
16+
hooks:
17+
- id: black
18+
language_version: python3.10
19+
args: [--line-length=999999]
2020

2121
# - repo: https://github.com/pycqa/isort
2222
# rev: 5.12.0

ajet/__init__.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,6 @@
44
from ajet.workflow import Workflow
55
from ajet.utils.vsdb import vscode_conditional_breakpoint as bp
66

7-
__all__ = [
8-
"Workflow",
9-
"WorkflowTask",
10-
"WorkflowOutput",
11-
"AjetTuner",
12-
"AgentJetJob",
13-
"bp"
14-
]
7+
__all__ = ["Workflow", "WorkflowTask", "WorkflowOutput", "AjetTuner", "AgentJetJob", "bp"]
158

169
__version__ = "0.1.0"

ajet/backbone/main_trinity.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def patched_trainer_get_actor(cls, config: Config):
5454

5555
if ajet_config.ajet.enable_experimental_interchange_server:
5656
from ajet.tuner_lib.weight_tuner.experimental.as_oai_model_server import start_interchange_server
57+
5758
start_interchange_server(ajet_config)
5859

5960

ajet/backbone/main_verl.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -66,16 +66,10 @@ def run_ppo(config) -> None:
6666

6767
# Create a remote instance of the TaskRunner class, and
6868
# Execute the `run` method of the TaskRunner instance remotely and wait for it to complete
69-
if (
70-
is_cuda_available
71-
and config.trainer.get("profile_steps") is not None
72-
and len(config.trainer.get("profile_steps", [])) > 0
73-
):
69+
if is_cuda_available and config.trainer.get("profile_steps") is not None and len(config.trainer.get("profile_steps", [])) > 0:
7470
from verl.utils.import_utils import is_nvtx_available
7571

76-
assert (
77-
is_nvtx_available()
78-
), "nvtx is not available in CUDA platform. Please 'pip3 install nvtx'"
72+
assert is_nvtx_available(), "nvtx is not available in CUDA platform. Please 'pip3 install nvtx'"
7973
nsight_options = OmegaConf.to_container(config.trainer.controller_nsight_options)
8074
runner = TaskRunner.options(runtime_env={"nsight": nsight_options}).remote()
8175
else:
@@ -223,9 +217,7 @@ def run(self, config):
223217
num_examine=1,
224218
**config.reward_model.get("reward_kwargs", {}),
225219
)
226-
resource_pool_manager = ResourcePoolManager(
227-
resource_pool_spec=resource_pool_spec, mapping=mapping
228-
)
220+
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
229221

230222
from verl.utils.dataset.rl_dataset import collate_fn
231223

@@ -248,6 +240,7 @@ def run(self, config):
248240

249241
if config.ajet.enable_experimental_interchange_server:
250242
from ajet.tuner_lib.weight_tuner.experimental.as_oai_model_server import start_interchange_server
243+
251244
start_interchange_server(config)
252245

253246
# Initialize the PPO trainer.

ajet/backbone/main_vllm.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,7 @@ def submit_chat_completions(self, messages, sampling_params, request_id, tools=[
8282
"request_id": completion.id,
8383
"content": message["content"],
8484
"tool_calls": message.get("tool_calls", None),
85-
"tokens": [
86-
TokenAndProbVllmDebug(t) for t in completion.choices[0].logprobs.content # type: ignore
87-
],
85+
"tokens": [TokenAndProbVllmDebug(t) for t in completion.choices[0].logprobs.content], # type: ignore
8886
}
8987
)
9088
return messages
@@ -130,13 +128,12 @@ async def submit_chat_completions_async(self, messages, sampling_params, request
130128
"request_id": completion.id,
131129
"content": message["content"],
132130
"tool_calls": message.get("tool_calls", None),
133-
"tokens": [
134-
TokenAndProbVllmDebug(t) for t in completion.choices[0].logprobs.content # type: ignore
135-
],
131+
"tokens": [TokenAndProbVllmDebug(t) for t in completion.choices[0].logprobs.content], # type: ignore
136132
}
137133
)
138134
return messages
139135

136+
140137
def run(config):
141138
from ajet.task_reader import RouterTaskReader
142139

@@ -147,9 +144,7 @@ def run(config):
147144
vllm_port = config.ajet.debug.debug_vllm_port
148145

149146
# --------- init ---------
150-
async_rollout_manager = ChatCompletionScheduler(
151-
config=config, url=f"http://localhost:{vllm_port}/v1"
152-
)
147+
async_rollout_manager = ChatCompletionScheduler(config=config, url=f"http://localhost:{vllm_port}/v1")
153148
parallel_env = VerlRolloutManager(
154149
config=config,
155150
async_rollout_manager=async_rollout_manager,
@@ -159,16 +154,13 @@ def run(config):
159154
tokenizer=async_rollout_manager.tokenizer,
160155
)
161156

162-
163157
task_reader = RouterTaskReader(
164158
config.ajet.task_reader.type,
165159
config.ajet.task_reader,
166160
)
167161
tasks = task_reader.get_validation_tasks()
168162
logger.info(tasks[:n_task])
169-
ctx_tracker = parallel_env.rollout(
170-
tasks=tasks[:n_task], mode="sample", epoch="1"
171-
) # "sample" or "validate"
163+
ctx_tracker = parallel_env.rollout(tasks=tasks[:n_task], mode="sample", epoch="1") # "sample" or "validate"
172164
_ = parallel_env.to_dataproto(ctx_tracker)
173165

174166

@@ -179,13 +171,15 @@ def run(config):
179171
)
180172
def main(config):
181173
from omegaconf import OmegaConf
174+
182175
OmegaConf.resolve(config)
183176
runtime_env = get_runtime_env(config)
184177
os.environ.update(runtime_env["env_vars"])
185178
# atexit.register(lambda: print("Process exiting, performing cleanup..."))
186179

187180
if config.ajet.enable_experimental_interchange_server:
188181
from ajet.tuner_lib.weight_tuner.experimental.as_oai_model_server import start_interchange_server
182+
189183
start_interchange_server(config)
190184

191185
def companion_launch():
@@ -198,9 +192,7 @@ def companion_launch():
198192
tensor_parallel_size = config.ajet.debug.debug_tensor_parallel_size
199193
n_avail_gpus = torch.cuda.device_count()
200194
if tensor_parallel_size > n_avail_gpus:
201-
logger.info(
202-
f"Warning: tensor_parallel_size {tensor_parallel_size} is greater than available GPUs {n_avail_gpus}. Setting tensor_parallel_size to {n_avail_gpus}."
203-
)
195+
logger.info(f"Warning: tensor_parallel_size {tensor_parallel_size} is greater than available GPUs {n_avail_gpus}. Setting tensor_parallel_size to {n_avail_gpus}.")
204196
tensor_parallel_size = n_avail_gpus
205197
gpu_memory_utilization = config.actor_rollout_ref.rollout.gpu_memory_utilization
206198
max_num_seqs = config.actor_rollout_ref.rollout.max_num_seqs

ajet/backbone/trainer_trinity.py

Lines changed: 12 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def __init__(
6666

6767
def convert_task(self, task: TrinityTask):
6868
from ajet.schema.task import Task
69+
6970
assert isinstance(task.raw_task, dict)
7071
return dict_to_ajet_task(task.raw_task)
7172

@@ -150,16 +151,10 @@ async def run_async(self):
150151
"madness": tracker.reward_structure.madness,
151152
}
152153

153-
if (
154-
len(response_ids) + len(prompt_ids) == len(input_ids)
155-
and len(logprobs) == len(response_ids)
156-
and len(logprobs) > 0
157-
):
154+
if len(response_ids) + len(prompt_ids) == len(input_ids) and len(logprobs) == len(response_ids) and len(logprobs) > 0:
158155
exp = Experience(
159156
tokens=input_ids, # [seq_length] prompt + response
160-
prompt_length=len(
161-
prompt_ids
162-
), # Length of the prompt in tokens, used for generating attention masks
157+
prompt_length=len(prompt_ids), # Length of the prompt in tokens, used for generating attention masks
163158
logprobs=logprobs, # [resp_length]
164159
reward=reward, #
165160
# advantages=None,
@@ -211,19 +206,11 @@ def __init__(self, config):
211206
if "train" in self.split:
212207
dataset_segments.append(task_to_standard_dataset(task_reader.get_training_tasks()))
213208
if "val" in self.split:
214-
dataset_segments.append(
215-
task_to_standard_dataset(task_reader.get_validation_tasks())
216-
)
209+
dataset_segments.append(task_to_standard_dataset(task_reader.get_validation_tasks()))
217210
if not dataset_segments:
218-
raise ValueError(
219-
f"Unsupported split '{self.split}'. Expected to contain 'train' or 'val'."
220-
)
211+
raise ValueError(f"Unsupported split '{self.split}'. Expected to contain 'train' or 'val'.")
221212

222-
concatenated_dataset = (
223-
dataset_segments[0]
224-
if len(dataset_segments) == 1
225-
else datasets.concatenate_datasets(dataset_segments)
226-
)
213+
concatenated_dataset = dataset_segments[0] if len(dataset_segments) == 1 else datasets.concatenate_datasets(dataset_segments)
227214

228215
self.dataset = _HFBatchReader(
229216
concatenated_dataset,
@@ -271,15 +258,9 @@ class SwanlabMonitor(Monitor):
271258
"""
272259

273260
def __init__(self, project: str, group: str, name: str, role: str, config) -> None:
274-
assert (
275-
swanlab is not None
276-
), "swanlab is not installed. Please install it to use SwanlabMonitor."
277-
278-
monitor_args = (
279-
(config.monitor.monitor_args or {})
280-
if config and getattr(config, "monitor", None)
281-
else {}
282-
)
261+
assert swanlab is not None, "swanlab is not installed. Please install it to use SwanlabMonitor."
262+
263+
monitor_args = (config.monitor.monitor_args or {}) if config and getattr(config, "monitor", None) else {}
283264

284265
# Optional API login via code if provided; otherwise try environment, then rely on prior `swanlab login`.
285266
api_key = os.environ.get("SWANLAB_API_KEY")
@@ -331,9 +312,7 @@ def __init__(self, project: str, group: str, name: str, role: str, config) -> No
331312
self.data_dashboard_url = run_info["cloud"]["experiment_url"]
332313

333314
def log_table(self, table_name: str, experiences_table, step: int):
334-
assert (
335-
swanlab is not None
336-
), "swanlab is not installed. Please install it to use SwanlabMonitor."
315+
assert swanlab is not None, "swanlab is not installed. Please install it to use SwanlabMonitor."
337316

338317
# Convert pandas DataFrame to SwanLab ECharts Table
339318
headers: List[str] = list(experiences_table.columns)
@@ -351,9 +330,7 @@ def log_table(self, table_name: str, experiences_table, step: int):
351330
def log(self, data: dict, step: int, commit: bool = False) -> None:
352331
"""Log metrics."""
353332
# SwanLab doesn't use commit flag; keep signature for compatibility
354-
assert (
355-
swanlab is not None
356-
), "swanlab is not installed. Please install it to use SwanlabMonitor."
333+
assert swanlab is not None, "swanlab is not installed. Please install it to use SwanlabMonitor."
357334
swanlab.log(data, step=step)
358335
self.console_logger.info(f"Step {step}: {data}")
359336

@@ -372,9 +349,7 @@ def log(self, data: dict, step: int, commit: bool = False) -> None:
372349
test_robot_data = {}
373350
test_robot_data["step"] = step
374351
test_robot_data["data_dashboard_url"] = self.data_dashboard_url
375-
test_robot_data["reward_for_test_robot"] = data[
376-
"experience_pipeline/group_advantages/reward_mean/mean"
377-
]
352+
test_robot_data["reward_for_test_robot"] = data["experience_pipeline/group_advantages/reward_mean/mean"]
378353
_test_if_test_mode(key="reward_probe", value=test_robot_data, config=ajet_config)
379354

380355
def close(self) -> None:

0 commit comments

Comments
 (0)