|
| 1 | +import argparse |
| 2 | +import json |
| 3 | +import os |
| 4 | +import torch |
| 5 | +import numpy as np |
| 6 | +import imageio |
| 7 | +from tqdm import tqdm |
| 8 | +from omegaconf import OmegaConf |
| 9 | +from PIL import Image |
| 10 | +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection |
| 11 | +from diffusers import AutoencoderKLTemporalDecoder, EulerDiscreteScheduler |
| 12 | +from diffusers.utils.import_utils import is_xformers_available |
| 13 | +from packaging import version as pver |
| 14 | + |
| 15 | +from cameractrl.pipelines.pipeline_animation import StableVideoDiffusionPipelinePoseCond |
| 16 | +from cameractrl.models.unet import UNetSpatioTemporalConditionModelPoseCond |
| 17 | +from cameractrl.models.pose_adaptor import CameraPoseEncoder |
| 18 | +from cameractrl.utils.util import save_videos_grid |
| 19 | + |
| 20 | + |
| 21 | +class Camera(object): |
| 22 | + def __init__(self, entry): |
| 23 | + fx, fy, cx, cy = entry[1:5] |
| 24 | + self.fx = fx |
| 25 | + self.fy = fy |
| 26 | + self.cx = cx |
| 27 | + self.cy = cy |
| 28 | + w2c_mat = np.array(entry[7:]).reshape(3, 4) |
| 29 | + w2c_mat_4x4 = np.eye(4) |
| 30 | + w2c_mat_4x4[:3, :] = w2c_mat |
| 31 | + self.w2c_mat = w2c_mat_4x4 |
| 32 | + self.c2w_mat = np.linalg.inv(w2c_mat_4x4) |
| 33 | + |
| 34 | + |
| 35 | +def setup_for_distributed(is_master): |
| 36 | + """ |
| 37 | + This function disables printing when not in master process |
| 38 | + """ |
| 39 | + import builtins as __builtin__ |
| 40 | + builtin_print = __builtin__.print |
| 41 | + |
| 42 | + def print(*args, **kwargs): |
| 43 | + force = kwargs.pop('force', False) |
| 44 | + if is_master or force: |
| 45 | + builtin_print(*args, **kwargs) |
| 46 | + |
| 47 | + __builtin__.print = print |
| 48 | + |
| 49 | + |
| 50 | +def custom_meshgrid(*args): |
| 51 | + # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid |
| 52 | + if pver.parse(torch.__version__) < pver.parse('1.10'): |
| 53 | + return torch.meshgrid(*args) |
| 54 | + else: |
| 55 | + return torch.meshgrid(*args, indexing='ij') |
| 56 | + |
| 57 | + |
| 58 | +def get_relative_pose(cam_params, zero_first_frame_scale): |
| 59 | + abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params] |
| 60 | + abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params] |
| 61 | + source_cam_c2w = abs_c2ws[0] |
| 62 | + |
| 63 | + # Get the average distance from camera to origin |
| 64 | + distances = [np.linalg.norm(c2w[:3, 3]) for c2w in abs_c2ws] |
| 65 | + avg_distance = np.mean(distances) |
| 66 | + |
| 67 | + # Scale factor to normalize the distances |
| 68 | + # This will make the average distance from camera to origin approximately 1.0 |
| 69 | + scale_factor = 1.6 / avg_distance |
| 70 | + |
| 71 | + # Scale all camera positions |
| 72 | + scaled_c2ws = [] |
| 73 | + for c2w in abs_c2ws: |
| 74 | + scaled_c2w = c2w.copy() |
| 75 | + scaled_c2w[:3, 3] *= scale_factor |
| 76 | + scaled_c2ws.append(scaled_c2w) |
| 77 | + |
| 78 | + source_cam_c2w = scaled_c2ws[0] |
| 79 | + |
| 80 | + if zero_first_frame_scale: |
| 81 | + cam_to_origin = 0 |
| 82 | + else: |
| 83 | + cam_to_origin = np.linalg.norm(source_cam_c2w[:3, 3]) |
| 84 | + |
| 85 | + target_cam_c2w = np.array([ |
| 86 | + [1, 0, 0, 0], |
| 87 | + [0, 1, 0, -cam_to_origin], |
| 88 | + [0, 0, 1, 0], |
| 89 | + [0, 0, 0, 1] |
| 90 | + ]) |
| 91 | + |
| 92 | + abs2rel = target_cam_c2w @ np.linalg.inv(scaled_c2ws[0]) |
| 93 | + ret_poses = [target_cam_c2w, ] + [abs2rel @ c2w for c2w in scaled_c2ws[1:]] |
| 94 | + ret_poses = np.array(ret_poses, dtype=np.float32) |
| 95 | + return ret_poses |
| 96 | + |
| 97 | + |
| 98 | +def ray_condition(K, c2w, H, W, device): |
| 99 | + # c2w: B, V, 4, 4 |
| 100 | + # K: B, V, 4 |
| 101 | + # V: # of video frames |
| 102 | + |
| 103 | + B = K.shape[0] # batch size |
| 104 | + |
| 105 | + j, i = custom_meshgrid( |
| 106 | + torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype), |
| 107 | + torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype), |
| 108 | + ) |
| 109 | + i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW] |
| 110 | + j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW] |
| 111 | + |
| 112 | + fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1 |
| 113 | + |
| 114 | + zs = torch.ones_like(i) # [B, HxW] |
| 115 | + xs = (i - cx) / fx * zs |
| 116 | + ys = (j - cy) / fy * zs |
| 117 | + zs = zs.expand_as(ys) |
| 118 | + |
| 119 | + directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3 |
| 120 | + directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3 |
| 121 | + |
| 122 | + rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, 3, HW |
| 123 | + rays_o = c2w[..., :3, 3] # B, V, 3 |
| 124 | + rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, 3, HW |
| 125 | + # c2w @ dirctions |
| 126 | + rays_dxo = torch.linalg.cross(rays_o, rays_d) |
| 127 | + plucker = torch.cat([rays_dxo, rays_d], dim=-1) |
| 128 | + plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6 |
| 129 | + return plucker |
| 130 | + |
| 131 | + |
| 132 | +def get_pipeline(ori_model_path, unet_subfolder, down_block_types, up_block_types, pose_encoder_kwargs, |
| 133 | + attention_processor_kwargs, pose_adaptor_ckpt, enable_xformers, device): |
| 134 | + noise_scheduler = EulerDiscreteScheduler.from_pretrained(ori_model_path, subfolder="scheduler") |
| 135 | + feature_extractor = CLIPImageProcessor.from_pretrained(ori_model_path, subfolder="feature_extractor") |
| 136 | + image_encoder = CLIPVisionModelWithProjection.from_pretrained(ori_model_path, subfolder="image_encoder") |
| 137 | + vae = AutoencoderKLTemporalDecoder.from_pretrained(ori_model_path, subfolder="vae") |
| 138 | + unet = UNetSpatioTemporalConditionModelPoseCond.from_pretrained(ori_model_path, |
| 139 | + subfolder=unet_subfolder, |
| 140 | + down_block_types=down_block_types, |
| 141 | + up_block_types=up_block_types) |
| 142 | + pose_encoder = CameraPoseEncoder(**pose_encoder_kwargs) |
| 143 | + print("Setting the attention processors") |
| 144 | + unet.set_pose_cond_attn_processor(enable_xformers=(enable_xformers and is_xformers_available()), **attention_processor_kwargs) |
| 145 | + print(f"Loading weights of camera encoder and attention processor from {pose_adaptor_ckpt}") |
| 146 | + ckpt_dict = torch.load(pose_adaptor_ckpt, map_location=unet.device) |
| 147 | + pose_encoder_state_dict = ckpt_dict['pose_encoder_state_dict'] |
| 148 | + pose_encoder_m, pose_encoder_u = pose_encoder.load_state_dict(pose_encoder_state_dict) |
| 149 | + assert len(pose_encoder_m) == 0 and len(pose_encoder_u) == 0 |
| 150 | + attention_processor_state_dict = ckpt_dict['attention_processor_state_dict'] |
| 151 | + _, attention_processor_u = unet.load_state_dict(attention_processor_state_dict, strict=False) |
| 152 | + assert len(attention_processor_u) == 0 |
| 153 | + print("Loading done") |
| 154 | + vae.to(device) |
| 155 | + image_encoder.to(device) |
| 156 | + unet.to(device) |
| 157 | + pipeline = StableVideoDiffusionPipelinePoseCond( |
| 158 | + vae=vae, |
| 159 | + image_encoder=image_encoder, |
| 160 | + unet=unet, |
| 161 | + scheduler=noise_scheduler, |
| 162 | + feature_extractor=feature_extractor, |
| 163 | + pose_encoder=pose_encoder |
| 164 | + ) |
| 165 | + pipeline = pipeline.to(device) |
| 166 | + return pipeline |
| 167 | + |
| 168 | + |
| 169 | +def main(args): |
| 170 | + os.makedirs(os.path.join(args.out_root, 'generated_videos'), exist_ok=True) |
| 171 | + os.makedirs(os.path.join(args.out_root, 'reference_images'), exist_ok=True) |
| 172 | + rank = args.local_rank |
| 173 | + setup_for_distributed(rank == 0) |
| 174 | + gpu_id = rank % torch.cuda.device_count() |
| 175 | + model_configs = OmegaConf.load(args.model_config) |
| 176 | + device = f"cuda:{gpu_id}" |
| 177 | + print(f'Constructing pipeline') |
| 178 | + pipeline = get_pipeline(args.ori_model_path, model_configs['unet_subfolder'], model_configs['down_block_types'], |
| 179 | + model_configs['up_block_types'], model_configs['pose_encoder_kwargs'], |
| 180 | + model_configs['attention_processor_kwargs'], args.pose_adaptor_ckpt, args.enable_xformers, device) |
| 181 | + print('Done') |
| 182 | + |
| 183 | + print('Loading K, R, t matrix') |
| 184 | + with open(args.trajectory_file, 'r') as f: |
| 185 | + poses = f.readlines() |
| 186 | + poses = [pose.strip().split(' ') for pose in poses[1:]] |
| 187 | + cam_params = [[float(x) for x in pose] for pose in poses] |
| 188 | + cam_params = [Camera(cam_param) for cam_param in cam_params] |
| 189 | + |
| 190 | + sample_wh_ratio = args.image_width / args.image_height |
| 191 | + pose_wh_ratio = args.original_pose_width / args.original_pose_height |
| 192 | + if pose_wh_ratio > sample_wh_ratio: |
| 193 | + resized_ori_w = args.image_height * pose_wh_ratio |
| 194 | + for cam_param in cam_params: |
| 195 | + cam_param.fx = resized_ori_w * cam_param.fx / args.image_width |
| 196 | + else: |
| 197 | + resized_ori_h = args.image_width / pose_wh_ratio |
| 198 | + for cam_param in cam_params: |
| 199 | + cam_param.fy = resized_ori_h * cam_param.fy / args.image_height |
| 200 | + intrinsic = np.asarray([[cam_param.fx * args.image_width, |
| 201 | + cam_param.fy * args.image_height, |
| 202 | + cam_param.cx * args.image_width, |
| 203 | + cam_param.cy * args.image_height] |
| 204 | + for cam_param in cam_params], dtype=np.float32) |
| 205 | + K = torch.as_tensor(intrinsic)[None] # [1, 1, 4] |
| 206 | + c2ws = get_relative_pose(cam_params, zero_first_frame_scale=True) |
| 207 | + c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4] |
| 208 | + plucker_embedding = ray_condition(K, c2ws, args.image_height, args.image_width, device='cpu') # b f h w 6 |
| 209 | + plucker_embedding = plucker_embedding.permute(0, 1, 4, 2, 3).contiguous().to(device=device) |
| 210 | + |
| 211 | + prompt_dict = json.load(open(args.prompt_file, 'r')) |
| 212 | + prompt_images = prompt_dict['image_paths'] |
| 213 | + prompt_captions = prompt_dict['captions'] |
| 214 | + N = int(len(prompt_images) // args.n_procs) |
| 215 | + remainder = int(len(prompt_images) % args.n_procs) |
| 216 | + prompts_per_gpu = [N + 1 if gpu_id < remainder else N for gpu_id in range(args.n_procs)] |
| 217 | + low_idx = sum(prompts_per_gpu[:gpu_id]) |
| 218 | + high_idx = low_idx + prompts_per_gpu[gpu_id] |
| 219 | + prompt_images = prompt_images[low_idx: high_idx] |
| 220 | + prompt_captions = prompt_captions[low_idx: high_idx] |
| 221 | + print(f"rank {rank} / {torch.cuda.device_count()}, number of prompts: {len(prompt_images)}") |
| 222 | + |
| 223 | + generator = torch.Generator(device=device) |
| 224 | + generator.manual_seed(42) |
| 225 | + |
| 226 | + for prompt_image, prompt_caption in tqdm(zip(prompt_images, prompt_captions)): |
| 227 | + save_name = "_".join(prompt_caption.split(" ")) |
| 228 | + condition_image = Image.open(prompt_image) |
| 229 | + # 1) output_type='pil' 로 설정해 PIL.Image 리스트로 받습니다. |
| 230 | + with torch.no_grad(): |
| 231 | + result = pipeline( |
| 232 | + image=condition_image, |
| 233 | + pose_embedding=plucker_embedding, |
| 234 | + height=args.image_height, |
| 235 | + width=args.image_width, |
| 236 | + num_frames=args.num_frames, |
| 237 | + num_inference_steps=args.num_inference_steps, |
| 238 | + min_guidance_scale=args.min_guidance_scale, |
| 239 | + max_guidance_scale=args.max_guidance_scale, |
| 240 | + do_image_process=True, |
| 241 | + generator=generator, |
| 242 | + output_type='pil' |
| 243 | + ) |
| 244 | + # result.frames is [[PIL1, PIL2, ..., PILN]] for batch_size=1 |
| 245 | + frames = result.frames[0] |
| 246 | + |
| 247 | + # 2) 각 프레임을 개별 PNG 로 저장 |
| 248 | + frame_dir = os.path.join(args.out_root, 'generated_frames', save_name) |
| 249 | + os.makedirs(frame_dir, exist_ok=True) |
| 250 | + for idx, frame in enumerate(frames): |
| 251 | + frame.save(os.path.join(frame_dir, f"frame_{idx:03d}.png")) |
| 252 | + |
| 253 | + # 3) NumPy 배열 리스트로 변환 → (H, W, 3), uint8 |
| 254 | + frames_np = [np.array(f) for f in frames] # 각 요소 shape=(H, W, 3) |
| 255 | + |
| 256 | + # 4) imageio로 동영상 저장 |
| 257 | + video_path = os.path.join(args.out_root, 'generated_videos', f"{save_name}.mp4") |
| 258 | + imageio.mimsave(video_path, frames_np, fps=30) # 채널 순서 오류 사라짐 |
| 259 | + |
| 260 | + # 5) 조건 이미지도 저장 |
| 261 | + resized_condition_image = condition_image.resize((args.image_width, args.image_height)) |
| 262 | + resized_condition_image.save(os.path.join(args.out_root, 'reference_images', f'{save_name}.png')) |
| 263 | + |
| 264 | + |
| 265 | +if __name__ == '__main__': |
| 266 | + parser = argparse.ArgumentParser() |
| 267 | + parser.add_argument("--out_root", type=str) |
| 268 | + parser.add_argument("--image_height", type=int, default=576) |
| 269 | + parser.add_argument("--image_width", type=int, default=320) |
| 270 | + parser.add_argument("--num_frames", type=int, default=14, help="14 for svd and 25 for svd-xt", choices=[14, 25]) |
| 271 | + parser.add_argument("--ori_model_path", type=str) |
| 272 | + parser.add_argument("--unet_subfolder", type=str, default='unet') |
| 273 | + parser.add_argument("--enable_xformers", action='store_true') |
| 274 | + parser.add_argument("--pose_adaptor_ckpt", default=None) |
| 275 | + parser.add_argument("--num_inference_steps", type=int, default=25) |
| 276 | + parser.add_argument("--min_guidance_scale", type=float, default=1.0) # Guidance scale 조정 - 1.0->0.5 (바꾸니까 사람 얼굴이 아님 ;) |
| 277 | + parser.add_argument("--max_guidance_scale", type=float, default=3.0) # Guidance scale 조정 - 3.0->2.0 |
| 278 | + parser.add_argument("--prompt_file", required=True, help='prompts path, json or txt') |
| 279 | + parser.add_argument("--trajectory_file", required=True) |
| 280 | + parser.add_argument("--original_pose_width", type=int, default=720) |
| 281 | + parser.add_argument("--original_pose_height", type=int, default=1280) |
| 282 | + parser.add_argument("--model_config", required=True) |
| 283 | + parser.add_argument("--n_procs", type=int, default=8) |
| 284 | + |
| 285 | + # DDP args |
| 286 | + parser.add_argument("--world_size", default=1, type=int, |
| 287 | + help="number of the distributed processes.") |
| 288 | + parser.add_argument('--local-rank', type=int, default=-1, |
| 289 | + help='Replica rank on the current node. This field is required ' |
| 290 | + 'by `torch.distributed.launch`.') |
| 291 | + args = parser.parse_args() |
| 292 | + main(args) |
0 commit comments