-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdata_loader.py
More file actions
379 lines (314 loc) · 13.7 KB
/
data_loader.py
File metadata and controls
379 lines (314 loc) · 13.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
## add the pretraining directory to the path
import sys
sys.path.append('./Video-Pre-Training')
import json
import glob
import os
import random
from multiprocessing import Process, Queue, Event
from typing import List, Dict, Tuple, Optional
import numpy as np
import cv2
import collections
import traceback
from agent import resize_image, AGENT_RESOLUTION, ENV_KWARGS
QUEUE_TIMEOUT = 10
CURSOR_FILE = "./Video-Pre-Training/cursors/mouse_cursor_white_16x16.png"
print("CURSOR_FILE: ", CURSOR_FILE)
MINEREC_ORIGINAL_HEIGHT_PX = 720
# If GUI is open, mouse dx/dy need also be adjusted with these scalers.
# If data version is not present, assume it is 1.
MINEREC_VERSION_SPECIFIC_SCALERS = {
"5.7": 0.5,
"5.8": 0.5,
"6.7": 2.0,
"6.8": 2.0,
"6.9": 2.0,
}
def composite_images_with_alpha(image1, image2, alpha, x, y):
"""
Draw image2 over image1 at location x,y, using alpha as the opacity for image2.
Modifies image1 in-place
"""
ch = max(0, min(image1.shape[0] - y, image2.shape[0]))
cw = max(0, min(image1.shape[1] - x, image2.shape[1]))
if ch == 0 or cw == 0:
return
alpha = alpha[:ch, :cw]
image1[y:y + ch, x:x + cw, :] = (image1[y:y + ch, x:x + cw, :] * (1 - alpha) + image2[:ch, :cw, :] * alpha).astype(np.uint8)
def data_loader_worker(tasks_queue, output_queue, quit_workers_event, n_frames=128):
"""
Worker for the data loader that processes video frames and corresponding JSON data.
Returns stacked frames along with all corresponding actions.
Args:
tasks_queue: Queue containing tasks to process
output_queue: Queue to put processed data
quit_workers_event: Event to signal worker termination
n_frames: Number of consecutive frames to stack together
"""
cursor_image = cv2.imread(CURSOR_FILE, cv2.IMREAD_UNCHANGED)
# Assume 16x16
cursor_image = cursor_image[:16, :16, :]
cursor_alpha = cursor_image[:, :, 3:] / 255.0
cursor_image = cursor_image[:, :, :3]
# Initialize frame buffer
frame_buffer = collections.deque(maxlen=n_frames)
action_buffer = collections.deque(maxlen=n_frames)
while True:
task = tasks_queue.get()
if task is None:
break
trajectory_id, video_path, json_path = task
video = cv2.VideoCapture(video_path)
with open(json_path) as json_file:
json_lines = json_file.readlines()
json_data = "[" + ",".join(json_lines) + "]"
json_data = json.loads(json_data)
last_action = None
last_mouse_x = 0
last_mouse_y = 0
# Clear buffers at the start of each trajectory
frame_buffer.clear()
action_buffer.clear()
for i in range(len(json_data)):
if quit_workers_event.is_set():
break
step_data = json_data[i]
# Convert the new format to action dictionary
action = {}
# Handle keyboard inputs
keyboard_keys = step_data.get("keyboard_keys", []) or []
for key in keyboard_keys:
action[f"key.{key.lower()}"] = 1
# Handle mouse buttons
mouse_buttons = step_data.get("mouse_buttons", []) or []
for button in mouse_buttons:
button_index = {
"Button.left": 0,
"Button.right": 1,
"Button.middle": 2
}.get(button)
if button_index is not None:
action[f"mouse.{button_index}"] = 1
# Handle mouse movement
mouse_position = step_data.get("mouse_position")
if mouse_position:
mouse_x, mouse_y = mouse_position
action["mouse_dx"] = mouse_x - last_mouse_x
action["mouse_dy"] = mouse_y - last_mouse_y
action["mouse_x"] = mouse_x
action["mouse_y"] = mouse_y
last_mouse_x = mouse_x
last_mouse_y = mouse_y
else:
action["mouse_dx"] = 0
action["mouse_dy"] = 0
action["mouse_x"] = last_mouse_x
action["mouse_y"] = last_mouse_y
# Handle mouse scroll
mouse_scroll = step_data.get("mouse_scroll")
if mouse_scroll is not None:
action["scroll"] = mouse_scroll
# Check if this is a null action (no inputs)
is_null_action = not any(v for k, v in action.items() if k not in ["mouse_x", "mouse_y"])
# Read frame even if this is null so we progress forward
ret, frame = video.read()
if ret:
# Skip null actions as done in the VPT paper
if is_null_action:
continue
# Add cursor to frame if mouse position is available
if mouse_position:
camera_scaling_factor = frame.shape[0] / MINEREC_ORIGINAL_HEIGHT_PX
cursor_x = int(mouse_position[0] * camera_scaling_factor)
cursor_y = int(mouse_position[1] * camera_scaling_factor)
composite_images_with_alpha(frame, cursor_image, cursor_alpha, cursor_x, cursor_y)
# Convert frame color space and resize
cv2.cvtColor(frame, code=cv2.COLOR_BGR2RGB, dst=frame)
frame = np.asarray(np.clip(frame, 0, 255), dtype=np.uint8)
frame = resize_image(frame, AGENT_RESOLUTION)
# Add frame and action to their respective buffers
frame_buffer.append(frame)
action_buffer.append(action)
# Only output when we have enough frames
if len(frame_buffer) == n_frames:
# Stack frames along a new dimension
stacked_frames = np.stack(list(frame_buffer), axis=0)
# Create list of all corresponding actions
actions = list(action_buffer)
# Send processed data to output queue
output_queue.put((trajectory_id, stacked_frames, actions), timeout=QUEUE_TIMEOUT)
# Update last action for next iteration
last_action = action.copy()
else:
print(f"Could not read frame from video {video_path}")
video.release()
if quit_workers_event.is_set():
break
# Signal completion
output_queue.put(None)
class DataLoader:
"""
Generator class for loading batches from a dataset
This only returns a single step at a time per worker; no sub-sequences.
Idea is that you keep track of the model's hidden state and feed that in,
along with one sample at a time.
+ Simpler loader code
+ Supports lower end hardware
- Not very efficient (could be faster)
- No support for sub-sequences
- Loads up individual files as trajectory files (i.e. if a trajectory is split into multiple files,
this code will load it up as a separate item).
"""
def __init__(self, dataset_dir, n_workers=3, batch_size=2, n_epochs=1, max_queue_size=16):
assert n_workers >= batch_size, "Number of workers must be equal or greater than batch size"
self.dataset_dir = dataset_dir
self.n_workers = n_workers
self.n_epochs = n_epochs
self.batch_size = batch_size
self.max_queue_size = max_queue_size
unique_ids = glob.glob(os.path.join(dataset_dir, "*.mp4"))
unique_ids = list(set([os.path.basename(x).split(".")[0] for x in unique_ids]))
self.unique_ids = unique_ids
# Create tuples of (video_path, json_path) for each unique_id
demonstration_tuples = []
for unique_id in unique_ids:
video_path = os.path.join(dataset_dir, unique_id + ".mp4")
json_path = os.path.join(dataset_dir, unique_id.replace('_recording', '_events') + ".jsonl")
print(video_path, json_path)
demonstration_tuples.append((video_path, json_path))
assert n_workers <= len(demonstration_tuples), f"n_workers should be lower or equal than number of demonstrations {len(demonstration_tuples)}"
# Repeat dataset for n_epochs times, shuffling the order for
# each epoch
self.demonstration_tuples = []
for i in range(n_epochs):
random.shuffle(demonstration_tuples)
self.demonstration_tuples += demonstration_tuples
self.task_queue = Queue()
self.n_steps_processed = 0
for trajectory_id, task in enumerate(self.demonstration_tuples):
self.task_queue.put((trajectory_id, *task))
for _ in range(n_workers):
self.task_queue.put(None)
self.output_queues = [Queue(maxsize=max_queue_size) for _ in range(n_workers)]
self.quit_workers_event = Event()
self.processes = [
Process(
target=data_loader_worker,
args=(
self.task_queue,
output_queue,
self.quit_workers_event,
),
daemon=True
)
for output_queue in self.output_queues
]
for process in self.processes:
process.start()
def __iter__(self):
return self
def __next__(self):
batch_frames = []
batch_actions = []
batch_episode_id = []
for i in range(self.batch_size):
workitem = self.output_queues[self.n_steps_processed % self.n_workers].get(timeout=QUEUE_TIMEOUT)
if workitem is None:
# Stop iteration when first worker runs out of work to do.
# Yes, this has a chance of cutting out a lot of the work,
# but this ensures batches will remain diverse, instead
# of having bad ones in the end where potentially
# one worker outputs all samples to the same batch.
raise StopIteration()
trajectory_id, frame, action = workitem
batch_frames.append(frame)
batch_actions.append(action)
batch_episode_id.append(trajectory_id)
self.n_steps_processed += 1
return batch_frames, batch_actions, batch_episode_id
def __del__(self):
for process in getattr(self, 'processes', []):
process.terminate()
process.join()
def stacked_frame_worker(task_queue, output_queue, quit_event, n_frames):
"""
Worker process that handles loading and stacking frames from videos
"""
try:
while not quit_event.is_set():
task = task_queue.get()
if task is None:
output_queue.put(None)
break
trajectory_id, video_path, json_path = task
# Load video and events
cap = cv2.VideoCapture(video_path)
with open(json_path, 'r') as f:
events = [json.loads(line) for line in f]
frame_buffer = collections.deque(maxlen=n_frames)
frame_idx = 0
while True:
ret, frame = cap.read()
if not ret:
break
processed_frame = process_frame(frame)
frame_buffer.append(processed_frame)
# Only start outputting once we have enough frames
if len(frame_buffer) == n_frames:
# Find corresponding action for this frame
action = get_action_for_frame(frame_idx, events) # You'll need to implement this
# Stack frames along a new dimension
stacked_frames = np.stack(list(frame_buffer), axis=0)
output_queue.put((trajectory_id, stacked_frames, action))
frame_idx += 1
cap.release()
except Exception as e:
print(f"Worker failed with exception: {e}")
traceback.print_exc()
output_queue.put(None)
def json_action_to_env_action(step_data):
"""
Converts JSON action to env action.
"""
# 0. Initialise action dictionary
action = {}
is_null_action = False
# 1. Convert keyboard keys
for key_name in step_data["keyboard_keys"]:
key_name = key_name.lower()
if key_name == "w":
action["forward"] = 1
elif key_name == "s":
action["back"] = 1
elif key_name == "a":
action["left"] = 1
elif key_name == "d":
action["right"] = 1
elif key_name == "space":
action["jump"] = 1
elif key_name == "shift":
action["attack"] = 1
elif key_name == "ctrl":
action["interact_with_item"] = 1
# 2. Convert mouse buttons
for button in step_data["mouse"]["buttons"]:
if button == 0:
action["left_mouse"] = 1
elif button == 1:
action["right_mouse"] = 1
# 3. Convert camera actions
# The input is absolute position, but we want a relative movement vector
# The game outputs mouse dx/dy in pixel units.
action["camera"] = np.array(
[
step_data["mouse"]["dx"],
step_data["mouse"]["dy"],
],
dtype=np.float32,
)
# Set all zeros if null action. This also avoids having
# zero mouse movement and a key press being detected as a non-null action.
if all(v==0 for v in action.values()):
is_null_action = True
return action, is_null_action