Skip to content

Commit 9bef9f4

Browse files
KimbingNgkevinkhwusayakpaulDN6
authored
Fix SVD bug (shape of time_context) (#7268)
* Fix SVD bug (shape of `time_context`) * Formatting code * Formatting src/diffusers/models/transformers/transformer_temporal.py by `make style && make quality` --------- Co-authored-by: kevinkhwu <kevinkhwu@tencent.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
1 parent 7aa4514 commit 9bef9f4

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

src/diffusers/models/transformers/transformer_temporal.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -311,10 +311,10 @@ def forward(
311311
time_context_first_timestep = time_context[None, :].reshape(
312312
batch_size, num_frames, -1, time_context.shape[-1]
313313
)[:, 0]
314-
time_context = time_context_first_timestep[None, :].broadcast_to(
315-
height * width, batch_size, 1, time_context.shape[-1]
314+
time_context = time_context_first_timestep[:, None].broadcast_to(
315+
batch_size, height * width, time_context.shape[-2], time_context.shape[-1]
316316
)
317-
time_context = time_context.reshape(height * width * batch_size, 1, time_context.shape[-1])
317+
time_context = time_context.reshape(batch_size * height * width, -1, time_context.shape[-1])
318318

319319
residual = hidden_states
320320

0 commit comments

Comments
 (0)