This directory contains examples and configurations for training WanVideo models using the LMMs-Engine-Mini framework.
WanVideo is a family of diffusion-based video generation models that support:
- Text-to-Video (T2V): Generate videos from text descriptions
- Image-to-Video (I2V): Generate videos from a starting image
- Video-to-Video (V2V): Transform existing videos with text guidance
- VACE: Video All-in-one Creation and Editing model
- Fun Controls: Advanced control mechanisms for video generation
The implementation currently supports the following WanVideo model configurations:
- 1.3B T2V: Efficient text-to-video model for quick iterations (480×832 resolution)
- 14B T2V: High-quality text-to-video model (480×832 resolution)
- 14B I2V: Image-to-video model with higher resolution support (720×1280)
Organize your video dataset in CSV format with the following structure:
modelscope download --dataset DiffSynth-Studio/example_video_dataset --local_dir ./data/example_video_datasetvideo,prompt
path/to/video1.mp4,"A beautiful sunset over the mountains"
path/to/video2.mp4,"A cat playing with a ball in a garden"video,image,prompt
path/to/video1.mp4,path/to/first_frame1.jpg,"The scene comes to life with movement"
path/to/video2.mp4,path/to/first_frame2.jpg,"Camera pans across the landscape"Save your metadata CSV file (e.g., data/metadata.csv) and update the dataset path in the configuration files.
We provide pre-configured YAML files for different model variants:
examples/wanvideo/configs/wan2.1_t2v_1.3b.yaml: Text-to-Video 1.3B modelconfigs/wan2.1_t2v_14b.yaml: Text-to-Video 14B modelconfigs/wan2.1_i2v_14b.yaml: Image-to-Video 14B model
Modify the configuration files to match your dataset paths and training requirements.
python -m lmms_engine.launch.cli --config examples/wanvideo/configs/wan2.1_t2v_1.3b.yamltorchrun --nproc_per_node=8 --nnodes=1 --node_rank=0 \
--master_addr="127.0.0.1" --master_port="8000" \
-m lmms_engine.launch.cli --config examples/wanvideo/configs/wan2.1_t2v_1.3b.yamlaccelerate launch --use_fsdp \
-m lmms_engine.launch.cli --config examples/wanvideo/configs/wan2.1_t2v_1.3b.yamlTo resume training, update the configuration file to include:
trainer_args:
resume_from_checkpoint: "./output/wan2.1_t2v_1.3b/checkpoint-XXX"Then run the training command again.
Key parameters in the model configuration:
model_config:
load_from_config:
model_type: wanvideo
model_variant: "Wan2.1-T2V-1.3B" # Model variant identifier
model_size: "1.3B" # or "14B"
# DiT architecture
dit_hidden_size: 2432 # Hidden dimension
dit_num_layers: 28 # Number of transformer layers
dit_num_heads: 19 # Number of attention heads
dit_enable_flash_attn: true # Enable Flash Attention
# Training settings
gradient_checkpointing: true
scheduler_type: "flow_match" # Scheduler for diffusion
# Generation settings
num_frames: 49
height: 480
width: 832
guidance_scale: 5.0
num_inference_steps: 20Important training parameters:
trainer_args:
output_dir: "./output/wan2.1_t2v_1.3b"
num_train_epochs: 3
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 1.0e-5
# Mixed precision
bf16: true
tf32: true
# Optimizer settings
optim: "adamw_torch"
weight_decay: 0.01
lr_scheduler_type: "cosine"
warmup_steps: 500
# Checkpointing
save_steps: 500
save_total_limit: 3
eval_steps: 500LoRA support is planned for future releases. Currently, use gradient checkpointing and mixed precision training to reduce memory usage:
model_config:
load_from_config:
gradient_checkpointing: true
attn_implementation: flash_attention_2
trainer_args:
gradient_checkpointing: true
bf16: true
tf32: trueFor training 14B models across multiple GPUs (as configured in wan2.1_i2v_14b.yaml):
trainer_args:
fsdp: "full_shard auto_wrap"
fsdp_config:
backward_prefetch: "backward_pre"
forward_prefetch: true
use_orig_params: false
cpu_ram_efficient_loading: true
sync_module_states: true
limit_all_gathers: true
activation_checkpointing: true
sharding_strategy: "FULL_SHARD"The dataset is configured in the YAML files:
dataset_config:
dataset_type: vision
dataset_format: csv # Currently supports CSV format
dataset_path: data/metadata.csv
video_sampling_strategy: frame_num
frame_num: 49 # Number of frames to sample
video_backend: qwen_vl_utils # Video processing backend
# Processor configuration for video preprocessing
processor_config:
processor_type: wanvideo
do_resize: true
size:
height: 480 # Target height
width: 832 # Target width
do_normalize: true
image_mean: [0.5, 0.5, 0.5]
image_std: [0.5, 0.5, 0.5]The dataset must be in CSV format with the following columns:
For T2V training:
video,prompt
path/to/video1.mp4,"A scenic mountain landscape"
path/to/video2.mp4,"Urban cityscape at night"For I2V training:
video,image,prompt
path/to/video1.mp4,path/to/frame1.jpg,"Movement starts from this frame"
path/to/video2.mp4,path/to/frame2.jpg,"Dynamic scene evolution"Training progress is logged to TensorBoard:
tensorboard --logdir ./output/wan2.1_t2v_1.3bAfter training, you can load and use the model for inference. The model checkpoints are saved in the output_dir specified in your configuration.
import torch
from lmms_engine.models.wanvideo import (
WanVideoForConditionalGeneration,
WanVideoProcessor,
WanVideoConfig,
)
# Load the trained model
checkpoint_path = "./output/wan2.1_t2v_1.3b/checkpoint-XXX"
config = WanVideoConfig.from_pretrained(checkpoint_path)
model = WanVideoForConditionalGeneration.from_pretrained(
checkpoint_path,
config=config,
torch_dtype=torch.bfloat16, # Use bf16 for efficiency
)
processor = WanVideoProcessor()
# Move model to GPU
model = model.to("cuda")
# Generate video from text prompt
prompt = "A serene lake surrounded by mountains at sunset"
with torch.no_grad():
outputs = model.generate(
prompt=prompt,
num_frames=49,
height=480,
width=832,
num_inference_steps=20,
guidance_scale=5.0,
)
# Process and save the generated video
# Note: Video post-processing implementation depends on your specific requirements-
Enable gradient checkpointing:
gradient_checkpointing: true
-
Reduce batch size and increase gradient accumulation:
per_device_train_batch_size: 1 gradient_accumulation_steps: 16
-
Use LoRA for fine-tuning instead of full training
-
Enable FSDP for multi-GPU setups
-
Ensure Flash Attention is installed:
pip install flash-attn --no-build-isolation
-
Use mixed precision training (bf16)
-
Enable TF32 for Ampere GPUs:
tf32: true
If you use WanVideo in your research, please cite:
@article{wanvideo2024,
title={WanVideo: Unified Video Generation with Diffusion Models},
author={WanVideo Team},
year={2024}
}This implementation is provided under the Apache 2.0 License.