-
Notifications
You must be signed in to change notification settings - Fork 6.6k
perf: optimize CausalConv3d for wan autoencoders #12800
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
testing script: import torch
import torch.nn as nn
import torch.nn.functional as F
import triton.testing
class CausalConv3d_A(nn.Conv3d):
"""
Implementation A: Fully explicit padding using F.pad
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._padding = (
self.padding[2],
self.padding[2],
self.padding[1],
self.padding[1],
2 * self.padding[0],
0,
)
self.padding = (0, 0, 0) # Reset internal padding to 0
def forward(self, x, cache_x=None):
padding = list(self._padding)
if cache_x is not None and self._padding[4] > 0:
cache_x = cache_x.to(x.device)
x = torch.cat([cache_x, x], dim=2)
padding[4] -= cache_x.shape[2]
x = F.pad(x, padding)
return super().forward(x)
class CausalConv3d_B(nn.Conv3d):
"""
Implementation B: Explicit Temporal padding, Implicit Spatial padding
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.temporal_padding = 2 * self.padding[0]
# Keep spatial padding, remove temporal padding from conv layer
self.padding = (0, self.padding[1], self.padding[2])
def forward(self, x, cache_x=None):
b, c, t, h, w = x.size()
padding = self.temporal_padding
if cache_x is not None and self.temporal_padding > 0:
cache_x = cache_x.to(x.device)
x = torch.cat([cache_x, x], dim=2)
padding -= cache_x.shape[2]
# Manually pad time dimension
if padding > 0:
x = torch.cat([x.new_zeros(b, c, padding, h, w), x], dim=2)
return super().forward(x)
def setup_models(in_channels, out_channels, kernel_size, padding):
device = "cuda" if torch.cuda.is_available() else "cpu"
model_a = CausalConv3d_A(in_channels, out_channels, kernel_size, padding=padding).to(device)
model_b = CausalConv3d_B(in_channels, out_channels, kernel_size, padding=padding).to(device)
model_b.load_state_dict(model_a.state_dict())
return model_a, model_b
def test_correctness():
print("\n=== Running Correctness Test ===")
B, C, T, H, W = 2, 32, 16, 64, 64
out_C = 64
kernel = 3
pad_val = 1 # resulting in causal pad of 2*1=2
model_a, model_b = setup_models(C, out_C, kernel, padding=(pad_val, pad_val, pad_val))
model_a.eval()
model_b.eval()
x = torch.randn(B, C, T, H, W, device="cuda")
# 1. Test without cache
with torch.no_grad():
out_a = model_a(x)
out_b = model_b(x)
try:
torch.testing.assert_close(out_a, out_b, rtol=1e-5, atol=1e-5)
print("[Pass] Outputs are numerically identical (No Cache).")
except AssertionError as e:
print("[Fail] Outputs differ!")
print(e)
return
cache = torch.randn(B, C, 2, H, W, device="cuda")
with torch.no_grad():
out_a_cache = model_a(x, cache_x=cache)
out_b_cache = model_b(x, cache_x=cache)
try:
torch.testing.assert_close(out_a_cache, out_b_cache, rtol=1e-5, atol=1e-5)
print("[Pass] Outputs are numerically identical (With Cache).")
except AssertionError as e:
print("[Fail] Outputs differ with cache!")
def benchmark_performance():
print("\n=== Running Performance Benchmark ===")
if not torch.cuda.is_available():
print("Skipping benchmark (CUDA not available)")
return
B, C, T, H, W = 4, 64, 32, 128, 128
out_C = 64
kernel = 3
# Padding set to (1,1,1), so T gets padded by 2, H/W by 1
model_a, model_b = setup_models(C, out_C, kernel, padding=(1, 1, 1))
x = torch.randn(B, C, T, H, W, device="cuda")
def run_a():
return model_a(x)
def run_b():
return model_b(x)
ms_a = triton.testing.do_bench(run_a, rep=100)
ms_b = triton.testing.do_bench(run_b, rep=100)
print(f"Implementation A (F.pad): {ms_a:.3f} ms")
print(f"Implementation B (Impl.H/W): {ms_b:.3f} ms")
diff = (ms_a - ms_b) / ms_a * 100
print(f"Implementation B is {diff:.2f}% faster")
if __name__ == "__main__":
test_correctness()
benchmark_performance()result: |
|
@sayakpaul @yiyixuxu @DN6 |
sayakpaul
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Quite a brain-teaser this 🧠
So, we're mainly just materializing the temporal padding i.e., only padding along time and NOT all 3 dimensions. This is what mainly leads to the improvements, right?
Do you see any positive effects of this when increasing the input resolutions?
| self.temporal_padding = 2 * self.padding[0] | ||
| # Keep spatial padding, remove temporal padding from conv layer | ||
| self.padding = (0, self.padding[1], self.padding[2]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This bit feels a little confusing to me, TBH.
We're assiging a scalar to temporal_padding and then a 3-member tuple to padding. I would have expected temporal_padding to be a 3-member tuple, unless I am missing something obvious.
Perhaps, you could help me understand this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Although the standard Conv module includes a padding attribute, it does not support causal padding in the temporal dimension. Previously, we manually removed all internal padding and relied explicitly on F.pad. In this implementation, we apply manual padding only to the temporal dimension when necessary, while retaining the module's native padding for spatial (H/W) dimensions. This could be faster because the internal padding may use better optimizations.
|
@c8ef could we also check if the performance further improves with |
After compilation, the two CausalConv modules have similar performance. However, I believe this optimization is still useful in scenarios where we cannot compile - for example, when compiling the entire VAE takes too much startup time, or when compiling certain modules may negatively impact video quality. |
What does this PR do?
By optimizing CausalConv3d, this patch improves the overall performance of wan autoencoders by 5-10%.
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.