⚠ This page is served via a proxy. Original site: https://github.com
This service does not collect credentials or authentication data.
Skip to content

Conversation

@Flink-ddd
Copy link
Contributor

Description

This PR addresses Issue #7672.

When sequence_parallel_size is smaller than world_size (e.g., sp_size=2 on 4 GPUs), using torch.distributed.nn.functional.all_gather for loss aggregation triggers an IndexError: tuple index out of range during the backward pass. This occurs because the implementation attempts to access gradient outputs using the global rank index, which exceeds the bounds of the local sequence parallel group (which only contains sp_size elements).

Solution

I have replaced the problematic all_gather aggregation with a mathematically equivalent and robust all_reduce operation:

  • Before: all_gather -> manual sum -> divide (Vulnerable to rank indexing mismatch on sub-groups).
  • After: all_reduce(op=SUM) -> divide (Safe for any sp_size / world_size configuration).

Verification

I added a new regression test TestUlyssesLossBackward in tests/unit/sequence_parallelism/test_ulysses.py.

1. Reproduction (Before Fix)
Confirmed IndexError crash on Rank 2/3 with sp_size=2 on a 4-GPU setup.
Screenshot 2026-01-23 at 23 53 42

2. Verification (After Fix)
Verified the fix using the regression test logic on 4x RTX A6000. The backward pass now completes successfully on all ranks without error.
Screenshot 2026-01-23 at 23 52 54

Signed-off-by: vensen <vensenmu@gmail.com>
@Flink-ddd Flink-ddd force-pushed the fix/issue-7672-ulysses-sp-backward-stability branch from 2b386ab to 4dc7846 Compare January 23, 2026 15:58
@tohtana
Copy link
Collaborator

tohtana commented Jan 23, 2026

@Flink-ddd Thank you for opening this PR! I only see changes in tests. Did you miss committing some changes?

@Flink-ddd
Copy link
Contributor Author

Hi @tohtana , Thanks for the review. There are no missing commits. The issue reported in #7672 stems from using all_gather for loss aggregation in the user's training loop, rather than a bug within DeepSpeed's internal runtime.

Since we cannot patch user scripts directly, I submitted this regression test to:

  1. Verify that the correct approach (using all_reduce) works stably when sp_size < world_size.
  2. Prevent future regressions or confusion regarding this usage pattern.

but, pls tell me if you have other option.

@tohtana
Copy link
Collaborator

tohtana commented Jan 24, 2026

Thank you for your clarification, @Flink-ddd!I It looks like a bug in PyTorch. In AllGather’s backward pass, we should use the local rank within the given process group. It appears this was fixed in v2.3.

  • v2.3: rank = dist.get_rank(group=ctx.group)
  • v2.2: rank = dist.get_rank()

As you said, we can’t force client code to implement loss calculation in a particular way. So I’m wondering whether we should simply add an assertion to check the PyTorch version when SP is enabled. We could also note that SP requires v2.3 or later in the document, even though the DeepSpeed code itself doesn’t have an issue with older versions.

It would still be good to add a regression test. One concern is that the all-reduce approach can’t implement weighted loss averaging, which is used in the original example.

What are your thoughts?

@Flink-ddd
Copy link
Contributor Author

Flink-ddd commented Jan 25, 2026

Hi @tohtana Thanks for the suggestion. I agree that simulating the weighted averaging pattern is better for real-world scenarios. I will update the test case to implement the weighted all-reduce pattern (reducing both the weighted loss and total weights separately) to address this.

@tohtana
Copy link
Collaborator

tohtana commented Jan 25, 2026

Hi @Flink-ddd
Do you think we should support SP with v2.2 or older?

@Flink-ddd
Copy link
Contributor Author

Hi @tohtana , Yes, I believe we should. Many production environments and clusters are still pinned to PyTorch v2.1 or v2.2 due to CUDA driver constraints or stability requirements. maintaining support for SP on these older versions adds significant value to DeepSpeed's compatibility. This regression test ensures that we continue to support these users stably. However, it depends on your perspective. If you think it's unnecessary, we can set assert torch.version >= 2.3 and then turn off this PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants