Skip to content

Commit e7e97f6

Browse files
authored
Add files via upload
1 parent edd8321 commit e7e97f6

File tree

1 file changed

+292
-0
lines changed

1 file changed

+292
-0
lines changed

CameraCtrl-svd/inference2.py

Lines changed: 292 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,292 @@
1+
import argparse
2+
import json
3+
import os
4+
import torch
5+
import numpy as np
6+
import imageio
7+
from tqdm import tqdm
8+
from omegaconf import OmegaConf
9+
from PIL import Image
10+
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
11+
from diffusers import AutoencoderKLTemporalDecoder, EulerDiscreteScheduler
12+
from diffusers.utils.import_utils import is_xformers_available
13+
from packaging import version as pver
14+
15+
from cameractrl.pipelines.pipeline_animation import StableVideoDiffusionPipelinePoseCond
16+
from cameractrl.models.unet import UNetSpatioTemporalConditionModelPoseCond
17+
from cameractrl.models.pose_adaptor import CameraPoseEncoder
18+
from cameractrl.utils.util import save_videos_grid
19+
20+
21+
class Camera(object):
22+
def __init__(self, entry):
23+
fx, fy, cx, cy = entry[1:5]
24+
self.fx = fx
25+
self.fy = fy
26+
self.cx = cx
27+
self.cy = cy
28+
w2c_mat = np.array(entry[7:]).reshape(3, 4)
29+
w2c_mat_4x4 = np.eye(4)
30+
w2c_mat_4x4[:3, :] = w2c_mat
31+
self.w2c_mat = w2c_mat_4x4
32+
self.c2w_mat = np.linalg.inv(w2c_mat_4x4)
33+
34+
35+
def setup_for_distributed(is_master):
36+
"""
37+
This function disables printing when not in master process
38+
"""
39+
import builtins as __builtin__
40+
builtin_print = __builtin__.print
41+
42+
def print(*args, **kwargs):
43+
force = kwargs.pop('force', False)
44+
if is_master or force:
45+
builtin_print(*args, **kwargs)
46+
47+
__builtin__.print = print
48+
49+
50+
def custom_meshgrid(*args):
51+
# ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
52+
if pver.parse(torch.__version__) < pver.parse('1.10'):
53+
return torch.meshgrid(*args)
54+
else:
55+
return torch.meshgrid(*args, indexing='ij')
56+
57+
58+
def get_relative_pose(cam_params, zero_first_frame_scale):
59+
abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params]
60+
abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params]
61+
source_cam_c2w = abs_c2ws[0]
62+
63+
# Get the average distance from camera to origin
64+
distances = [np.linalg.norm(c2w[:3, 3]) for c2w in abs_c2ws]
65+
avg_distance = np.mean(distances)
66+
67+
# Scale factor to normalize the distances
68+
# This will make the average distance from camera to origin approximately 1.0
69+
scale_factor = 1.6 / avg_distance
70+
71+
# Scale all camera positions
72+
scaled_c2ws = []
73+
for c2w in abs_c2ws:
74+
scaled_c2w = c2w.copy()
75+
scaled_c2w[:3, 3] *= scale_factor
76+
scaled_c2ws.append(scaled_c2w)
77+
78+
source_cam_c2w = scaled_c2ws[0]
79+
80+
if zero_first_frame_scale:
81+
cam_to_origin = 0
82+
else:
83+
cam_to_origin = np.linalg.norm(source_cam_c2w[:3, 3])
84+
85+
target_cam_c2w = np.array([
86+
[1, 0, 0, 0],
87+
[0, 1, 0, -cam_to_origin],
88+
[0, 0, 1, 0],
89+
[0, 0, 0, 1]
90+
])
91+
92+
abs2rel = target_cam_c2w @ np.linalg.inv(scaled_c2ws[0])
93+
ret_poses = [target_cam_c2w, ] + [abs2rel @ c2w for c2w in scaled_c2ws[1:]]
94+
ret_poses = np.array(ret_poses, dtype=np.float32)
95+
return ret_poses
96+
97+
98+
def ray_condition(K, c2w, H, W, device):
99+
# c2w: B, V, 4, 4
100+
# K: B, V, 4
101+
# V: # of video frames
102+
103+
B = K.shape[0] # batch size
104+
105+
j, i = custom_meshgrid(
106+
torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype),
107+
torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype),
108+
)
109+
i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
110+
j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
111+
112+
fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1
113+
114+
zs = torch.ones_like(i) # [B, HxW]
115+
xs = (i - cx) / fx * zs
116+
ys = (j - cy) / fy * zs
117+
zs = zs.expand_as(ys)
118+
119+
directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3
120+
directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3
121+
122+
rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, 3, HW
123+
rays_o = c2w[..., :3, 3] # B, V, 3
124+
rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, 3, HW
125+
# c2w @ dirctions
126+
rays_dxo = torch.linalg.cross(rays_o, rays_d)
127+
plucker = torch.cat([rays_dxo, rays_d], dim=-1)
128+
plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6
129+
return plucker
130+
131+
132+
def get_pipeline(ori_model_path, unet_subfolder, down_block_types, up_block_types, pose_encoder_kwargs,
133+
attention_processor_kwargs, pose_adaptor_ckpt, enable_xformers, device):
134+
noise_scheduler = EulerDiscreteScheduler.from_pretrained(ori_model_path, subfolder="scheduler")
135+
feature_extractor = CLIPImageProcessor.from_pretrained(ori_model_path, subfolder="feature_extractor")
136+
image_encoder = CLIPVisionModelWithProjection.from_pretrained(ori_model_path, subfolder="image_encoder")
137+
vae = AutoencoderKLTemporalDecoder.from_pretrained(ori_model_path, subfolder="vae")
138+
unet = UNetSpatioTemporalConditionModelPoseCond.from_pretrained(ori_model_path,
139+
subfolder=unet_subfolder,
140+
down_block_types=down_block_types,
141+
up_block_types=up_block_types)
142+
pose_encoder = CameraPoseEncoder(**pose_encoder_kwargs)
143+
print("Setting the attention processors")
144+
unet.set_pose_cond_attn_processor(enable_xformers=(enable_xformers and is_xformers_available()), **attention_processor_kwargs)
145+
print(f"Loading weights of camera encoder and attention processor from {pose_adaptor_ckpt}")
146+
ckpt_dict = torch.load(pose_adaptor_ckpt, map_location=unet.device)
147+
pose_encoder_state_dict = ckpt_dict['pose_encoder_state_dict']
148+
pose_encoder_m, pose_encoder_u = pose_encoder.load_state_dict(pose_encoder_state_dict)
149+
assert len(pose_encoder_m) == 0 and len(pose_encoder_u) == 0
150+
attention_processor_state_dict = ckpt_dict['attention_processor_state_dict']
151+
_, attention_processor_u = unet.load_state_dict(attention_processor_state_dict, strict=False)
152+
assert len(attention_processor_u) == 0
153+
print("Loading done")
154+
vae.to(device)
155+
image_encoder.to(device)
156+
unet.to(device)
157+
pipeline = StableVideoDiffusionPipelinePoseCond(
158+
vae=vae,
159+
image_encoder=image_encoder,
160+
unet=unet,
161+
scheduler=noise_scheduler,
162+
feature_extractor=feature_extractor,
163+
pose_encoder=pose_encoder
164+
)
165+
pipeline = pipeline.to(device)
166+
return pipeline
167+
168+
169+
def main(args):
170+
os.makedirs(os.path.join(args.out_root, 'generated_videos'), exist_ok=True)
171+
os.makedirs(os.path.join(args.out_root, 'reference_images'), exist_ok=True)
172+
rank = args.local_rank
173+
setup_for_distributed(rank == 0)
174+
gpu_id = rank % torch.cuda.device_count()
175+
model_configs = OmegaConf.load(args.model_config)
176+
device = f"cuda:{gpu_id}"
177+
print(f'Constructing pipeline')
178+
pipeline = get_pipeline(args.ori_model_path, model_configs['unet_subfolder'], model_configs['down_block_types'],
179+
model_configs['up_block_types'], model_configs['pose_encoder_kwargs'],
180+
model_configs['attention_processor_kwargs'], args.pose_adaptor_ckpt, args.enable_xformers, device)
181+
print('Done')
182+
183+
print('Loading K, R, t matrix')
184+
with open(args.trajectory_file, 'r') as f:
185+
poses = f.readlines()
186+
poses = [pose.strip().split(' ') for pose in poses[1:]]
187+
cam_params = [[float(x) for x in pose] for pose in poses]
188+
cam_params = [Camera(cam_param) for cam_param in cam_params]
189+
190+
sample_wh_ratio = args.image_width / args.image_height
191+
pose_wh_ratio = args.original_pose_width / args.original_pose_height
192+
if pose_wh_ratio > sample_wh_ratio:
193+
resized_ori_w = args.image_height * pose_wh_ratio
194+
for cam_param in cam_params:
195+
cam_param.fx = resized_ori_w * cam_param.fx / args.image_width
196+
else:
197+
resized_ori_h = args.image_width / pose_wh_ratio
198+
for cam_param in cam_params:
199+
cam_param.fy = resized_ori_h * cam_param.fy / args.image_height
200+
intrinsic = np.asarray([[cam_param.fx * args.image_width,
201+
cam_param.fy * args.image_height,
202+
cam_param.cx * args.image_width,
203+
cam_param.cy * args.image_height]
204+
for cam_param in cam_params], dtype=np.float32)
205+
K = torch.as_tensor(intrinsic)[None] # [1, 1, 4]
206+
c2ws = get_relative_pose(cam_params, zero_first_frame_scale=True)
207+
c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4]
208+
plucker_embedding = ray_condition(K, c2ws, args.image_height, args.image_width, device='cpu') # b f h w 6
209+
plucker_embedding = plucker_embedding.permute(0, 1, 4, 2, 3).contiguous().to(device=device)
210+
211+
prompt_dict = json.load(open(args.prompt_file, 'r'))
212+
prompt_images = prompt_dict['image_paths']
213+
prompt_captions = prompt_dict['captions']
214+
N = int(len(prompt_images) // args.n_procs)
215+
remainder = int(len(prompt_images) % args.n_procs)
216+
prompts_per_gpu = [N + 1 if gpu_id < remainder else N for gpu_id in range(args.n_procs)]
217+
low_idx = sum(prompts_per_gpu[:gpu_id])
218+
high_idx = low_idx + prompts_per_gpu[gpu_id]
219+
prompt_images = prompt_images[low_idx: high_idx]
220+
prompt_captions = prompt_captions[low_idx: high_idx]
221+
print(f"rank {rank} / {torch.cuda.device_count()}, number of prompts: {len(prompt_images)}")
222+
223+
generator = torch.Generator(device=device)
224+
generator.manual_seed(42)
225+
226+
for prompt_image, prompt_caption in tqdm(zip(prompt_images, prompt_captions)):
227+
save_name = "_".join(prompt_caption.split(" "))
228+
condition_image = Image.open(prompt_image)
229+
# 1) output_type='pil' 로 설정해 PIL.Image 리스트로 받습니다.
230+
with torch.no_grad():
231+
result = pipeline(
232+
image=condition_image,
233+
pose_embedding=plucker_embedding,
234+
height=args.image_height,
235+
width=args.image_width,
236+
num_frames=args.num_frames,
237+
num_inference_steps=args.num_inference_steps,
238+
min_guidance_scale=args.min_guidance_scale,
239+
max_guidance_scale=args.max_guidance_scale,
240+
do_image_process=True,
241+
generator=generator,
242+
output_type='pil'
243+
)
244+
# result.frames is [[PIL1, PIL2, ..., PILN]] for batch_size=1
245+
frames = result.frames[0]
246+
247+
# 2) 각 프레임을 개별 PNG 로 저장
248+
frame_dir = os.path.join(args.out_root, 'generated_frames', save_name)
249+
os.makedirs(frame_dir, exist_ok=True)
250+
for idx, frame in enumerate(frames):
251+
frame.save(os.path.join(frame_dir, f"frame_{idx:03d}.png"))
252+
253+
# 3) NumPy 배열 리스트로 변환 → (H, W, 3), uint8
254+
frames_np = [np.array(f) for f in frames] # 각 요소 shape=(H, W, 3)
255+
256+
# 4) imageio로 동영상 저장
257+
video_path = os.path.join(args.out_root, 'generated_videos', f"{save_name}.mp4")
258+
imageio.mimsave(video_path, frames_np, fps=30) # 채널 순서 오류 사라짐
259+
260+
# 5) 조건 이미지도 저장
261+
resized_condition_image = condition_image.resize((args.image_width, args.image_height))
262+
resized_condition_image.save(os.path.join(args.out_root, 'reference_images', f'{save_name}.png'))
263+
264+
265+
if __name__ == '__main__':
266+
parser = argparse.ArgumentParser()
267+
parser.add_argument("--out_root", type=str)
268+
parser.add_argument("--image_height", type=int, default=576)
269+
parser.add_argument("--image_width", type=int, default=320)
270+
parser.add_argument("--num_frames", type=int, default=14, help="14 for svd and 25 for svd-xt", choices=[14, 25])
271+
parser.add_argument("--ori_model_path", type=str)
272+
parser.add_argument("--unet_subfolder", type=str, default='unet')
273+
parser.add_argument("--enable_xformers", action='store_true')
274+
parser.add_argument("--pose_adaptor_ckpt", default=None)
275+
parser.add_argument("--num_inference_steps", type=int, default=25)
276+
parser.add_argument("--min_guidance_scale", type=float, default=1.0) # Guidance scale 조정 - 1.0->0.5 (바꾸니까 사람 얼굴이 아님 ;)
277+
parser.add_argument("--max_guidance_scale", type=float, default=3.0) # Guidance scale 조정 - 3.0->2.0
278+
parser.add_argument("--prompt_file", required=True, help='prompts path, json or txt')
279+
parser.add_argument("--trajectory_file", required=True)
280+
parser.add_argument("--original_pose_width", type=int, default=720)
281+
parser.add_argument("--original_pose_height", type=int, default=1280)
282+
parser.add_argument("--model_config", required=True)
283+
parser.add_argument("--n_procs", type=int, default=8)
284+
285+
# DDP args
286+
parser.add_argument("--world_size", default=1, type=int,
287+
help="number of the distributed processes.")
288+
parser.add_argument('--local-rank', type=int, default=-1,
289+
help='Replica rank on the current node. This field is required '
290+
'by `torch.distributed.launch`.')
291+
args = parser.parse_args()
292+
main(args)

0 commit comments

Comments
 (0)