Skip to content

Conversation

@c8ef
Copy link
Contributor

@c8ef c8ef commented Dec 6, 2025

What does this PR do?

By optimizing CausalConv3d, this patch improves the overall performance of wan autoencoders by 5-10%.

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@c8ef
Copy link
Contributor Author

c8ef commented Dec 6, 2025

testing script:

import torch
import torch.nn as nn
import torch.nn.functional as F

import triton.testing


class CausalConv3d_A(nn.Conv3d):
    """
    Implementation A: Fully explicit padding using F.pad
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._padding = (
            self.padding[2],
            self.padding[2],
            self.padding[1],
            self.padding[1],
            2 * self.padding[0],
            0,
        )
        self.padding = (0, 0, 0)  # Reset internal padding to 0

    def forward(self, x, cache_x=None):
        padding = list(self._padding)
        if cache_x is not None and self._padding[4] > 0:
            cache_x = cache_x.to(x.device)
            x = torch.cat([cache_x, x], dim=2)
            padding[4] -= cache_x.shape[2]
        x = F.pad(x, padding)
        return super().forward(x)


class CausalConv3d_B(nn.Conv3d):
    """
    Implementation B: Explicit Temporal padding, Implicit Spatial padding
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.temporal_padding = 2 * self.padding[0]
        # Keep spatial padding, remove temporal padding from conv layer
        self.padding = (0, self.padding[1], self.padding[2])

    def forward(self, x, cache_x=None):
        b, c, t, h, w = x.size()
        padding = self.temporal_padding
        if cache_x is not None and self.temporal_padding > 0:
            cache_x = cache_x.to(x.device)
            x = torch.cat([cache_x, x], dim=2)
            padding -= cache_x.shape[2]

        # Manually pad time dimension
        if padding > 0:
            x = torch.cat([x.new_zeros(b, c, padding, h, w), x], dim=2)
        return super().forward(x)


def setup_models(in_channels, out_channels, kernel_size, padding):
    device = "cuda" if torch.cuda.is_available() else "cpu"

    model_a = CausalConv3d_A(in_channels, out_channels, kernel_size, padding=padding).to(device)
    model_b = CausalConv3d_B(in_channels, out_channels, kernel_size, padding=padding).to(device)

    model_b.load_state_dict(model_a.state_dict())

    return model_a, model_b


def test_correctness():
    print("\n=== Running Correctness Test ===")
    B, C, T, H, W = 2, 32, 16, 64, 64
    out_C = 64
    kernel = 3
    pad_val = 1  # resulting in causal pad of 2*1=2

    model_a, model_b = setup_models(C, out_C, kernel, padding=(pad_val, pad_val, pad_val))
    model_a.eval()
    model_b.eval()

    x = torch.randn(B, C, T, H, W, device="cuda")

    # 1. Test without cache
    with torch.no_grad():
        out_a = model_a(x)
        out_b = model_b(x)

    try:
        torch.testing.assert_close(out_a, out_b, rtol=1e-5, atol=1e-5)
        print("[Pass] Outputs are numerically identical (No Cache).")
    except AssertionError as e:
        print("[Fail] Outputs differ!")
        print(e)
        return

    cache = torch.randn(B, C, 2, H, W, device="cuda")
    with torch.no_grad():
        out_a_cache = model_a(x, cache_x=cache)
        out_b_cache = model_b(x, cache_x=cache)

    try:
        torch.testing.assert_close(out_a_cache, out_b_cache, rtol=1e-5, atol=1e-5)
        print("[Pass] Outputs are numerically identical (With Cache).")
    except AssertionError as e:
        print("[Fail] Outputs differ with cache!")


def benchmark_performance():
    print("\n=== Running Performance Benchmark ===")
    if not torch.cuda.is_available():
        print("Skipping benchmark (CUDA not available)")
        return

    B, C, T, H, W = 4, 64, 32, 128, 128
    out_C = 64
    kernel = 3
    # Padding set to (1,1,1), so T gets padded by 2, H/W by 1
    model_a, model_b = setup_models(C, out_C, kernel, padding=(1, 1, 1))

    x = torch.randn(B, C, T, H, W, device="cuda")

    def run_a():
        return model_a(x)

    def run_b():
        return model_b(x)

    ms_a = triton.testing.do_bench(run_a, rep=100)
    ms_b = triton.testing.do_bench(run_b, rep=100)

    print(f"Implementation A (F.pad):   {ms_a:.3f} ms")
    print(f"Implementation B (Impl.H/W): {ms_b:.3f} ms")

    diff = (ms_a - ms_b) / ms_a * 100
    print(f"Implementation B is {diff:.2f}% faster")


if __name__ == "__main__":
    test_correctness()
    benchmark_performance()

result:

=== Running Correctness Test ===
[Pass] Outputs are numerically identical (No Cache).
[Pass] Outputs are numerically identical (With Cache).

=== Running Performance Benchmark ===
Implementation A (F.pad):   44.787 ms
Implementation B (Impl.H/W): 42.507 ms
Implementation B is 5.09% faster

@c8ef
Copy link
Contributor Author

c8ef commented Dec 6, 2025

@sayakpaul @yiyixuxu @DN6
Please take a look, thanks!

@c8ef c8ef changed the title perf: optimize CasualConv3d for wan autoencoders perf: optimize CausalConv3d for wan autoencoders Dec 6, 2025
Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quite a brain-teaser this 🧠

So, we're mainly just materializing the temporal padding i.e., only padding along time and NOT all 3 dimensions. This is what mainly leads to the improvements, right?

Do you see any positive effects of this when increasing the input resolutions?

Comment on lines +165 to +167
self.temporal_padding = 2 * self.padding[0]
# Keep spatial padding, remove temporal padding from conv layer
self.padding = (0, self.padding[1], self.padding[2])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This bit feels a little confusing to me, TBH.

We're assiging a scalar to temporal_padding and then a 3-member tuple to padding. I would have expected temporal_padding to be a 3-member tuple, unless I am missing something obvious.

Perhaps, you could help me understand this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although the standard Conv module includes a padding attribute, it does not support causal padding in the temporal dimension. Previously, we manually removed all internal padding and relied explicitly on F.pad. In this implementation, we apply manual padding only to the temporal dimension when necessary, while retaining the module's native padding for spatial (H/W) dimensions. This could be faster because the internal padding may use better optimizations.

@sayakpaul sayakpaul requested review from DN6 and yiyixuxu December 7, 2025 12:54
@c8ef c8ef requested a review from sayakpaul December 8, 2025 02:04
@sayakpaul
Copy link
Member

@c8ef could we also check if the performance further improves with torch.compile?

@c8ef
Copy link
Contributor Author

c8ef commented Dec 8, 2025

@c8ef could we also check if the performance further improves with torch.compile?

After compilation, the two CausalConv modules have similar performance.

However, I believe this optimization is still useful in scenarios where we cannot compile - for example, when compiling the entire VAE takes too much startup time, or when compiling certain modules may negatively impact video quality.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants