Conversation
|
Work done
Waiting for feedback and review :) |
|
Hi @sayakpaul @dhruvrnaik any updates? |
|
@LawJarp-A sorry about the delay on our end. @DN6 will review it soon. |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
Hi @LawJarp-A I think we would need TeaCache to be implemented in a model agnostic way in order to merge the PR. The First Block Cache implementation is a good reference for this. |
Yep @DN6 , I agree, I wanted to first implement it just for a single model and get feedback on that before I work on Model agnostic full implementation. I'm sort of working on it, didn't push it yet. I'll take a look at First block cache for reference as well. |
|
@DN6 updated it in a more model agnostic way. |
…th auto-detection
|
Added multi model support, testing it thoroughly though. |
|
Hi @DN6 @sayakpaul
In the meantime any feedback would be appreciated |
|
Thanks @LawJarp-A!
You can refer to #12569 for testing
Yes, I think that is informative for users. |
sayakpaul
left a comment
There was a problem hiding this comment.
Some initial feedback. Most important question is it seems like we need to craft different logic based on different model? Can we not keep it model agnostic?
|
I am trying to think if ways we can avoid having the forward model for each model now. Initially that seemed like th ebe
t was fine when I wrote for flux, but lumina needed multi stage preprocessing. |
… isolation Signed-off-by: Prajwal A <prajwalanagani@gmail.com>
Signed-off-by: Prajwal A <prajwalanagani@gmail.com>
…elpers Signed-off-by: Prajwal A <prajwalanagani@gmail.com>
|
@DN6 @sayakpaul I spent the weekend going over the code again to understand and simplify
I have kept it with per model forward function like you requested instead of the common adapter pattern I was using before. Btw, below are the images generated w and w/o cache |
…orward methods Signed-off-by: Prajwal A <prajwalanagani@gmail.com>
src/diffusers/hooks/hooks.py
Outdated
| # Fallback to default context for backward compatibility with | ||
| # pipelines that don't call cache_context() | ||
| context = "_default" |
There was a problem hiding this comment.
Should this branch not error out like previous?
src/diffusers/hooks/teacache.py
Outdated
| if prev_mean.item() > 1e-9: | ||
| return ((current - previous).abs().mean() / prev_mean).item() |
There was a problem hiding this comment.
Do we need to make it data-dependent (item() call)? Raising it because it makes torch.compile cry.
|
@bot /style |
|
Style bot fixed some files and pushed the changes. |
There was a problem hiding this comment.
Pull request overview
This PR implements TeaCache (Timestep Embedding Aware Cache), a training-free caching technique that speeds up diffusion model inference by 1.5x-2.6x by reusing transformer block computations when consecutive timestep embeddings are similar.
Changes:
- Adds TeaCache hook system with model-specific forward implementations for FLUX, Mochi, Lumina2, and CogVideoX models
- Integrates TeaCache with the existing CacheMixin infrastructure for unified cache management
- Implements StateManager improvements for context-aware state isolation (CFG support)
Reviewed changes
Copilot reviewed 10 out of 10 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
| src/diffusers/hooks/teacache.py | Core TeaCache implementation with polynomial rescaling, model auto-detection, and specialized forward functions for each supported model |
| src/diffusers/models/cache_utils.py | Integration of TeaCacheConfig into enable_cache/disable_cache methods |
| src/diffusers/hooks/init.py | Export TeaCacheConfig, apply_teacache, and StateManager |
| src/diffusers/hooks/hooks.py | StateManager enhancement with default context fallback for backward compatibility |
| src/diffusers/models/transformers/transformer_lumina2.py | Add CacheMixin to Lumina2Transformer2DModel |
| tests/hooks/test_teacache.py | Comprehensive unit tests for config validation, state management, and model detection |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
src/diffusers/hooks/teacache.py
Outdated
| state.cnt = 0 | ||
| state.accumulated_rel_l1_distance = 0.0 | ||
| state.previous_modulated_input = None | ||
| state.previous_residual = None |
There was a problem hiding this comment.
The _maybe_reset_state_for_new_inference method doesn't reset cache_dict and uncond_seq_len which are used by Lumina2. This could cause stale cache data to persist across inference runs when using Lumina2 models. Consider calling state.reset() instead of manually resetting individual fields, or add these Lumina2-specific fields to the reset logic.
| state.previous_residual = None | |
| state.previous_residual = None | |
| # Reset Lumina2-specific state to avoid stale cache/data between inference runs | |
| if hasattr(state, "cache_dict") and state.cache_dict is not None: | |
| # Clear in-place to preserve any existing references to the cache dict | |
| state.cache_dict.clear() | |
| if hasattr(state, "uncond_seq_len"): | |
| state.uncond_seq_len = None |
|
Thanks for the review. Taking a look |
|
@LawJarp-A I am guessing the Copilot review comments were resolved? There also seems to be a couple of unresolved comments. |
|
Cc: @LiewFeng would you like to give this a review as well? |
|
@LawJarp-A a gentle ping. |
Noted @sayakpaul |
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
|
@LawJarp-A sounds good. Let us know whenever you're ready. |
|
@LawJarp-A let us know. |
Yessir. Will ping here |




What does this PR do?
What is TeaCache?
TeaCache (Timestep Embedding Aware Cache) is a training-free caching technique that speeds up diffusion model inference by 1.5x-2.6x by reusing transformer block computations when consecutive timestep embeddings are similar.
Architecture & Design
TeaCache uses a
ModelHookto intercept transformer forward passes without modifying model code. The algorithm:c[0]*x^4 + c[1]*x^3 + c[2]*x^2 + c[3]*x + c[4]Key Design Features:
HookRegistryandCacheMixinfor lifecycle managementStateManagerwith context-aware state for CFG conditional/unconditional branchesSupported Models
All models support automatic coefficient detection based on model class name and config path. Custom coefficients can also be provided via
TeaCacheConfig.Benchmark Results (FLUX.1-dev)
Benchmark Results (Lumina2)
Benchmark Results (CogVideoX-2b)
Benchmark Results (Mochi)
Test Hardware: NVIDIA h100
Framework: Diffusers with TeaCache hooks
All tests: Same seed (42) for reproducibility
Usage
Configuration Options
The
TeaCacheConfigsupports the following parameters:rel_l1_thresh(float, default=0.2): Threshold for accumulated relative L1 distance. Recommended values: 0.25 for ~1.5x speedup, 0.4 for ~1.8x, 0.6 for ~2.0x. Mochi models require lower thresholds (0.06-0.09).coefficients(List[float], optional): Polynomial coefficients for rescaling L1 distance. Auto-detected based on model type if not provided.num_inference_steps(int, optional): Total inference steps. Ensures first/last timesteps are always computed. Auto-detected if not provided.num_inference_steps_callback(Callable[[], int], optional): Callback returning total inference steps. Alternative tonum_inference_steps.current_timestep_callback(Callable[[], int], optional): Callback returning current timestep. Used for debugging/statistics.Files Changed
src/diffusers/hooks/teacache.py- Core implementation with model-specific forward functionssrc/diffusers/models/cache_utils.py- CacheMixin integrationsrc/diffusers/hooks/__init__.py- Export TeaCacheConfig and apply_teacachetests/hooks/test_teacache.py- Comprehensive unit testsFixes # (issue)
#12589
#12635
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@sayakpaul @yiyixuxu @DN6