Skip to content

Conversation

@NuojCheng
Copy link
Collaborator

@NuojCheng NuojCheng commented Dec 3, 2025

Description

This PR introduces explicit sharding support for the DeepSeek model in MaxText. This allows for more granular control over how tensors are distributed across devices, which may lead to better performance and scalability.

This PR is a simplified version based on #2579 .

Key Changes:

  • Enabled DeepSeek for Explicit Sharding: The deepseek decoder is now supported in ShardMode.EXPLICIT.
  • Introduced shard_mode: A shard_mode parameter has been added to various layers, including attention, embeddings, and MoE, to control sharding behavior. This allows for more flexible and explicit sharding configurations.
  • Refactored Sharding Logic: The existing sharding logic has been (minorly) refactored to use the new sharding utilities and NamedSharding objects, making the sharding implementation more explicit and maintainable.
  • Updated Tests: The test suite has been updated to include tests for explicit sharding, ensuring the correctness and robustness of the implementation.

Detailed Changes by File:

  • src/MaxText/common_types.py: Added Q_LORA_UP_PROJ and KV_LORA_UP_PROJ to support LoRA projections with explicit sharding.
  • src/MaxText/configs/types.py: Enabled the deepseek decoder for explicit sharding.
  • src/MaxText/layers/attention_mla.py: Updated to use out_sharding and _maybe_shard_with_logical for more explicit control over sharding in MLA.
  • src/MaxText/layers/attentions.py: Added the shard_mode parameter to RotaryEmbedding and its variants. Refactored input and projection sharding to be more explicit.
  • src/MaxText/layers/deepseek.py: Integrated explicit sharding into the DeepSeekDecoderLayer by using _maybe_shard_with_logical and passing out_sharding and intermediate_sharding to sub-layers.
  • src/MaxText/layers/embeddings.py: Added the shard_mode parameter to all embedding classes to allow for explicit sharding configuration.
  • src/MaxText/layers/moe.py: Added shard_mode to GateLogit and updated the MoeBlock to handle explicit sharding for weights and activations.
  • src/MaxText/sharding.py: Introduced new utility functions for explicit sharding.
  • tests/*: Updated various tests to include test cases for explicit sharding and to pass the new shard_mode and mesh parameters where necessary.

Tests

Model_name: deepseek3-test
Topology: v5p-8
JAX==0.8.1
TL;DR:

  • Auto Sharding: Maintains performance parity with the main branch (deviations ≤ 0.3%).
  • Explicit Sharding: While this introduces a slight expected overhead in TP case, it provides significant improvements for TP_transpose sharding.

Performance Impact Table

Configuration Main Peak Memory (Mb) Main Step Time (ms) Auto Peak Memory (Mb) Auto Peak Mem Change (%) Auto Step Time (ms) Auto Step Time Change (%) Explicit Peak Memory (Mb) Explicit Peak Mem Change (%) Explicit Step Time (ms) Explicit Step Time Change (%)
FSDP4 68748 2238 68748 0% 2238 0% 68961 0.31% 2284 2.01%
FSDP2+TP2 69447 2009 69448 0% 2014 0.25% 69617 0.24% 2378 18.37%
FSDP2+TPT2 93835 10742 93836 0% 10743 0.01% 69373 -26.07% 2714 -74.68%

Full FSDP

smoke_train model_name=deepseek3-test per_device_batch_size=1

FSDP + TP

smoke_train model_name=deepseek3-test per_device_batch_size=1 ici_tensor_parallelism=2

FSDP + TP_transpose

smoke_train model_name=deepseek3-test per_device_batch_size=1 ici_tensor_transpose_parallelism=2

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

Copy link
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for simplifying it! Are you planning to merge this one instead?

If you could directly apply a few minor moments from old PR (just published). We should be good to go.

@NuojCheng
Copy link
Collaborator Author

NuojCheng commented Dec 4, 2025

Reply to @RissyRan:

Are you planning to merge this one instead?

Yes let's merge this one instead. The sharding refactoring parts are removed here.

If you could directly apply a few minor moments from old PR (just published). We should be good to go.

Thank you for the comments! The concerns from the previous PR are all related to the sharding refactoring, which are not included in this PR.

@github-actions
Copy link

github-actions bot commented Dec 5, 2025

🤖 Hi @NuojCheng, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This PR introduces explicit sharding for the DeepSeek model, which is a significant and valuable feature for performance and scalability. The changes are extensive and well-structured. I've identified a couple of minor logical issues and potential improvements in the implementation.

🔍 General Feedback

  • The introduction of shard_mode and the refactoring to use _maybe_shard_with_logical is a clean way to handle explicit sharding.
  • The tests have been updated comprehensively, which is great to see.
  • Pay close attention to copy-paste errors, especially when dealing with similar variables like inputs_q and inputs_kv.

Overall, this is a solid contribution. Once the minor issues are addressed, this PR will be in great shape.

@NuojCheng NuojCheng force-pushed the chengnuojin-explicit-deepseek-split branch from 9d4d6b4 to 8d2efc0 Compare December 5, 2025 23:50
@NuojCheng NuojCheng force-pushed the chengnuojin-explicit-deepseek-split branch from 8d2efc0 to 3dc9988 Compare December 6, 2025 01:05
Copy link
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! cc @suexu1025 FYI

@copybara-service copybara-service bot merged commit 7ebcc9a into main Dec 8, 2025
65 of 71 checks passed
@copybara-service copybara-service bot deleted the chengnuojin-explicit-deepseek-split branch December 8, 2025 20:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants