forked from THUDM/slime
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathfully_async_rollout.py
More file actions
263 lines (208 loc) · 9.35 KB
/
fully_async_rollout.py
File metadata and controls
263 lines (208 loc) · 9.35 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
import asyncio
import atexit
import queue
import threading
import time
# Import core functions from sglang_rollout directly to avoid code duplication
from slime.rollout.sglang_rollout import GenerateState, generate_and_rm_group
from slime.utils.async_utils import run
from slime.utils.types import Sample
# Global worker manager
_global_worker = None
_worker_lock = threading.Lock()
def get_global_worker(args, data_buffer):
"""Get or create global worker"""
global _global_worker
with _worker_lock:
if _global_worker is None or not _global_worker.worker_thread.is_alive():
print("Creating new global async worker...")
_global_worker = AsyncRolloutWorker(args, data_buffer, concurrency=args.sglang_server_concurrency)
_global_worker.start()
return _global_worker
def stop_global_worker():
"""Stop global worker"""
global _global_worker
with _worker_lock:
if _global_worker is not None:
_global_worker.stop()
_global_worker = None
class AsyncRolloutWorker:
"""
Simplified asynchronous rollout worker, using threads instead of processes
Supports continuous running, independent of rollout function lifecycle
"""
def __init__(self, args, data_buffer, concurrency=10):
self.args = args
self.data_buffer = data_buffer # Directly save data_buffer reference
self.concurrency = concurrency
self.running = True
self.output_queue = queue.Queue(maxsize=1000) # Continuous output queue
self.worker_thread = None
self.state = GenerateState(args)
async def continuous_worker_loop(self):
"""Continuous work loop - constantly get data from data_buffer and process"""
print("Continuous async rollout worker started")
active_tasks = set()
max_concurrent_tasks = self.args.rollout_batch_size
group_id_counter = 0
while self.running:
try:
# Clean up completed tasks
if active_tasks:
done_tasks = {task for task in active_tasks if task.done()}
for task in done_tasks:
try:
task.result() # Results are already handled in callbacks
except Exception as e:
print(f"Task failed with exception: {e}")
active_tasks -= done_tasks
# If active task count hasn't reached limit, try to get new data and start tasks
while len(active_tasks) < max_concurrent_tasks and self.running:
samples = self.data_buffer.get_samples(1)
for group in samples:
group_id = group_id_counter
group_id_counter += 1
# Create new async task
task = asyncio.create_task(
generate_and_rm_group(
self.args,
group,
sampling_params=self.state.sampling_params.copy(),
evaluation=False,
)
)
# Add completion callback
def make_callback(gid):
def task_done_callback(done_task):
result = done_task.result()
self.output_queue.put((gid, result))
return task_done_callback
task.add_done_callback(make_callback(group_id))
active_tasks.add(task)
break
# Brief sleep to avoid busy waiting
await asyncio.sleep(1)
except Exception as e:
print(f"Error in continuous worker loop: {e}")
await asyncio.sleep(1)
if active_tasks:
print(f"Waiting for {len(active_tasks)} continuous tasks to complete...")
await asyncio.wait(active_tasks)
print("Continuous async rollout worker stopped")
def worker_thread_func(self):
"""Worker function running in independent thread"""
asyncio.run(self.continuous_worker_loop())
def start(self):
"""Start continuous work mode"""
if self.worker_thread is None or not self.worker_thread.is_alive():
self.worker_thread = threading.Thread(target=self.worker_thread_func, daemon=True)
self.worker_thread.start()
print("Started continuous async worker thread")
def stop(self):
"""Stop worker thread"""
self.running = False
if self.worker_thread and self.worker_thread.is_alive():
self.worker_thread.join(timeout=5)
print("Stopped async worker thread")
def get_completed_groups(self) -> list[tuple]:
"""Get completed sample groups"""
completed = []
while True:
try:
result = self.output_queue.get_nowait()
completed.append(result)
except queue.Empty:
break
return completed
def get_queue_size(self) -> int:
"""Get current output queue size"""
return self.output_queue.qsize()
async def generate_rollout_async(args, rollout_id: int, data_buffer) -> list[list[Sample]]:
"""
Simplified asynchronous rollout generation - using global continuous worker
"""
assert args.rollout_global_dataset
# Get global worker, which will run continuously
worker = get_global_worker(args, data_buffer)
# Simplified: directly use rollout_batch_size as target
target_data_size = args.rollout_batch_size
data = []
completed_groups = {}
do_print = True
print(f"Starting async rollout generation for {target_data_size} groups")
print(f"Global worker queue size: {worker.get_queue_size()}")
# Main loop: collect results from global worker's output queue
start_time = time.time()
last_progress_time = start_time
no_progress_timeout = 30.0 # Warn if no progress for 30 seconds
while len(data) < target_data_size:
# Collect completed results
completed = worker.get_completed_groups()
made_progress = False
for group_id, group in completed:
completed_groups[group_id] = group
made_progress = True
if made_progress:
last_progress_time = time.time()
# Process completed groups in order (try to maintain order, but not strict requirement)
processed_any = False
# Process all available completed groups
available_ids = list(completed_groups.keys())
for group_id in available_ids:
if len(data) >= target_data_size:
break
group = completed_groups.pop(group_id)
# If any sample in the group was aborted, return the whole group to the data buffer
# and do not forward it to the training engine.
try:
any_aborted = any([sample.status == Sample.Status.ABORTED for sample in group])
except Exception:
any_aborted = False
if any_aborted:
try:
# add back to buffer so it can be retried or handled by buffer policy
data_buffer.add_samples([group])
print(f"Returned aborted group {group_id} to data buffer", flush=True)
except Exception as e:
print(f"Failed to return aborted group {group_id} to buffer: {e}", flush=True)
# don't count as processed for training
continue
if do_print:
print(
f"First rollout sample: {[group[0].prompt + group[0].response]}, "
f"label: {group[0].label}, reward: {group[0].reward}",
flush=True,
)
do_print = False
# Simplified: directly add samples, no filters used
data.append(group)
processed_any = True
# Check progress
current_time = time.time()
if current_time - last_progress_time > no_progress_timeout:
print(
f"Warning: No progress for {no_progress_timeout}s. "
f"Queue size: {worker.get_queue_size()}, "
f"Collected: {len(data)}/{target_data_size}"
)
last_progress_time = current_time
# If no results were processed, brief sleep to avoid busy waiting
if not processed_any:
await asyncio.sleep(0.01)
duration = time.time() - start_time
print(f"Rollout completed in {duration:.2f}s! Global worker queue size: {worker.get_queue_size()}")
if data:
print(
f"Finish rollout: {[data[-1][0].prompt + data[-1][0].response]}, "
f"label: {data[-1][0].label}, reward: {data[-1][0].reward}",
flush=True,
)
data = sorted(data, key=lambda group: group[0].index)
return data
def generate_rollout_fully_async(args, rollout_id, data_buffer, evaluation=False):
if evaluation:
raise ValueError("Evaluation mode not supported in simple async rollout")
completed_samples = run(generate_rollout_async(args, rollout_id, data_buffer))
return completed_samples
# Register exit cleanup function
atexit.register(stop_global_worker)