-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgbuffer_masking_utils.py
More file actions
67 lines (54 loc) · 2.41 KB
/
gbuffer_masking_utils.py
File metadata and controls
67 lines (54 loc) · 2.41 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
"""
FrameDiffuser G-buffer Masking Utilities.
Handles sky region detection and basecolor filling.
Copyright (c) 2025 Ole Beisswenger, Jan-Niklas Dihlmann, Hendrik Lensch
Licensed under MIT License.
"""
import torch
import numpy as np
from PIL import Image
def fill_black_pixels_in_basecolor(base_color_image, depth_image, final_image):
"""
Fill black pixels in basecolor where both basecolor AND depth are black.
This handles sky regions where deferred rendering provides no geometric data.
Args:
base_color_image: BaseColor G-buffer (PIL Image, numpy array, or torch Tensor)
depth_image: Depth G-buffer (PIL Image, numpy array, or torch Tensor)
final_image: Final rendered image to sample sky colors from
Returns:
PIL Image with sky regions filled from final_image
"""
# Convert basecolor to numpy
if isinstance(base_color_image, torch.Tensor):
base_color_np = ((base_color_image.cpu().numpy().transpose(1, 2, 0) + 1) * 127.5).astype(np.uint8)
elif isinstance(base_color_image, Image.Image):
base_color_np = np.array(base_color_image.convert('RGB'))
else:
base_color_np = np.array(base_color_image)
# Convert depth to numpy
if isinstance(depth_image, torch.Tensor):
if depth_image.shape[0] == 1:
depth_np = ((depth_image[0].cpu().numpy() + 1) * 127.5).astype(np.uint8)
else:
depth_np = ((depth_image[0].cpu().numpy() + 1) * 127.5).astype(np.uint8)
elif isinstance(depth_image, Image.Image):
depth_np = np.array(depth_image.convert('L'))
else:
depth_np = np.array(depth_image)
# Convert final image to numpy
if isinstance(final_image, torch.Tensor):
final_np = ((final_image.cpu().numpy().transpose(1, 2, 0) + 1) * 127.5).astype(np.uint8)
elif isinstance(final_image, Image.Image):
final_np = np.array(final_image.convert('RGB'))
else:
final_np = np.array(final_image)
# Threshold for detecting black pixels (sky regions)
threshold = 10
base_color_black = np.all(base_color_np < threshold, axis=2)
depth_black = depth_np < threshold
# Mask: true where both basecolor and depth are black (sky regions)
mask = base_color_black & depth_black
# Fill sky regions with final image pixels
result = base_color_np.copy()
result[mask] = final_np[mask]
return Image.fromarray(result)